A Quick Guide To Using Regression With Image Data In Fastai

Bilal Tahir
3 min readNov 24, 2018

I was working on a project with Images that required me to use regression instead of classification. In this case, I had labels that were ages of the subject in a photo and, rather than classifying them, I wanted to use regression to get the difference and use that in my loss function.

This makes sense if you think about using classification, you would be penalizing an output of 32 vs the ground truth of age 31 the same as if the model had predicted age 99. We want to let the model know that the closer it got to the true age, the better.

More details about my particular app can be found here: https://hackernoon.com/building-an-age-predictor-web-app-using-deep-learning-25f0190ea18f

During my project, I found it a little frustrating that there seemed to be no straightforward way to do this (that I knew at the time) and there seemed to be next to no guidance on the topic.

So I am laying out the steps you need to take to switch your Image Classifier to a regressor here. Hopefully it helps if you have such a project!

Let us assume you have pre-processed your data into images with nice labels you want to regress. The process is the same as any other project so far.

Setting Up The Data Bunch

The main difference happens in the databunch. You need to tell fastai that the dataset you are going to use for your model is not a Classification dataset. A databunch for an Image Classifier might look like this (using the data block API):

data = (ImageItemList.from_folder(path_img)
.random_split_by_pct()
.label_from_func(get_float_labels)
.transform(get_transforms(), size=224)
.databunch())
data.normalize(imagenet_stats)

The only change we need to make is to change the dataset to one used for regression. The default dataset is CategoryList but we can change this to FloatList in the label_from_func:

data = (ImageItemList.from_folder(path_img)
.random_split_by_pct()
.label_from_func(get_float_labels, label_cls=FloatList)
.transform(get_transforms(), size=224)
.databunch())
data.normalize(imagenet_stats)

Now fastai knows that the dataset is a set of Floats and not Categories, and the databunch can be used for regression!

Setting Up The Loss Function

Since we are using regression, we would need to update the loss function of our Model. Rather than Binary Cross Entropy, we can use a whole host of loss functions. The most common ones used are Mean Squared Error Loss (MSE Loss) or L1 Loss (see Pytorch docs for more details in the link I will use below).

The default in fastai is MSE (Mean Squared Error). You can explicitly set this in your learner like this:

learn = create_cnn(data, models.resnet34)
learn.loss = MSELossFlat

And now you can run your model using MSE as the loss function. But let’s say you want to use a different loss function (maybe L1 Loss?). Unfortunately, fastai currently does not have these loss functions so we will have to define our own. The good news is though this is super easy!

We can start by getting the Loss function in fastai for MSELossFlat:

class MSELossFlat(nn.MSELoss): 
“Same as `nn.MSELoss`, but flattens input and target.”
def forward(self, input:Tensor, target:Tensor) -> Rank0Tensor:
return super().forward(input.view(-1), target.view(-1))

It seems the MSELoss function directly comes from Pytorch. Let’s find nn.MSELoss from the Pytorch docs: https://pytorch.org/docs/stable/nn.html#loss-functions

These are all the Loss functions from Pytorch. So it seems all we need to do to switch our loss function is to pass a new parameter for the loss function. Let’s create our L1 loss function:

class L1LossFlat(nn.L1Loss):
“Mean Absolute Error Loss”
def forward(self, input:Tensor, target:Tensor) -> Rank0Tensor:
return super().forward(input.view(-1), target.view(-1))

And now let’s pass this to our Learner:

learn = create_cnn(data, models.resnet34)
learn.loss = L1LossFlat

And now we are using Mean Absolute Error instead of Mean Squared Error as our loss function. We can similarly create any other loss function from Pytorch’s list using this tweak.

I hope this comes in handy as you work on your projects! Cheers. :)

--

--