Troubleshooting Deep Learning models
Attention Conservation Notice An old and barely-edited Twitter thread (so please forgive choppy writing and any formatting glitches) that I rescued from the memory hole. Has my own tips and advice for troubleshooting deep learning models when they don’t want to train. As we’ve moved away from hand-written training loops and people making up new and wierd architectures specific to their particular problem, and just throw LLMs/transformers at them, these are probably less relevant.
Debugging deep learning models can be really tricky and frustrating, especially in the security space where a lot of the time you’re not sure about ground truth labels, but here’s a thread with some tricks I’ve picked up. Add your own if you’ve got them!
First thing: make sure the outputs change when you change the inputs; you want to make sure that information is actually flowing through your model and your clever new architecture isn’t accidentally just outputting a constant value for all inputs. Note that if your model fails this test, it’s worth a) trying a few different initializations, and b) trying a constant (i.e. a value of 1 for every weight) initialization before you give up on this point, especially if there’s any recurrence in it. If you’re sure no information is getting from the input to the output, then it’s time to check layer by layer to see where the blockage is. Accidentally either multiplying by zero or discarding intermediate values are common causes here.
With that sorted…
The second thing you probably want to try is just overfit a trivial dataset. Can you take ~100 samples and drive training loss to zero on just those samples? If the answer is no, then that’s almost always indicative of a problem; common headaches include:
- Accidental non-differentiable operations: check the gradient of each of your layers, fix the input, vary the output, compute loss, take gradients, and see if the layer gradients all change.
- Bad optimizer set-up: I can’t tell you the number of times I’ve gotten mad because nothing was happening, only to discover that I’d set my optimizer to update weights on a different model than the one I was trying to train.
- A bad data loader: are you loading the right labels for the samples? Have you accidentally shuffled your samples differently from your labels? Have you accidentally set the weights of all of your samples to zero
Alternatively: can you not train because your loss rockets off to the land of NaN? Common culprits here include exp() functions in losses or activations that should be handled via ln() trickery, ln() functions or division without an ‘epsilon nugget’ to avoid zeros and having a fairly deep or recurrent network with initialization values that lead to over/underflows. Secondary culprits are fancy pseudo-second order optimizers that end up dramatically under-estimating the variance of some model weights or hand rolled optimizers and/or loss functions that are off by a negative; accidentally doing gradient ascent instead of descent. This is especially easy to do when doing maximum likelihood or some variety of adversarial loss function.
So if you’re through all that and still having problems, it’s generally either
- it trains but the results are just bad, or
- it trains ok for a while and then a few epochs in either the loss starts whipsawing or shoot of to NaNville again.
The solution to ‘a’ is generally to check the data/features. If you’re doing ML in the security space, you’re probably time-splitting (or at least should be); try doing randomized shuffle/split and see if the model can learn on that. If it does, then you might be looking at distributional drift in your data. That’s a tough one and tackling that is pretty problem-specific. Otherwise, it might be time to revisit features and model architecture.
The root cause of ‘b’ is often a lot harder to pin down. Sometimes – especially with pseudo-second-order optimizers or BatchNorm – a bad-luck run of minibatches that are heavily imbalanced or very similar can mess up the running estimates and cause bad behavior. If you save the seed for your PRNG and track epoch/minibatch, you can sometimes catch where it goes off the rails and then reset and fast-forward your iterator to take a closer look at the minibatches leading up to the error. Tensorboard or the like can also help: if you aggressively log model weights and parameters you can often see when some parameter begins to go out of control. BatchNorm variance estimates are a good thing to check there. It’s also possible that your loss surface is just weird in some spots; regularization and BatchNorm can help smooth it a bit, as can adversarial training. I’ve heard other people having success with things like mixup and soft targets, but I’ve never tried those.
A final thing to check – which theoretically shouldn’t make a difference but seems to in practice – is for identical data points with different labels. Relabeling those all with some pseudolabel (mode, mean value, etc depending on loss) can sometimes help.
Specific to the security space, label noise is often a major headache. If your performance is in the ok-but-not-great range, it’s quite possible that you’re at the limit of label noise and just can’t do any better. This is especially true when you’re building detection models, where you usually have a really strong constraint on what false positive rate the end user is willing to put up with. If (e.g) 1 in 1000 labels are wrong, then that’s a hard floor on your assessed FPR. The best option is taking some time to do manual labeling, if you have the ability/resources. Other “purely” ML-driven techniques that I’ve seen proposed include Weak supervision (a la Snorkel) and semi-supervised learning. I personally haven’t had much luck with weak supervision (though see @phtully’s talk at @CamlisOrg 2019 for an example of it working well); semi-supervised learning can give massive improvements if your test set has really good labels.
I’ve also seen active learning proposed to help clean up labels, we did simulations and found that at the scale of data we were working at, it wouldn’t be feasible to get enough manual labels to really help, but if you’re at <1,000,000 samples it might be worth trying. And in practice, the best ROI for doing ML on security problems is almost always going to be improving the data somehow, either better labels or more data or data that’s more like deployment data or finding a narrower or different problem to solve. A careful model search can often get you that last 2% bump in detection at deployment FPR, but you always want to repeat the experiment a few times, just to make sure that it’s a real improvement and not just a random fluke.
So that’s a non-exhaustive list of things to try or think about if your deep learning model isn’t working; if you’ve got a favorite trick or technique I didn’t mention, I’d love to hear about it!