设为首页收藏本站
授权版本:2024_07
开启左侧

Pytorch构建网络模型四步骤:

[复制链接]
15732670364 发表于 2021-8-30 13:38:00 | 显示全部楼层 |阅读模式
本帖最后由 15732670364 于 2021-8-30 13:40 编辑


①利用pytorch进行深度学习

  • 准备数据集(Prepare dataset)
  • 设计用于计算最终结果的模型(Design model)
  • 构造损失函数及优化器(Construct loss and optimizer)
  • 设计循环周期(Training cycle)——前馈、反馈、更新
  1. import torch
  2. from torchvision import transforms
  3. from torchvision import datasets
  4. from torch.utils.data import DataLoader
  5. import torch.nn.functional as F
  6. import torch.optim as optim
  7. import matplotlib.pyplot as plt

  8. # prepare dataset

  9. batch_size = 64
  10. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

  11. train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
  12. train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
  13. test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
  14. test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)


  15. # design model using class


  16. class Net(torch.nn.Module):
  17.     def __init__(self):
  18.         super(Net, self).__init__()
  19.         self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
  20.         self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
  21.         self.pooling = torch.nn.MaxPool2d(2)
  22.         self.fc = torch.nn.Linear(320, 10)

  23.     def forward(self, x):
  24.         # flatten data from (n,1,28,28) to (n, 784)

  25.         batch_size = x.size(0)
  26.         x = F.relu(self.pooling(self.conv1(x)))
  27.         x = F.relu(self.pooling(self.conv2(x)))
  28.         x = x.view(batch_size, -1)  # -1 此处自动算出的是320
  29.         # print("x.shape",x.shape)
  30.         x = self.fc(x)

  31.         return x


  32. model = Net()
  33. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  34. model.to(device)

  35. # construct loss and optimizer
  36. criterion = torch.nn.CrossEntropyLoss()
  37. optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


  38. # training cycle forward, backward, update


  39. def train(epoch):
  40.     running_loss = 0.0
  41.     for batch_idx, data in enumerate(train_loader, 0):
  42.         inputs, target = data
  43.         inputs, target = inputs.to(device), target.to(device)
  44.         optimizer.zero_grad()

  45.         outputs = model(inputs)
  46.         loss = criterion(outputs, target)
  47.         loss.backward()
  48.         optimizer.step()

  49.         running_loss += loss.item()
  50.         if batch_idx % 300 == 299:
  51.             print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
  52.             running_loss = 0.0


  53. def test():
  54.     correct = 0
  55.     total = 0
  56.     with torch.no_grad():
  57.         for data in test_loader:
  58.             images, labels = data
  59.             images, labels = images.to(device), labels.to(device)
  60.             outputs = model(images)
  61.             _, predicted = torch.max(outputs.data, dim=1)
  62.             total += labels.size(0)
  63.             correct += (predicted == labels).sum().item()
  64.     print('accuracy on test set: %d %% ' % (100 * correct / total))
  65.     return correct / total


  66. if __name__ == '__main__':
  67.     epoch_list = []
  68.     acc_list = []

  69.     for epoch in range(10):
  70.         train(epoch)
  71.         acc = test()
  72.         epoch_list.append(epoch)
  73.         acc_list.append(acc)

  74.     plt.plot(epoch_list, acc_list)
  75.     plt.ylabel('accuracy')
  76.     plt.xlabel('epoch')
  77.     plt.show()
复制代码



奖励计划已经开启,本站鼓励作者发布最擅长的技术内容和资源,流量变现就在现在,[点我]加入吧~~~Go
您需要登录后才可以回帖 登录 | 注册

本版积分规则

快速回复 返回顶部 返回列表