PyTorch深度学习框架进阶——迁移学习

喜欢花科技君 2025-03-08 02:22:06

迁移学习(Transfer Learning)是一种利用预训练模型来解决新任务的技术。在深度学习中,迁移学习通常通过使用在大规模数据集上预训练的模型(如ImageNet上的ResNet、VGG等),并将其应用于新的、通常较小的数据集。PyTorch 提供了丰富的工具和预训练模型,使得迁移学习变得非常方便。

PyTorch 迁移学习的基本步骤加载预训练模型:PyTorch 提供了许多预训练模型,可以通过 torchvision.models 模块加载。修改模型结构:通常需要修改预训练模型的最后一层(全连接层),以适应新任务的类别数量。冻结模型参数:在训练初期,可以冻结预训练模型的参数,只训练新添加的层。这样可以避免破坏预训练模型已经学到的特征。训练模型:解冻部分或全部模型参数,进行微调(fine-tuning)。评估模型:在验证集或测试集上评估模型的性能。示例代码1. 导入必要的库import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms, models2. 数据预处理和加载

首先,需要对数据进行预处理,并使用 DataLoader 来加载数据。假设我们使用的是图像数据,可以使用 torchvision 提供的 transforms 进行数据增强和标准化。

# 数据预处理transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet 的均值 std=[0.229, 0.224, 0.225] # ImageNet 的标准差 )])# 加载训练集和验证集train_dataset = datasets.ImageFolder('path_to_train_data', transform=transform)valid_dataset = datasets.ImageFolder('path_to_valid_data', transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32)3. 加载预训练模型

选择合适的预训练模型,例如 resnet18,并将其加载到内存中。

model = models.resnet18(pretrained=True)4. 修改模型

根据具体任务,通常需要修改预训练模型的最后一层,以适应新的类别数。

num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, num_classes) # num_classes 是新任务的类别数5. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)6. 训练模型

以下是一个简单的训练循环示例:

num_epochs = 25for epoch in range(num_epochs): model.train() running_loss = 0.0 for inputs, labels in train_loader: # 将数据移动到GPU(如果可用) inputs = inputs.to(device) labels = labels.to(device) # 前向传播 outputs = model(inputs) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_dataset) # 验证 model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in valid_loader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}')7. 保存和加载模型

训练完成后,可以保存模型以备后用。

torch.save(model.state_dict(), 'model.pth')

加载模型时:

model = models.resnet18(pretrained=False)model.fc = nn.Linear(num_ftrs, num_classes)model.load_state_dict(torch.load('model.pth'))model.eval()关键点解释数据预处理:transforms 模块用于对图像进行预处理,包括裁剪、翻转、归一化等操作。加载预训练模型:models.resnet18(pretrained=True) 加载了在 ImageNet 上预训练的 ResNet18 模型。修改模型结构:model_ft.fc = nn.Linear(num_ftrs, 10) 修改了最后一层全连接层,使其输出类别数为 10。训练模型:train_model 函数实现了模型的训练和验证过程,并在每个 epoch 结束后保存最佳模型。保存模型:训练完成后,使用 torch.save 保存模型的权重。

0 阅读:0

喜欢花科技君

简介:感谢大家的关注