pytorch CNN实现图片分类
以下是一个使用 PyTorch 实现简单卷积神经网络(Convolutional Neural Network, CNN)的示例代码,用于对 MNIST 手写数字数据集进行分类。
代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义超参数
batch_size = 64
learning_rate = 0.001
epochs = 5
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.1307,), (0.3081,)) # 归一化处理
])
# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义 CNN 模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 第一层卷积层
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
# 第二层卷积层
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
# 池化层
self.pool = nn.MaxPool2d(2)
# 全连接层
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
# 第一层卷积和池化
x = self.pool(torch.relu(self.conv1(x)))
# 第二层卷积和池化
x = self.pool(torch.relu(self.conv2(x)))
# 展平张量
x = x.view(-1, 320)
# 全连接层
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(epochs):
model.train()
running_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')
# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total}%')
代码解释
- 数据预处理:使用
transforms.Compose
定义了数据预处理的步骤,包括将图像转换为张量和归一化处理。 - 数据加载:使用
torchvision.datasets.MNIST
加载 MNIST 数据集,并使用DataLoader
创建数据加载器。 - 模型定义:定义了一个简单的 CNN 模型
SimpleCNN
,包含两层卷积层、两层全连接层和池化层。 - 训练过程:使用交叉熵损失函数
nn.CrossEntropyLoss
和 Adam 优化器optim.Adam
进行模型训练。 - 测试过程:在测试集上评估模型的准确率。
注意事项
- 代码中的超参数(如
batch_size
、learning_rate
和epochs
)可以根据需要进行调整。 - 代码运行前请确保已经安装了 PyTorch 和 torchvision 库。