Halcon视觉之家 - 51Halcon专注于机器视觉技术

 找回密码
 会员注册

QQ登录

只需一步,快速开始

扫一扫,微信登录

查看: 263|回复: 1

[PyTorch] Pytorch构建网络模型四步骤:

[复制链接]
  • TA的每日心情
    慵懒
    2021-11-26 13:59
  • 签到天数: 4 天

    连续签到: 1 天

    [LV.2]偶尔看看I

    3

    主题

    4

    帖子

    24

    积分

    Rank: 1

    积分
    24

    切换助手验证会员

    QQ
    发表于 2021-8-30 13:38:00 | 显示全部楼层 |阅读模式

    51Halcon诚邀您的加入,专注于机器视觉开发与应用技术,我们一直都在努力!

    您需要 登录 才可以下载或查看,没有帐号?会员注册

    x
    本帖最后由 15732670364 于 2021-8-30 13:40 编辑


    ①利用pytorch进行深度学习

    • 准备数据集(Prepare dataset)
    • 设计用于计算最终结果的模型(Design model)
    • 构造损失函数及优化器(Construct loss and optimizer)
    • 设计循环周期(Training cycle)——前馈、反馈、更新
    1. import torch
    2. from torchvision import transforms
    3. from torchvision import datasets
    4. from torch.utils.data import DataLoader
    5. import torch.nn.functional as F
    6. import torch.optim as optim
    7. import matplotlib.pyplot as plt

    8. # prepare dataset

    9. batch_size = 64
    10. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    11. train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
    12. train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    13. test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
    14. test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)


    15. # design model using class


    16. class Net(torch.nn.Module):
    17.     def __init__(self):
    18.         super(Net, self).__init__()
    19.         self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
    20.         self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
    21.         self.pooling = torch.nn.MaxPool2d(2)
    22.         self.fc = torch.nn.Linear(320, 10)

    23.     def forward(self, x):
    24.         # flatten data from (n,1,28,28) to (n, 784)

    25.         batch_size = x.size(0)
    26.         x = F.relu(self.pooling(self.conv1(x)))
    27.         x = F.relu(self.pooling(self.conv2(x)))
    28.         x = x.view(batch_size, -1)  # -1 此处自动算出的是320
    29.         # print("x.shape",x.shape)
    30.         x = self.fc(x)

    31.         return x


    32. model = Net()
    33. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    34. model.to(device)

    35. # construct loss and optimizer
    36. criterion = torch.nn.CrossEntropyLoss()
    37. optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


    38. # training cycle forward, backward, update


    39. def train(epoch):
    40.     running_loss = 0.0
    41.     for batch_idx, data in enumerate(train_loader, 0):
    42.         inputs, target = data
    43.         inputs, target = inputs.to(device), target.to(device)
    44.         optimizer.zero_grad()

    45.         outputs = model(inputs)
    46.         loss = criterion(outputs, target)
    47.         loss.backward()
    48.         optimizer.step()

    49.         running_loss += loss.item()
    50.         if batch_idx % 300 == 299:
    51.             print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
    52.             running_loss = 0.0


    53. def test():
    54.     correct = 0
    55.     total = 0
    56.     with torch.no_grad():
    57.         for data in test_loader:
    58.             images, labels = data
    59.             images, labels = images.to(device), labels.to(device)
    60.             outputs = model(images)
    61.             _, predicted = torch.max(outputs.data, dim=1)
    62.             total += labels.size(0)
    63.             correct += (predicted == labels).sum().item()
    64.     print('accuracy on test set: %d %% ' % (100 * correct / total))
    65.     return correct / total


    66. if __name__ == '__main__':
    67.     epoch_list = []
    68.     acc_list = []

    69.     for epoch in range(10):
    70.         train(epoch)
    71.         acc = test()
    72.         epoch_list.append(epoch)
    73.         acc_list.append(acc)

    74.     plt.plot(epoch_list, acc_list)
    75.     plt.ylabel('accuracy')
    76.     plt.xlabel('epoch')
    77.     plt.show()
    复制代码



  • TA的每日心情

    2021-1-15 01:04
  • 签到天数: 6 天

    连续签到: 1 天

    [LV.2]偶尔看看I

    0

    主题

    38

    帖子

    99

    积分

    Rank: 1

    积分
    99

    切换助手验证会员

    发表于 2021-9-2 15:45:55 | 显示全部楼层
    期待后续大作!
    您需要登录后才可以回帖 登录 | 会员注册

    本版积分规则

    视觉培训招生

    建议您使用Chrome、Firefox、Edge、IE10及以上版本和360等主流浏览器浏览本网站

    51Halcon会员技术交流会员技术交流 | 51Halcon官方客服咨询官方客服咨询 | Halcon切换助手使用反馈切换助手使用

    算子查询| 申请友链| 小黑屋| 手机版| Archiver|

    © 2015-2021 51Halcon机器视觉  X3.4  粤ICP备15095995号 粤公网安备44030602000670号

    快速回复 返回顶部 返回列表