Convolutional Neural Network (CNN) for handwritten digits identifying

Beijing Institute of Technology | Ming-Jian Li

The following Python code is a companion code for the course on Artificial Intelligence and Simulation Science. It functions to establish a CNN model, and train the model from the downloaded MINST dataset for handwritten digits identifying.

This code will download the MINST dataset for handwritten digits, and train the CNN model, then write a file named CNN_model_weight.pth at the current directory to save the trained parameters.

If we want to predict the image as follows:

Please save is as test2_.jpg to the current folder, and use the following code to load the parameter and predict the result.

The result is:

Prediction: The number is tensor([2]).