PyTorch之四:构建和训练一个模型
如何训练一个神经网络
我们先回忆一下,我们是如何训练一个用于分类的神经网络的:
1. 定义一个网络结构。一般是几个卷积层,一个池化层,再几个卷积层,一个池化层,然后接几个全连接层,对于最后一个全连接层,它的神经元个数一般与分类个数相同,最后用softmax输出每个分类的概率。概率最大的分类作为这个分类神经网络的最后输出。定义完网络后,初始化网络参数。
2. 选择一个loss function,对分类问题,我们一般选择交叉熵损失函数。
3. 用一个batch的训练数据正向传播,计算loss值。
4. 将loss值反向传播,计算每个参数的梯度。
5. 更新网络参数,网络参数的更新规则是:weight = weight - learning_rate * gradient
如何在PyTorch里定义一个神经网络
在PyTorch里有一个类torch.nn.Module,它可以帮助我们大大简化定义一个网络结构,以及对网络参数的管理,反向传播等工作。我们只需要定义一个自己的神经网络的类,继承自torch.nn.Module。实现__init__和forward方法即可。在__init__方法里我们定义网络的各个层,在forward方法里,我们将网络各层连接起来,让输入通过正向传播,生成网络的输出返回。
class MyNetWork(nn.Module):
def __init__(self):
super(MyNetWork,self).__init__()
self.conv1 = nn.Conv2d(3,10,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(10,20,5)
self.fc1 = nn.Linear(20*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.reshape(-1,20*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
首先看一下__init__方法,在我们定义的初始化方法里,首先调用父类nn.Module的初始化方法,然后我们定义网络的各个层,分别是卷积层1,池化层,卷积层2,三个全连接层。
接下来我们看定义正向传播的forward方法,它有一个输入x,然后对于x,我们进行卷积,用relu激活,接着池化。然后在进行一次卷积,relu,池化。在接入全连接之前,我们首先将tensor转化为一维的。
对网络进行训练
接着我们还有几件事要做,才能训练网络。
定义损失函数
我们用交叉熵损失函数:
定义优化器
首先我们创建一个我们定义的网络,因为继承了pytorch.nn.Module,我们可以直接获取我们定义的网络里的paramter,然后对这些参数用随机梯度下降的优化器,在每次迭代里自动进行反向传播以及参数更新。我们只需要在优化器初始化时指定学习率等初始化参数即可。
optimizer = optim.SGD(network.parameters(),lr=0.001,momentum=0.9)
开始训练
我们训练5个epoch,bach大小定义在data_loader里。
running_loss = 0.0
for i,data in enumerate(train_data_loader,0):
inputs,labels = data[0].to(device),data[1].to(device)
optimizer.zero_grad()
outputs = network(inputs)
loss = criterion(outputs,labels)
loss.backward()
running_loss += loss.item()
optimizer.step()
if i% 20 == 19:
print('epoch:%d,batch:%5d,loss:%.5f' %(epoch+1,i+1,running_loss/20))
running_loss = 0.0
print('Finished Training.')
放在GPU训练
只需要把网络和输入数据转到GPU即可:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network.to(device)
inputs,labels = data[0].to(device),data[1].to(device)
完整代码
import torch
import torchvision as tv
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data_path = "E:/data/cifar10/train/"
train_data = tv.datasets.ImageFolder(train_data_path,transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size,shuffle=True)
test_data_path = "E:/data/cifar10/test/"
test_data = tv.datasets.ImageFolder(test_data_path,transform=transform)
test_data_loader = torch.utils.data.DataLoader(test_data,batch_size,shuffle=False)
class MyNetWork(nn.Module):
def __init__(self):
super(MyNetWork,self).__init__()
self.conv1 = nn.Conv2d(3,10,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(10,20,5)
self.fc1 = nn.Linear(20*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.reshape(-1,20*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
network = MyNetWork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(),lr=0.001,momentum=0.9)
for epoch in range(5):
running_loss = 0.0
for i,data in enumerate(train_data_loader,0):
inputs,labels = data[0].to(device),data[1].to(device)
optimizer.zero_grad()
outputs = network(inputs)
loss = criterion(outputs,labels)
loss.backward()
running_loss += loss.item()
optimizer.step()
if i% 20 == 19:
print('epoch:%d,batch:%5d,loss:%.5f' %(epoch+1,i+1,running_loss/20))
running_loss = 0.0
print('Finished Training.')
输出结果:
epoch:1,batch: 20,loss:2.30407 epoch:1,batch: 40,loss:2.30514 epoch:1,batch: 60,loss:2.30424 epoch:1,batch: 80,loss:2.30183 epoch:1,batch: 100,loss:2.30199 epoch:1,batch: 120,loss:2.30430 epoch:1,batch: 140,loss:2.30368 epoch:1,batch: 160,loss:2.30484 epoch:1,batch: 180,loss:2.30222 epoch:1,batch: 200,loss:2.30332 epoch:1,batch: 220,loss:2.30390 epoch:1,batch: 240,loss:2.30387 epoch:1,batch: 260,loss:2.30224 epoch:1,batch: 280,loss:2.30129 epoch:1,batch: 300,loss:2.30354 epoch:1,batch: 320,loss:2.30145 epoch:1,batch: 340,loss:2.30055 epoch:1,batch: 360,loss:2.30392 epoch:1,batch: 380,loss:2.30170 epoch:1,batch: 400,loss:2.30211 epoch:1,batch: 420,loss:2.29953 epoch:1,batch: 440,loss:2.30088 epoch:1,batch: 460,loss:2.30293 epoch:1,batch: 480,loss:2.30295 epoch:1,batch: 500,loss:2.30138 epoch:1,batch: 520,loss:2.30210 epoch:1,batch: 540,loss:2.30024 epoch:1,batch: 560,loss:2.30039 epoch:1,batch: 580,loss:2.30072 epoch:1,batch: 600,loss:2.30042 epoch:1,batch: 620,loss:2.30036 epoch:1,batch: 640,loss:2.30065 epoch:1,batch: 660,loss:2.29986 epoch:1,batch: 680,loss:2.30053 epoch:1,batch: 700,loss:2.30006 epoch:1,batch: 720,loss:2.29997 epoch:1,batch: 740,loss:2.30031 epoch:1,batch: 760,loss:2.29934 epoch:1,batch: 780,loss:2.29972 epoch:2,batch: 20,loss:2.29991 epoch:2,batch: 40,loss:2.29845 epoch:2,batch: 60,loss:2.29885 epoch:2,batch: 80,loss:2.29831 epoch:2,batch: 100,loss:2.29750 epoch:2,batch: 120,loss:2.29753 epoch:2,batch: 140,loss:2.29753 epoch:2,batch: 160,loss:2.29778 epoch:2,batch: 180,loss:2.29660 epoch:2,batch: 200,loss:2.29734 epoch:2,batch: 220,loss:2.29647 epoch:2,batch: 240,loss:2.29604 epoch:2,batch: 260,loss:2.29590 epoch:2,batch: 280,loss:2.29516 epoch:2,batch: 300,loss:2.29531 epoch:2,batch: 320,loss:2.29341 epoch:2,batch: 340,loss:2.29357 epoch:2,batch: 360,loss:2.29263 epoch:2,batch: 380,loss:2.29335 epoch:2,batch: 400,loss:2.29060 epoch:2,batch: 420,loss:2.29001 epoch:2,batch: 440,loss:2.28912 epoch:2,batch: 460,loss:2.28862 epoch:2,batch: 480,loss:2.28826 epoch:2,batch: 500,loss:2.28609 epoch:2,batch: 520,loss:2.28272 epoch:2,batch: 540,loss:2.28185 epoch:2,batch: 560,loss:2.28122 epoch:2,batch: 580,loss:2.27785 epoch:2,batch: 600,loss:2.27587 epoch:2,batch: 620,loss:2.27441 epoch:2,batch: 640,loss:2.27153 epoch:2,batch: 660,loss:2.26470 epoch:2,batch: 680,loss:2.26031 epoch:2,batch: 700,loss:2.25865 epoch:2,batch: 720,loss:2.25509 epoch:2,batch: 740,loss:2.24848 epoch:2,batch: 760,loss:2.24276 epoch:2,batch: 780,loss:2.23898 epoch:3,batch: 20,loss:2.22724 epoch:3,batch: 40,loss:2.22337 epoch:3,batch: 60,loss:2.20429 epoch:3,batch: 80,loss:2.20147 epoch:3,batch: 100,loss:2.18487 epoch:3,batch: 120,loss:2.16920 epoch:3,batch: 140,loss:2.16260 epoch:3,batch: 160,loss:2.14475 epoch:3,batch: 180,loss:2.12718 epoch:3,batch: 200,loss:2.11580 epoch:3,batch: 220,loss:2.10174 epoch:3,batch: 240,loss:2.07965 epoch:3,batch: 260,loss:2.07347 epoch:3,batch: 280,loss:2.06264 epoch:3,batch: 300,loss:2.02216 epoch:3,batch: 320,loss:2.03486 epoch:3,batch: 340,loss:2.00265 epoch:3,batch: 360,loss:2.00933 epoch:3,batch: 380,loss:2.00810 epoch:3,batch: 400,loss:2.01137 epoch:3,batch: 420,loss:2.01366 epoch:3,batch: 440,loss:1.99376 epoch:3,batch: 460,loss:2.00583 epoch:3,batch: 480,loss:1.97421 epoch:3,batch: 500,loss:1.95560 epoch:3,batch: 520,loss:1.95393 epoch:3,batch: 540,loss:1.95929 epoch:3,batch: 560,loss:1.93034 epoch:3,batch: 580,loss:1.92830 epoch:3,batch: 600,loss:1.94590 epoch:3,batch: 620,loss:1.93020 epoch:3,batch: 640,loss:1.92495 epoch:3,batch: 660,loss:1.92019 epoch:3,batch: 680,loss:1.90083 epoch:3,batch: 700,loss:1.90906 epoch:3,batch: 720,loss:1.89305 epoch:3,batch: 740,loss:1.91752 epoch:3,batch: 760,loss:1.91681 epoch:3,batch: 780,loss:1.89553 epoch:4,batch: 20,loss:1.85242 epoch:4,batch: 40,loss:1.87267 epoch:4,batch: 60,loss:1.86945 epoch:4,batch: 80,loss:1.85676 epoch:4,batch: 100,loss:1.88468 epoch:4,batch: 120,loss:1.82760 epoch:4,batch: 140,loss:1.85370 epoch:4,batch: 160,loss:1.86612 epoch:4,batch: 180,loss:1.82057 epoch:4,batch: 200,loss:1.88566 epoch:4,batch: 220,loss:1.85049 epoch:4,batch: 240,loss:1.85989 epoch:4,batch: 260,loss:1.80388 epoch:4,batch: 280,loss:1.80448 epoch:4,batch: 300,loss:1.79293 epoch:4,batch: 320,loss:1.82317 epoch:4,batch: 340,loss:1.80807 epoch:4,batch: 360,loss:1.76351 epoch:4,batch: 380,loss:1.77514 epoch:4,batch: 400,loss:1.81420 epoch:4,batch: 420,loss:1.80197 epoch:4,batch: 440,loss:1.80661 epoch:4,batch: 460,loss:1.78744 epoch:4,batch: 480,loss:1.81156 epoch:4,batch: 500,loss:1.75178 epoch:4,batch: 520,loss:1.75818 epoch:4,batch: 540,loss:1.79016 epoch:4,batch: 560,loss:1.77460 epoch:4,batch: 580,loss:1.76466 epoch:4,batch: 600,loss:1.73503 epoch:4,batch: 620,loss:1.72737 epoch:4,batch: 640,loss:1.69971 epoch:4,batch: 660,loss:1.75526 epoch:4,batch: 680,loss:1.67342 epoch:4,batch: 700,loss:1.73263 epoch:4,batch: 720,loss:1.70854 epoch:4,batch: 740,loss:1.73153 epoch:4,batch: 760,loss:1.69152 epoch:4,batch: 780,loss:1.69597 epoch:5,batch: 20,loss:1.70627 epoch:5,batch: 40,loss:1.73348 epoch:5,batch: 60,loss:1.66977 epoch:5,batch: 80,loss:1.71927 epoch:5,batch: 100,loss:1.70653 epoch:5,batch: 120,loss:1.69865 epoch:5,batch: 140,loss:1.70063 epoch:5,batch: 160,loss:1.66975 epoch:5,batch: 180,loss:1.64562 epoch:5,batch: 200,loss:1.61968 epoch:5,batch: 220,loss:1.65880 epoch:5,batch: 240,loss:1.66435 epoch:5,batch: 260,loss:1.61658 epoch:5,batch: 280,loss:1.65945 epoch:5,batch: 300,loss:1.67617 epoch:5,batch: 320,loss:1.65040 epoch:5,batch: 340,loss:1.61006 epoch:5,batch: 360,loss:1.65418 epoch:5,batch: 380,loss:1.67202 epoch:5,batch: 400,loss:1.68170 epoch:5,batch: 420,loss:1.61025 epoch:5,batch: 440,loss:1.64626 epoch:5,batch: 460,loss:1.57466 epoch:5,batch: 480,loss:1.60090 epoch:5,batch: 500,loss:1.63489 epoch:5,batch: 520,loss:1.61065 epoch:5,batch: 540,loss:1.62483 epoch:5,batch: 560,loss:1.60688 epoch:5,batch: 580,loss:1.64252 epoch:5,batch: 600,loss:1.62119 epoch:5,batch: 620,loss:1.57755 epoch:5,batch: 640,loss:1.61463 epoch:5,batch: 660,loss:1.57873 epoch:5,batch: 680,loss:1.59183 epoch:5,batch: 700,loss:1.60572 epoch:5,batch: 720,loss:1.63272 epoch:5,batch: 740,loss:1.57111 epoch:5,batch: 760,loss:1.59886 epoch:5,batch: 780,loss:1.59199 Finished Training.