Overfitting & regularization - Brother in Arms
Over-fitting
One of the issues which coder generally phase is to ensure their model doesn’t over-fit or under-fit. Now, it's not an easy task because you don’t have a specific answer for this question. However, there are multiple tools like your experience and ML literature, which equip you to come up with good answer (not the best one !).
Let’s talk about scenario with one feature variable where we have used two models to predict the output.
Consider the below models:
Model 1 – Y_Pred_Model1
Model 2 – Y_Pred_Model2
Which model you think is better?
To answer this question there are multiple concepts working in parallel. I can easily see model 1 is showing zero loss and in fact coincides with the actual output (red dots). Where as model two is more spread out the values. Model 1 seems to be the clear winner as its showing zero loss but it is really the case ?.To be honest, answer is not specific right now because we need more information. There comes the concept of regularization, which is way to ensure your model doesn’t over-fit. What do you mean by this?
Consider, you are teaching a 2-year-old and you teach the baby how to identify dog and a sparrow. However, baby doesn’t know how a zebra would look like. So, when the baby sees zebra, baby might say it’s a dog because baby know about four-legged animal but not about any other variant from the animal kingdom.
Regularization
Interesting part in the above example is that baby has classified the data (Animal) even though baby has not really seen that data point. This is generalization which ML/DL algorithm find hard to achieve. We often, include the regularization parameter in our loss function and find a trade-off between bias and variance. Few examples:
- In case of Tree based algorithms, we prune the trees to avoid over-fitting or may be tune the depth etc.
- In Ridge and lasso we use L1 and L2 regularization to avoid over-fitting.
Ideally, whenever one model the data, they keep different splits of data sets. Let’s say, I have 1 million data points about I need to classify text into n number of categories. One would split the data into test and training set. And then further split training set into validation set (refer cross validation which can help you choose parameters like C in SVM or ideal split ratio or bandwidth parameter in Locally weighted regression). One would train the data and check out its accuracy on test set and then conclude whether the model over-fits or under-fits. Few ways which i use are:
- Check the deviation between the training and test accuracy
- Check loss curve highlighting both training and test loss with every epoch.
- Tweak hyper-parameters, dropout probability.
- Data distribution analysis of training set and out of sample data sets.
Like i said in the beginning, one of the most important tool to come up with good model is your past experience across domains, which enable you to tweak your strategies and come up with better performing models.
Caveats
Please note that above strategy doesn’t ensure your model is good even if you find the model performance on both train and test data set. Why? because here we have one colossal assumption that the data set of 1 million data points represent the overall distribution. Personally, I have seen situations where my model is working well on all the splits but in "out of sample" data sets it was “bad”. I have encountered these situations while I was working in social media analytics space.
Hence, it is very important to analyze the data and perform strong EDA on the data set. Best way is to may be use multi-gram approach and check out the distributions of data sets and make sure they are not completely different. If they are different then you need to get in touch with the data procurement stakeholder and understand the data.