博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
8.Pytorch实现5层全连接结构的MNIST(手写数字识别)
阅读量:4203 次
发布时间:2019-05-26

本文共 4341 字,大约阅读时间需要 14 分钟。

注意几个点:

  • 这里使用的是MNIST数据集,使用的是5层的全连接结构(算上输入层)
  • softmax的输入不需要再做非线性变换,模型会将其输出——计算对应的占比

1 prepare dataset

import torchfrom torchvision import transformsfrom torchvision import datasetsfrom torch.utils.data import DataLoaderimport torch.nn.functional as Fimport torch.optim as optim batch_size = 64# 这里是对输入的图片进行预处理操作—》# 图像变换操作:transforms.ToTensor(),将PIL Image或者 ndarray 转换为tensor,并且归一化(直接除于255)至[0-1]# 图像变换操作:transforms.Normalize,对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwctransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MNIST(root='./资料/data/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='./资料/data/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
D:\common_software\Anaconda\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:180.)  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

2 design model using class

class Net(torch.nn.Module):    def __init__(self):        super(Net, self).__init__()        self.l1 = torch.nn.Linear(784, 512)        self.l2 = torch.nn.Linear(512, 256)        self.l3 = torch.nn.Linear(256, 128)        self.l4 = torch.nn.Linear(128, 64)        self.l5 = torch.nn.Linear(64, 10)     def forward(self, x):        x = x.view(-1, 784)  # -1其实就是自动获取mini_batch        x = F.relu(self.l1(x))        x = F.relu(self.l2(x))        x = F.relu(self.l3(x))        x = F.relu(self.l4(x))        return self.l5(x)  # 最后一层不做激活,不进行非线性变换  model = Net()

3 construct loss and optimizer

# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

4 training cycle forward, backward, update

# training cycle forward, backward, updatedef train(epoch):    running_loss = 0.0    for batch_idx, data in enumerate(train_loader, 0):        # 获得一个批次的数据和标签        inputs, target = data        optimizer.zero_grad()                # 获得模型预测结果(64, 10)        outputs = model(inputs)                # 交叉熵代价函数outputs(64,10),target(64)        loss = criterion(outputs, target)        loss.backward()        optimizer.step()         running_loss += loss.item()        if batch_idx % 300 == 299:            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))            running_loss = 0.0  def test():    correct = 0    total = 0    with torch.no_grad():        for data in test_loader:            images, labels = data            outputs = model(images)            _, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度            total += labels.size(0)            correct += (predicted == labels).sum().item() # 张量之间的比较运算    print('accuracy on test set: %d %% ' % (100*correct/total))
if __name__ == '__main__':    for epoch in range(10):        train(epoch)        test()
[1,   300] loss: 0.342[1,   600] loss: 0.300[1,   900] loss: 0.252accuracy on test set: 93 % [2,   300] loss: 0.212[2,   600] loss: 0.187[2,   900] loss: 0.165accuracy on test set: 95 % [3,   300] loss: 0.139[3,   600] loss: 0.134[3,   900] loss: 0.128accuracy on test set: 95 % [4,   300] loss: 0.106[4,   600] loss: 0.099[4,   900] loss: 0.098accuracy on test set: 96 % [5,   300] loss: 0.078[5,   600] loss: 0.083[5,   900] loss: 0.082accuracy on test set: 96 % [6,   300] loss: 0.062[6,   600] loss: 0.072[6,   900] loss: 0.063accuracy on test set: 97 % [7,   300] loss: 0.050[7,   600] loss: 0.056[7,   900] loss: 0.053accuracy on test set: 97 % [8,   300] loss: 0.041[8,   600] loss: 0.045[8,   900] loss: 0.043accuracy on test set: 97 % [9,   300] loss: 0.033[9,   600] loss: 0.035[9,   900] loss: 0.038accuracy on test set: 97 % [10,   300] loss: 0.024[10,   600] loss: 0.030[10,   900] loss: 0.030accuracy on test set: 97 %

转载地址:http://kfali.baihongyu.com/

你可能感兴趣的文章
JQuery 简介
查看>>
Java创建对象的方法
查看>>
Extjs自定义组件
查看>>
TreeGrid 异步加载节点
查看>>
Struts2 标签库讲解
查看>>
Google Web工具包 GWT
查看>>
材料与工程学科相关软件
查看>>
MPI的人怎么用仪器
查看>>
windows 下AdNDP 安装使用
查看>>
Project 2013项目管理教程(1):项目管理概述及预备
查看>>
ssh客户端后台运行
查看>>
哥去求职,才说了一句话考官就让我出去
查看>>
【React Native】把现代web科技带给移动开发者(一)
查看>>
【GoLang】Web工作方式
查看>>
Launch Sublime Text 3 from the command line
查看>>
【数据库之mysql】mysql的安装(一)
查看>>
【数据库之mysql】 mysql 入门教程(二)
查看>>
【HTML5/CSS/JS】A list of Font Awesome icons and their CSS content values(一)
查看>>
【HTML5/CSS/JS】<br>与<p>标签区别(二)
查看>>
【HTML5/CSS/JS】开发跨平台应用工具的选择(三)
查看>>