Convolutional Neural Network (CNN) for handwritten digits identifyingBeijing Institute of Technology | Ming-Jian Li
The following Python code is a companion code for the course on Artificial Intelligence and Simulation Science. It functions to establish a CNN model, and train the model from the downloaded MINST dataset for handwritten digits identifying.
1import torch2from torch import nn, optim3from torch.utils.data import DataLoader4from torchvision import datasets, transforms5import os6
7# Hyperparameters8batch_size = 2009learning_rate = 0.0110num_epochs = 2011
12# Load the MNIST dataset13train_dataset = datasets.MNIST(14 root='./data',15 train=True,16 transform=transforms.ToTensor(),17 download=True18)19test_dataset = datasets.MNIST(20 root='./data',21 train=False,22 transform=transforms.ToTensor(),23 download=True24)25
26# Create data loaders27train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)28test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)29
30# Define the CNN model31class CNN(nn.Module):32 def __init__(self, in_dim, n_class):33 super(CNN, self).__init__()34 self.conv = nn.Sequential(35 nn.Conv2d(in_dim, 6, kernel_size=3, stride=1, padding=1),36 nn.ReLU(True),37 nn.MaxPool2d(2, 2),38 nn.Conv2d(6, 16, 5, stride=1, padding=0),39 nn.ReLU(True),40 nn.MaxPool2d(2, 2)41 )42 self.fc = nn.Sequential(43 nn.Linear(16 * 5 * 5, 120),44 nn.Linear(120, 84),45 nn.Linear(84, n_class)46 )47
48 def forward(self, x):49 out = self.conv(x)50 out = out.view(out.size(0), -1) 51 # Flatten the output for the fully connected layer52 out = self.fc(out)53 return out54
55# Initialize the CNN model56cnn = CNN(1, 10)57print(cnn)58
59# Define the loss function and optimizer60criterion = nn.CrossEntropyLoss()61optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)62
63# Train the model64for epoch in range(num_epochs):65 print(f'Epoch {epoch+1}/{num_epochs}')66 running_loss = 0.067 running_acc = 0.068 for i, data in enumerate(train_loader, 1):69 img, label = data70 out = cnn(img)71 loss = criterion(out, label)72 running_loss += loss.item() * label.size(0)73 _, pred = torch.max(out, 1)74 running_acc += (pred == label).sum().item()75
76 optimizer.zero_grad()77 loss.backward()78 optimizer.step()79
80 epoch_loss = running_loss / len(train_dataset)81 epoch_acc = running_acc / len(train_dataset)82 print(f'Epoch {epoch+1}, Loss: {epoch_loss:.6f}, Acc: {epoch_acc:.6f}')83
84# Evaluate the model85cnn.eval()86eval_loss = 087eval_acc = 088with torch.no_grad(): # No need to compute gradients during evaluation89 for i, data in enumerate(test_loader, 1):90 img, label = data91 out = cnn(img)92 loss = criterion(out, label)93 eval_loss += loss.item() * label.size(0)94 _, pred = torch.max(out, 1)95 eval_acc += (pred == label).sum().item()96
97test_loss = eval_loss / len(test_dataset)98test_acc = eval_acc / len(test_dataset)99print(f'Test Loss: {test_loss:.6f}, Acc: {test_acc:.6f}')100
101# Save the trained model102ckpt_dir = './'103save_path = os.path.join(ckpt_dir, 'CNN_model_weight.pth')104torch.save({'state_dict': cnn.state_dict()}, save_path)This code will download the MINST dataset for handwritten digits, and train the CNN model, then write a file named CNN_model_weight.pth at the current directory to save the trained parameters.
If we want to predict the image as follows:

Please save is as test2_.jpg to the current folder, and use the following code to load the parameter and predict the result.
x
1import torch2from torch import nn, optim3from torch.utils.data import DataLoader4from torchvision import datasets, transforms5import os6import numpy as np7from PIL import Image8
9# Define the CNN model10class CNN(nn.Module):11 def __init__(self, in_dim, n_class):12 super(CNN, self).__init__()13 self.conv = nn.Sequential(14 nn.Conv2d(in_dim, 6, kernel_size=3, stride=1, padding=1),15 nn.ReLU(True),16 nn.MaxPool2d(2, 2),17 nn.Conv2d(6, 16, 5, stride=1, padding=0),18 nn.ReLU(True),19 nn.MaxPool2d(2, 2)20 )21 self.fc = nn.Sequential(22 nn.Linear(16 * 5 * 5, 120),23 nn.Linear(120, 84),24 nn.Linear(84, n_class)25 )26
27 def forward(self, x):28 out = self.conv(x)29 out = out.view(out.size(0), -1)30 # Flatten the output for the fully connected layer31 out = self.fc(out)32 return out33
34cnn = CNN(1, 10)35
36# Load the model parameters with weights_only=True37ckpt = torch.load('./CNN_model_weight.pth', map_location=torch.device('cpu'), weights_only=True)38cnn.load_state_dict(ckpt['state_dict']) # Load the parameters into the specified model cnn39
40# The image to be recognized41input_image = './test2_.jpg'42
43im = Image.open(input_image).resize((28, 28)) # Get the image data44im = im.convert('L') # Convert to grayscale45im_data = np.array(im)46im.save('./testgray.jpg')47
48im_data = torch.from_numpy(im_data).float()49
50im_data = im_data.view(1, 1, 28, 28)51out = cnn(im_data)52_, pred = torch.max(out, 1)53
54print('Prediction: The number is {}.'.format(pred))The result is:
Prediction: The number is tensor([2]).