Beijing Institute of Technology | Ming-Jian Li
The following Python code is a companion code for the course on Artificial Intelligence and Simulation Science. It functions to establish a CNN model, and train the model from the downloaded MINST dataset for handwritten digits identifying.
1import torch
2from torch import nn, optim
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import os
6
7# Hyperparameters
8batch_size = 200
9learning_rate = 0.01
10num_epochs = 20
11
12# Load the MNIST dataset
13train_dataset = datasets.MNIST(
14 root='./data',
15 train=True,
16 transform=transforms.ToTensor(),
17 download=True
18)
19test_dataset = datasets.MNIST(
20 root='./data',
21 train=False,
22 transform=transforms.ToTensor(),
23 download=True
24)
25
26# Create data loaders
27train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
28test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
29
30# Define the CNN model
31class CNN(nn.Module):
32 def __init__(self, in_dim, n_class):
33 super(CNN, self).__init__()
34 self.conv = nn.Sequential(
35 nn.Conv2d(in_dim, 6, kernel_size=3, stride=1, padding=1),
36 nn.ReLU(True),
37 nn.MaxPool2d(2, 2),
38 nn.Conv2d(6, 16, 5, stride=1, padding=0),
39 nn.ReLU(True),
40 nn.MaxPool2d(2, 2)
41 )
42 self.fc = nn.Sequential(
43 nn.Linear(16 * 5 * 5, 120),
44 nn.Linear(120, 84),
45 nn.Linear(84, n_class)
46 )
47
48 def forward(self, x):
49 out = self.conv(x)
50 out = out.view(out.size(0), -1)
51 # Flatten the output for the fully connected layer
52 out = self.fc(out)
53 return out
54
55# Initialize the CNN model
56cnn = CNN(1, 10)
57print(cnn)
58
59# Define the loss function and optimizer
60criterion = nn.CrossEntropyLoss()
61optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
62
63# Train the model
64for epoch in range(num_epochs):
65 print(f'Epoch {epoch+1}/{num_epochs}')
66 running_loss = 0.0
67 running_acc = 0.0
68 for i, data in enumerate(train_loader, 1):
69 img, label = data
70 out = cnn(img)
71 loss = criterion(out, label)
72 running_loss += loss.item() * label.size(0)
73 _, pred = torch.max(out, 1)
74 running_acc += (pred == label).sum().item()
75
76 optimizer.zero_grad()
77 loss.backward()
78 optimizer.step()
79
80 epoch_loss = running_loss / len(train_dataset)
81 epoch_acc = running_acc / len(train_dataset)
82 print(f'Epoch {epoch+1}, Loss: {epoch_loss:.6f}, Acc: {epoch_acc:.6f}')
83
84# Evaluate the model
85cnn.eval()
86eval_loss = 0
87eval_acc = 0
88with torch.no_grad(): # No need to compute gradients during evaluation
89 for i, data in enumerate(test_loader, 1):
90 img, label = data
91 out = cnn(img)
92 loss = criterion(out, label)
93 eval_loss += loss.item() * label.size(0)
94 _, pred = torch.max(out, 1)
95 eval_acc += (pred == label).sum().item()
96
97test_loss = eval_loss / len(test_dataset)
98test_acc = eval_acc / len(test_dataset)
99print(f'Test Loss: {test_loss:.6f}, Acc: {test_acc:.6f}')
100
101# Save the trained model
102ckpt_dir = './'
103save_path = os.path.join(ckpt_dir, 'CNN_model_weight.pth')
104torch.save({'state_dict': cnn.state_dict()}, save_path)
This code will download the MINST dataset for handwritten digits, and train the CNN model, then write a file named CNN_model_weight.pth
at the current directory to save the trained parameters.
If we want to predict the image as follows:
Please save is as test2_.jpg
to the current folder, and use the following code to load the parameter and predict the result.
x
1import torch
2from torch import nn, optim
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import os
6import numpy as np
7from PIL import Image
8
9# Define the CNN model
10class CNN(nn.Module):
11 def __init__(self, in_dim, n_class):
12 super(CNN, self).__init__()
13 self.conv = nn.Sequential(
14 nn.Conv2d(in_dim, 6, kernel_size=3, stride=1, padding=1),
15 nn.ReLU(True),
16 nn.MaxPool2d(2, 2),
17 nn.Conv2d(6, 16, 5, stride=1, padding=0),
18 nn.ReLU(True),
19 nn.MaxPool2d(2, 2)
20 )
21 self.fc = nn.Sequential(
22 nn.Linear(16 * 5 * 5, 120),
23 nn.Linear(120, 84),
24 nn.Linear(84, n_class)
25 )
26
27 def forward(self, x):
28 out = self.conv(x)
29 out = out.view(out.size(0), -1)
30 # Flatten the output for the fully connected layer
31 out = self.fc(out)
32 return out
33
34cnn = CNN(1, 10)
35
36# Load the model parameters with weights_only=True
37ckpt = torch.load('./CNN_model_weight.pth', map_location=torch.device('cpu'), weights_only=True)
38cnn.load_state_dict(ckpt['state_dict']) # Load the parameters into the specified model cnn
39
40# The image to be recognized
41input_image = './test2_.jpg'
42
43im = Image.open(input_image).resize((28, 28)) # Get the image data
44im = im.convert('L') # Convert to grayscale
45im_data = np.array(im)
46im.save('./testgray.jpg')
47
48im_data = torch.from_numpy(im_data).float()
49
50im_data = im_data.view(1, 1, 28, 28)
51out = cnn(im_data)
52_, pred = torch.max(out, 1)
53
54print('Prediction: The number is {}.'.format(pred))
The result is:
Prediction: The number is tensor([2]).