Table of Contents
- Introducing covariate shift
- Detecting Covariate Shift: A naive approach
- When does univariate drift detection fail?
- Detecting covariate shift: a multivariate approach
- Principal Components Analysis
- Data reconstruction with PCA
- Comparing reconstruction errors
- Multivariate drift detection in Python
- Assumptions & limitations
- Data completeness
- Reconstruction error stability with no drift
- Shift (non)linearity
Do not index
Do not index
Canonical URL
The ultimate purpose of any machine learning model is to bring value to its owner. Typically, this value comes in the form of the algorithm doing things better or faster (or both) than a human can. The cost of investment in developing and deploying the model is often high. In order to pay it back, the model needs to provide value in production for a long enough period of time. This can be impeded by covariate shift, a phenomenon causing the models in production to degrade over time. Let’s see how to detect it, and why the popular simple approaches are usually not good enough.
Introducing covariate shift
Covariate shift is a situation in which the distribution of the model’s input features in production changes compared to what the model has seen during training and validation.
Covariate shift is a change in the distribution of the model’s inputs between training and production data.
In most applications, it is a matter of time before covariate shift occurs. If you are modeling your customers, for instance, their behavior patterns will shift as the economy changes, as they get older, or as the customer base alters due to marketing campaigns. The only constant in life is change, as the Greek philosopher Heraclitus has supposedly said.
The key to ensuring the models keep working well in production is to detect covariate shift early. How can we do this? Let’s find out!
Detecting Covariate Shift: A naive approach
We have said that covariate shift is a change in the distribution of the model's inputs. We know the distribution of each feature in the training data, and we should also be able to obtain it for the features in production. Why don't we just compare the two?
Actually, it is a valid approach. Tools such as TensorFlow Data Validation or Great Expectations allow us to easily compare distributions of the input features between training and production data, feature by feature.
The simplest way to establish whether the training and production distributions differ for a given feature is to compare their summary statistics, such as means, medians, etc.
A more sophisticated approach would be to directly calculate how much the two distributions differ using a (dis)similarity measure such as Earth mover’s distance or Jensen–Shannon divergence.
An even more statistically rigorous approach is to conduct a proper hypothesis test. The Kolmogorov-Smirnov test for continuous features and the Chi-squared test for categorical ones can help us establish whether the feature distributions differ in a significant manner.
The latter approach, however, features a serious drawback. Testing production data like this would require continuously conducting many tests on large data volumes. Many of them will surely signal significant effects, but that would not always be a reason to worry from the monitoring standpoint. In many cases, a statistically significant distribution change does not correspond to a covariate shift resulting in degraded model performance.
But there is also an even more significant drawback that all of the approaches discussed above share: they are univariate, meaning that they treat each feature in isolation from the rest. In real life, we often observe a multivariate data drift, in which the distributions of the individual features don’t necessarily change, but their joint distribution does. Consider the following example.
When does univariate drift detection fail?
Imagine we have a dataset consisting of two features: age and income, and the data is collected over a number of weeks.
Both features have been standardized so that they can be neatly visualized together. The standardization used the mean and variance values from the general population, i.e., not calculated from these data, which means we don’t need to worry about any data leaks.
In the 10th week (blue), there is a pretty strong positive correlation between the two features. This makes sense: most people get more experience with age, and they are getting promoted, which results in a higher income, on average.
In the 16th week (yellow), the marginal univariate distribution of each feature is still standard normal as before, but the correlation pattern has changed dramatically; now, the two features are negatively correlated. Maybe the data from Week 16 was collected from older people. In this case, the older a person, the higher the likelihood they are retired and thus have a lower income.
Univariate drift detection methods fail when the univariate feature distributions remain unchanged while the joint distribution changes.
We are clearly experiencing a strong covariate shift; any machine learning models trained on the data from week 10 would perform very poorly on the data from week 16. However, the univariate drift detection methods we have discussed previously would not alert us to this dangerous shift.
Detecting covariate shift: a multivariate approach
In order to reliably detect covariate shift in any situation, we need (on top of the univariate methods) another method that can capture the changes in the joint distribution of all the features. A very simple yet clever approach to achieve this is based on the good old Principal Components Analysis (PCA).
Principal Components Analysis
Principal Components Analysis is a dimensionality reduction technique. It identifies the directions of our data that contain most of the variance. The directions, called principal components, are ordered in decreasing order of captured variance. This means that by restricting ourselves to a selected number of first principal components, we can capture most of the variability of our data while reducing the number of dimensions we have. In other words, we could say that PCA projects the data to a lower-dimensional space.
PCA maps original data into a lower-dimensional space, keeping most of the informative signal
Dimensionality reduction with PCA finds multiple applications. It can be performed as a pre-processing step before model training in order to reduce the number of features, as some machine learning models become very slow to train with too many features.
But speed is not the only gain from using PCA as the preprocessing step. In some cases, machine learning models trained on the PCA-transformed data perform better. This is because when we reduce the dimensionality, we try to squeeze the same amount of information into fewer variables. This process is lossy, meaning that some information will inevitably be discarded. PCA assumes that useful information about a dataset is captured by its variance, so using independent principal components that capture most of the variance as features can lead to better decision thresholds.
Finally, PCA is also useful for visualizing and finding patterns in multidimensional datasets — after compressing the data to two or three features, we can easily plot them, color each data point by an attribute of interest, and spot interesting patterns.
Data reconstruction with PCA
Okay, but how does this dimensionality reduction help with drift detection, I hear you asking. Here is the idea.
The only piece of the puzzle that we’re missing is the fact that the PCA transformation is invertible. Once we have compressed the data to fewer features, we can use the PCA model to decompress them, that is: bring the dataset back to its original number of features. The result will not be exactly the same as the original data due to the lossy compression process, but it should be quite similar. Let’s call the difference between the original and decompressed data the reconstruction error.
Let’s go back to our example dataset. Recall that we compared Week 10 to Week 16 to find that a multivariate data drift has occurred in the meantime. Now, let’s focus on Weeks 9, 10, and 11, during which no covariate shift happens, that is: the joint multivariate distribution of the features is the same for all these three weeks.
If we learn the PCA mapping on Week 9, we can use it to compress and decompress both Week 10 and Week 11 to obtain similar, low reconstruction errors.
We don’t really care about the particular value of the error now. What’s important is that it is likely to be roughly the same for weeks 10 and 11, because the internal structure of the data stays the same, and the compression-decompression mapping learned by the PCA is still applicable.
Now, what would happen if we used our PCA model on Weeks 10 and 16, rather than 10 and 11? This time, the mapping learned by PCA will still hold for Week 10, but not for Week 16. Thus, the reconstruction error for Week 16 will be significantly larger than the one for Week 10.
The idea we have just discussed can be used to detect covariate shift in production. If we know what reconstruction error we can expect from the data without drift (think Week 10 error), we can compare it to the reconstruction error of the production data. If the latter is significantly larger, we can declare covariate shift to be present.
We can detect covariate shift by comparing the PCA reconstruction error from production data to its expected level.
The only two remaining questions are: what is the expected reconstruction error, and what does significantly larger mean?
Comparing reconstruction errors
One way to estimate the expected reconstruction error is to take a portion of the training data that we know not to exhibit covariate shift (let’s call it a reference data set), split it into chunks, and compute the reconstruction error on each chunk. This way, we get a range of error values from which we can compute the mean error and its variance.
Then, a simple rule of thumb for declaring covariate shift is to check whether the reconstruction error from the data in production is outside of the ranges obtained from the reference set. If it is at least three standard deviations away from the reference mean, then a shift is very likely.
Multivariate drift detection in Python
Detect Multivariate Data Drift with our Open Source Library
nannyML supports tabular use cases, classification and regression
Click here to Get Started
Setting up ML monitoring can consume a lot of your time.
Our team can help you get started. Talk to one of the founders who understand your use case
Let’s now take a look at how to implement all of these ideas in Python using the NannyML package.
We will be using Yandex’s weather dataset. Multivariate concept drift is inherently present in weather measurements; as the seasons come and go, the way in which different climate components interact with each other tends to evolve, too. The dataset consists of over a hundred features describing various meteorological measures such as temperature, humidity, atmospheric pressure, etc.
First, let’s load and prepare the data and split it into two disjoint sets: September data (our reference set, on which PCA will be trained) and October data (our analysis set, which we will test against covariate shift). We will also need to parse the timestamp column to a proper pandas datetime format.
Next, we fit NannyML’s drift calculator, which does all the steps described above. We need to pass it the feature names to check (all but the timestamp column), the timestamp column name, and the chunk size based on which the expected reconstruction error will be estimated. Then, we simply fit it to the reference data and apply it to the analysis data to test for covariate shift.
The easiest way to make use of the generated results object is to render it as a plot.
The two dashed lines denote the expected interval for reconstruction error values, estimated from September data. As you can see, there is one anomalous observation towards the end of this month — perhaps a particularly strong thunderstorm or some other impactful event.
In the first days of October, there is no covariate shift; it seems the month started with weather similar to the one in the last week of September. But then the autumn hits, and the rest of October data features a strong covariate shift.
Although clear and visually appealing, the plot above will not be a convenient way of detecting covariate shift as a part of a pipeline with automated alerting. But it’s enough to call the
.to_df()
method on the results
object to get a data frame with all the data behind the plot.It contains the alert column with a boolean flag denoting whether each particular data point was out of the expected bounds. It’s a great way to add covariate drift detection to any existing data validation checks!
Assumptions & limitations
No discussion of a statistical method would be complete without delving deeper into the assumptions it makes. This will give us some indication of when the method works and when it doesn’t, and what its limitations are. Data drift detection based on the change in the PCA reconstruction error is no different.
Data completeness
First and foremost, our algorithm needs to handle missing values should there be any in our dataset. This requirement transfers from the underlying PCA which assumes all data points to be observed so that it can find the principal components.
The obvious remedy is to impute the missing values, should there be any, before running drift detection. Actually, the
DataReconstrucionDriftCalculator
class we have used before accepts two arguments imputer_categorical
and imputer_continuous
through which we can specify the method for filling in the missing data for each of the two data types. We just need to pass an instance of the scikit-learn’s SimpleImputer
as follows:There is, however, a risk associated with following this approach. The scikit-learn’s
SimpleImputer
is, well, simple. As of the time of writing, it can only perform four simple donor-based imputation methods; it replaces the missing values with the mean, the median, or the mode of other values in the column, or simply with a specified constant value.If you have taken my DataCamp course on data imputation or read my blog post on the topic, then you know that these imputation strategies are not the best ways to go. Mean, median, mode, or constant imputations all create the same two problems: they reduce the variance of the mean-imputed variables and destroy correlations between such variables and the rest of the data. This can hurt us twice: first, when we run PCA, a method based on the data variance, and second, when we hunt for a shift in the joint data distribution, which could have been affected by the imputed values.
Our drift detector doesn’t allow missing values. If present, they should be imputed beforehand, preferably with model-based methods.
Hence, it might be a better solution to employ one of the model-based imputation methods such as MICE before running drift detection.
Reconstruction error stability with no drift
Another important assumption made by our drift detector is that reconstruction error is stable in time in the absence of covariate shift. Let’s try to unpack this statement.
Let’s go back to the algorithm’s subsequent steps. We started by learning the PCA mapping on the data from Week 9. Then, we have said that we can use this mapping to compress and decompress data from subsequent Weeks 10 and 11 (both with no drift) and the obtained reconstruction error will be low and similar in both cases. One thing that we did not mention earlier is that the error computed on the same data that was used for learning the PCA mapping would be even lower. And when we do this for Week 16 when the data has drifted, we find a higher reconstruction error.
The assumption of the reconstruction error stability refers to the fact while the error increases from Week 9 to Week 10, it remains stable in the following weeks, up until covariate shift happens. Let’s see why this assumption makes sense.
When we fit PCA to Week 9 data, we look for such directions in the multidimensional data space that capture the largest amount of variance. We can then use these principal components on the same data to obtain some value of the reconstruction error, let’s call it RE9.
When we use the same principal components to get the reconstruction error for Week 10, however, we will find RE10 to be greater than RE9. This is because the Week 9 components need not be optimal (i.e., capturing the most variance) for Week 10 data which features a different noise pattern. Such an effect is known as regression to the mean. This is analogous to how machine learning models typically perform better on the training set than on the test set.
The crucial assumption is that we expect RE10 ≈ RE11, so the error from Week 11 should be close to the one from Week 10. This is what allows us to spot the drift; we assume the error to remain stable for Weeks 10, 11, 12, etc., up until a point when it goes up, which would alert us to a possible drift.
Our drift detector assumes that in the absence of covariate shift, reconstruction errors will be stable over time (except for the period used for learning the PCA mapping).
The justification behind this assumption is that while RE9 is underestimated through overfitting to the particular Week 9 data, RE10 is the expected reconstruction error that we anticipate seeing for each new data sample (provided its joint distribution stays unchanged).
Shift (non)linearity
Arguably the most important limitation of our approach is in the type of shift that the PCA reconstruction error can and cannot capture. Namely, covariate shift detection as we have done it will not capture drift when there is a non-linear transformation that preserves the correlation between features while also preserving the mean and standard deviation of each feature.
PCA-based drift detector won’t work in case of a non-linear drift that preserves each feature’s mean and standard deviation, and the correlations between different features.
Let’s illustrate this with a simple example. Imagine we have a reference dataset consisting of just two features, x and y. Working with only two dimensions allows us to easily visualize the relationship between them.
There clearly exists a non-random relationship between our two features. Now, imagine that at some point, this relationship changes to the following one.
The range of possible values of y increased, while there is a low-density region in the middle of the joint distribution where no data points are to be found.
Note how the linear correlation coefficient is almost the same for this new dataset and the reference dataset above. If we were to run our drift detector on these data, the results would the as follows:
Drift detection for a dataset with non-linear covariate shift that cannot be captured by PCA.
The detector failed to capture the shift that we know has happened: the purple error line is well within the expected bound. This is because the drift left features’ means and standard deviations, as well as the correlation between the features, intact.
We can bring this argument to the extreme. Imagine our analysis dataset takes a dino shape (all data discussed in this section comes from the Datasaurus set: a collection of x-y pairs with the same linear correlation and basic summary statistics, but very different joint distributions).
As long as the linear correlation and the features’ mean and standard deviation continue to be the same, the PCA-based drift detector won’t work.
Drift detection for a dataset with non-linear covariate shift that cannot be captured by PCA.
In our two-dimensional example, we can easily observe the data joint distribution. In most real-life applications, however, we work with more features which makes it harder to call whether the shift was linear or not.
If you’ve got not too many features and detect no multivariate drift, I suggest looking at their summary statistics and the linear correlation between each pair. If it’s similar to your reference data and for your analysis data, that’s a sign the PCA-based drift detection might not be reliable.
While covariate shift detection through data reconstruction with PCA cannot detect some cases of shift that we have just discussed, we can detect them with univariate drift detection methods. Stay tuned for a separate article about them!