MNIST Digit Predictor
Convolutional neural network trained from scratch with PyTorch to predict hand-drawn digits.
Image Gallery
Click on an image to read more about it.
Predicting Hand-Drawn Digits
The user draws a digit on the canvas, which is then converted to an image and sent to the server, where it is input into the model for prediction.
Grad-CAM Visualization
Grad-CAM is a method for analyzing the layers of convolutional neural networks. In this image, each example has two rows: the original image on top and a very slightly transformed image below. (Neural activations are shown for each output class.) However, the activations are sometimes very different as a result of this very small transformation. This insight helped me design a better training strategy to get the network more accurate.
Initial Confusion Matrix
Inspecting the confusion matrix after initial training, we see that the model's predictions worsen significantly from translations and shear transformations. The translation insight was a big alarm bell: the features detected should be the same regardless of where in the image they are.
Initial Loss and Accuracy Graphs
Indeed, there was a lower limit to the validation loss we were seeing, whereas the training loss continued to decrease. This means we were overfitting the training data. To further confirm, notice how quickly the training accuracy reaches near 100%, only to abruptly taper off.
Improved Confusion Matrix
After training more aggressively on image transformations, the model had no problem with identifying the correct digit, even with translation and shear.
Improved Loss and Accuracy Graphs
With the improved model, the validation loss follows a very similar trend to the training loss, meaning we are no longer overfitting. Also notice that the training accuracy increases more slowly toward 100%, indicating that it is learning more difficult trends, which ended up being more general for digit prediction.
How it Works
The model that predicts digits is a convolutional neural network that is trained on the MNIST dataset, which is dataset of images of handwritten digits. The neural network was trained from scratch using the PyTorch Python library.
I converted the neural network to ONNX after training, which was small enough to deploy via AWS Lambda with reasonably fast inference time. The frontend is a Vite app, which compiles to a static website. I deployed this to Google Cloud Storage. As a result, the entire end-to-end application is hosted essentially for free.
This project was an valuable learning experience for me. I learned how to troubleshoot and problem solve in a data-heavy system. I could no longer rely on debuggers and code tracing because even if the code is correct the model was still underperforming. I had to add many skills to my toolbox, like visualizing data, setting up experiments to find root causes, and creatively searching for new solutions.
One of the main problems I had during development was the model being unable to handle transformations of the image. For example, it could guess the digit if it were directly in the middle of the image, but any slight translation and the model didn't know what to do. After much visualization and testing, I decided the best way forward was to aggressively train on transformed images from the start. So, in the training loop, random transformations would be applied without any gradual increase. This forced the model to learn actual features regardless of their scale or position, rather than memorizing the data distribution.