• 深度神经网络

    知识蒸馏(Knowledge Distillation)

    知识蒸馏(Knowledge Distillation)是一种模型压缩和迁移学习的技术,通过将一个大型的、性能较好的教师模型(Teacher Model)的知识传递给一个小型的学生模型(Student Model),从而使学生模型能够在保持较高性能的同时减少计算资源和内存的使用。下面是一个使用 PyTorch 在 MNIST 数据集上进行知识蒸馏的示例代码:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    
    # 定义教师模型
    class TeacherModel(nn.Module):
        def __init__(self):
            super(TeacherModel, self).__init__()
            self.fc1 = nn.Linear(784, 128)
            self.fc2 = nn.Linear(128, 64)
            self.fc3 = nn.Linear(64, 10)
    
        def forward(self, x):
            x = x.view(-1, 784)
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # 定义学生模型
    class StudentModel(nn.Module):
        def __init__(self):
            super(StudentModel, self).__init__()
            self.fc1 = nn.Linear(784, 32)
            self.fc2 = nn.Linear(32, 10)
    
        def forward(self, x):
            x = x.view(-1, 784)
            x = torch.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 加载 MNIST 数据集
    train_dataset = datasets.MNIST(root='./data', train=True,
                                   download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False,
                                  download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 训练教师模型
    teacher_model = TeacherModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
    
    for epoch in range(5):
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = teacher_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')
    
    # 知识蒸馏
    student_model = StudentModel()
    optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)
    temperature = 2.0
    alpha = 0.5
    
    for epoch in range(5):
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            optimizer_student.zero_grad()
    
            # 教师模型输出
            teacher_outputs = teacher_model(images)
            teacher_logits = teacher_outputs / temperature
            teacher_probs = torch.softmax(teacher_logits, dim=1)
    
            # 学生模型输出
            student_outputs = student_model(images)
            student_logits = student_outputs / temperature
            student_probs = torch.softmax(student_logits, dim=1)
    
            # 蒸馏损失
            distillation_loss = nn.KLDivLoss(reduction='batchmean')(torch.log(student_probs), teacher_probs) * (temperature ** 2)
    
            # 学生模型与真实标签的交叉熵损失
            student_loss = nn.CrossEntropyLoss()(student_outputs, labels)
    
            # 总损失
            total_loss = alpha * distillation_loss + (1 - alpha) * student_loss
    
            total_loss.backward()
            optimizer_student.step()
            running_loss += total_loss.item()
    
        print(f'Epoch {epoch + 1}, Student Loss: {running_loss / len(train_loader)}')
    
    # 测试学生模型
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = student_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy of the student network on the 10000 test images: {100 * correct / total}%')

    代码说明:

    1. 模型定义:定义了一个简单的三层全连接教师模型和一个两层全连接学生模型。
    2. 数据加载:使用 torchvision 加载 MNIST 数据集,并进行预处理。
    3. 教师模型训练:使用交叉熵损失函数和 Adam 优化器训练教师模型。
    4. 知识蒸馏:使用 KL 散度损失函数计算教师模型和学生模型输出的软标签之间的损失,同时结合学生模型与真实标签的交叉熵损失,得到总损失。
    5. 学生模型测试:在测试集上评估学生模型的准确率。

    通过知识蒸馏,学生模型可以从教师模型中学习到更多的知识,从而提高性能。