Testing machine learning models with testthat

· Read in about 10 min · (1932 Words)

Automated testing is a huge part of software development. Once a project reaches a certain level of complexity, the only way that it can be maintained is if it has a set of tests that identify the main functionality and allow you to verify that functionality is intact. Without tests, it’s difficult or impossible to identify where errors are occurring, and to fix those errors without causing further problems.

  1. Tests codify your expectations at the point when you actually remember what you’re expecting
  2. They allow you to offload verification to a machine
  3. You can make changes to your code with confidence

Data science projects tend to be pretty under-tested, which is unfortunate because they have all of the same complexity and maintainability issues as software projects. For instance, if the packages you use change, the data changes, or you just go back and try to make some changes yourself you run the risk of breaking your analysis in ways thatcan be very difficult to detect. There’s been a fair amount of excellent writing on testing functionality and testing data but one of the best types of data science testing is to test the statistical model itself.

Statistical models are often developed under rigorous observation but then deployed into a production system where they might not be monitored as closely. During the development process a data scientist will hopefully spend a lot of time interrogating the model to make sure that it’s not over-fitting, making use of bad data, or misbehaving in some other manner - but once the model is deployed they probably move on to other things and might not worry too much about whether the model is continuing to function as they expected. This leaves the model open to two big classes of bugs:

  1. The model’s inputs change This might happen if the data collection process changes, the modelling function itself changes, or you change something about how the data is processed before it’s fed into the model. Sometimes this will result in dramatic failures that are easy to identify. For instance, your model might start producing noticeably bad predictions which users complain about, but it might also produce subtle problems that you can’t easily detect. The model might just produce worse predictions, or do so only for a subset of users.

  2. The model is refit on invalid data One of the great features of modern machine learning techniques is that they can continue to learn from new data. Models can be fit on a historical dataset, but then updated on live data. This allows the model to make use of the largest set of information, but leaves it open to bias and over-fitting if the new data has problems that you couldn’t detect in the historical data. An extreme example might be if a bug in the data collection process resulted in a single user’s record being repeated multiple times. If the model is refit on that data, it would begin to fit to the peculiarities of that user and perform poorly on other records.

What do tests offer?

The main benefit of testing software is that it helps clarify what exactly you expect from your code. Whenever we write a line of code we have some fuzzy sense of what that code should do but that sense isn’t super precise. We might look at the result of a function and say “that looks a bit fishy” or “that looks realistic” without stopping to clarify why it looks fishy or why it’s realistic. When you go to write a test you have to be more specific, which helps to make your thought process more precise. For instance, we might know that a person’s height can’t be too small, but in order to write a test you need to specify what “too small” means.

The same is true for modelling. There are lots of ways to get a sense of whether your model is fishy or not during the model development process. You might check how the model performs on new data, or check colinearity of its features. You could check whether the model performs well on easy cases, or make sure that the predictions aren’t too good to be true. The key point about testing models is that you should clarify and write down your expectations of the model. How should it perform on these different metrics? When should you start worrying about it?

Questions to ask about your model

The answers to these questions are domain- and technique-specific. An analyst will ask very different questions depending on what they are modelling, the particular techniques they are using, and their background. Here are a few examples of tests, to hopefully get you started testing statistical models.

Let’s start with a simple model of house prices from Kaggle.

library(tidyverse)
library(broom)
library(testthat)
house <- read_csv("house_prices.csv")
model <- lm(SalePrice ~ LotArea + Neighborhood, house)
house$pred <- predict(model)
rmse = function(predicted, observed){
  sqrt(mean((predicted - observed)^2))
}

Is it performing well enough?

The most basic question you can ask about a model is whether it’s continuing to perform well as new data comes in. Typically when you develop a model for production you have made some kind of promise to your company or customer about how that model performs. For instance, it might be worth predicting a customer’s next purchase if the model is 50% accurate, but not worth doing so if it is only 2% accurate. A model can perform well on the training data, but if customer behavior changes in some way its accuracy can degrade. Instead of being notified of these changes by angry product managers, you can write a test that tells you when the model accuracy falls below a certain threshold. In this case we have a root mean squared error of 5.189134110^{4} dollars so let’s say we want to worry about the model when that error gets bigger than $75,000. I’m going to generate new data by sampling my existing dataset, but in production you should use actual new data. The test would look something like this:

library(testthat)
new_data <- house[sample(1:nrow(house), replace = TRUE), ]
test_that("model rmseis above threshold", {
  new_data$pred <- predict(model, newdata = new_data)
  expect_true(rmse( new_data$pred, new_data$SalePrice) < 75000)
})

It’s often helpful to check the other side of model performance as well to ensure that the model isn’t performing unnaturally well. This can happen in real life systems as users learn to game machine learning models; for example, if you have a model of student performance that rewards teachers based on how their students are doing, over time student and teacher behavior might start to shift to accord with the model. A good place to start with these kinds of tests is the intuition that your model performance should probably be worse in the wild than it was when you first developed it. If it starts improving you probably want to do some investigation to understand why.

test_that("model rmse isn't too hight", {
  expect_true(rmse( new_data$pred, new_data$SalePrice) > 20000)
})

This approach is flexible in that you can test all of the model performance metrics which you used to develop your model. In fact you can write these tests out before you develop the model and use them to test the model as you develop it. This can be a helpful way to prevent yourself from cherry picking measures of performance that look good for your particular data.

Are the predictions within a sensible range?

A common form of sanity check for a model is to generate some predictions and check that they are within a sensible range. For instance we might want to check that we’re not predicting negative house prices, or house prices which are extremely large.

test_that("model predictions are sensible", {
  expect_true(min(new_data$pred) > 0)
  expect_true(max(new_data$pred) < 2 * max(house$SalePrice))
})

Depending on your domain you might also want to check if the model is predicting outcomes that weren’t really part of the training set. For instance our model didn’t see very small or very large house prices so we might have some issues with those predictions. These tests check that most of the model’s predictions are within the range of our training data.

test_that("model predictions are within training range", {
  small_predictions <- mean(new_data$pred < min(house$SalePrice)) 
  large_predictions <- mean(new_data$pred > min(house$SalePrice))
  expect_true(small_predictions < 0.01)
  expect_true(large_predictions < 0.01)
})

How does the refit model look?

Refitting models automatically in production is a great way of continuing to learn from new data, but you lose the human oversight of the model which can lead to over-fitting. In addition to the tests listed above you can do some basic checks on the model itself to make sure that it’s being generated properly.

library(broom)
new_model <- lm(SalePrice ~ LotArea + Neighborhood, new_data)

#Did this actually generate a model?
expect_is(new_model, "lm")
expect_true(glance(new_model)$r.squared > 0.5)

# Check outliers among the Neighborhood dummy variables
max_neighborhood_beta <- new_model %>% 
  tidy() %>%  
  filter(str_detect(term, "Neighborhood")) %>% 
  pull(estimate) %>% 
  max(abs(.))
expect_true(max_neighborhood_beta < 2e5)

Again you can substitute whatever you want for measures of model performance. The important thing is that you clarify what you are expecting and have your computer check that those expectations are met every time the model is re-fit.

Is it making illegal or unethical decisions?

Testing is a great way to check whether your model is making distinctions that it shouldn’t. I’m a lawyer by training so I’m aware of some of the problems companies can run into when they start making illegal distinctions between customers. For instance, if you implement a model that sentences people differently based on their race or gender, your company runs the risk of legal sanction. These sanctions can be ruinous for a company, so it’s definitely worth writing a test about.

A good way of evaluating whether your model is making these distinctions is to sample from various ethnic or gender groups in your data, run those samples through your data pipeline, and check that the predictions are within the same range. If they are not the same, it’s worth worrying about whether your model is discriminatory. In Canada there’s something called “adverse effect discrimination,” which means you can be liable for a model that discriminates against a particular group even if you don’t include information about group membership in your model. If you consult with a lawyer about what kind of distinctions are allowed, you can then build tests to ensure that you and future data scientists are not stepping over those lines.

To show how this test would work, let’s generate racial categories for this data.

house$seller_race <- sample(c("Aboriginal", "Black", "White"), nrow(house), TRUE)

sub_groups <- house %>% 
  group_by(seller_race) %>% 
  nest() %>% 
  mutate(preds = map(data, ~predict(model, .))) %>% 
  mutate(mean_predicted = round(map_dbl(preds, mean))) %>% 
  select(seller_race, mean_predicted) %>% 
  spread(seller_race, mean_predicted)

test_that("No significant difference between racial groups", {
  expect_false(sub_groups$Aboriginal[1] < sub_groups$White[1])
  expect_false(sub_groups$Black[1] < sub_groups$White[1])
})

My guess is that a large proportion of data science models would fail checks like this because so many datasets on which those models are based are biased in one way other another. The result is that even if the model doesn’t explicitly rely on racial or gender by including those features in the model, they might be systematically biased against minotiry groups. This should make everybody worried as these models get embedded into more aspects of our lives. For instance a model which generates home loan interest rates might charge minority groups higher rates, or a sentencing model might prevent minority groups from qualifying for bail. These models are likely illegal, and certainly unethical so it’s worth creating tests which check that you are not accidentally making those kinds of distinctions.

Comments