如何训练一个神经网络

我们先回忆一下,我们是如何训练一个用于分类的神经网络的:
1. 定义一个网络结构。一般是几个卷积层,一个池化层,再几个卷积层,一个池化层,然后接几个全连接层,对于最后一个全连接层,它的神经元个数一般与分类个数相同,最后用softmax输出每个分类的概率。概率最大的分类作为这个分类神经网络的最后输出。定义完网络后,初始化网络参数。
2. 选择一个loss function,对分类问题,我们一般选择交叉熵损失函数。
3. 用一个batch的训练数据正向传播,计算loss值。
4. 将loss值反向传播,计算每个参数的梯度。
5. 更新网络参数,网络参数的更新规则是:weight = weight - learning_rate * gradient

如何在PyTorch里定义一个神经网络

在PyTorch里有一个类torch.nn.Module,它可以帮助我们大大简化定义一个网络结构,以及对网络参数的管理,反向传播等工作。我们只需要定义一个自己的神经网络的类,继承自torch.nn.Module。实现__init__和forward方法即可。在__init__方法里我们定义网络的各个层,在forward方法里,我们将网络各层连接起来,让输入通过正向传播,生成网络的输出返回。


class MyNetWork(nn.Module):
    def __init__(self):
        super(MyNetWork,self).__init__()
        self.conv1 = nn.Conv2d(3,10,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(10,20,5)
        self.fc1 = nn.Linear(20*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1,20*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

首先看一下__init__方法,在我们定义的初始化方法里,首先调用父类nn.Module的初始化方法,然后我们定义网络的各个层,分别是卷积层1,池化层,卷积层2,三个全连接层。

接下来我们看定义正向传播的forward方法,它有一个输入x,然后对于x,我们进行卷积,用relu激活,接着池化。然后在进行一次卷积,relu,池化。在接入全连接之前,我们首先将tensor转化为一维的。

对网络进行训练

接着我们还有几件事要做,才能训练网络。

定义损失函数

我们用交叉熵损失函数:

criterion = nn.CrossEntropyLoss()

定义优化器

首先我们创建一个我们定义的网络,因为继承了pytorch.nn.Module,我们可以直接获取我们定义的网络里的paramter,然后对这些参数用随机梯度下降的优化器,在每次迭代里自动进行反向传播以及参数更新。我们只需要在优化器初始化时指定学习率等初始化参数即可。

network = MyNetWork()
optimizer = optim.SGD(network.parameters(),lr=0.001,momentum=0.9)

开始训练

我们训练5个epoch,bach大小定义在data_loader里。

for epoch in range(5):
    running_loss = 0.0
    for i,data in enumerate(train_data_loader,0):
        inputs,labels = data[0].to(device),data[1].to(device)
        optimizer.zero_grad()
        outputs = network(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
        if i% 20 == 19:
            print('epoch:%d,batch:%5d,loss:%.5f' %(epoch+1,i+1,running_loss/20))
            running_loss = 0.0
print('Finished Training.')

放在GPU训练

只需要把网络和输入数据转到GPU即可:

network = MyNetWork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network.to(device)

inputs,labels = data[0].to(device),data[1].to(device)

完整代码


import torch
import torchvision as tv
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data_path = "E:/data/cifar10/train/"
train_data = tv.datasets.ImageFolder(train_data_path,transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size,shuffle=True)

test_data_path = "E:/data/cifar10/test/"
test_data = tv.datasets.ImageFolder(test_data_path,transform=transform)
test_data_loader = torch.utils.data.DataLoader(test_data,batch_size,shuffle=False)

class MyNetWork(nn.Module):
    def __init__(self):
        super(MyNetWork,self).__init__()
        self.conv1 = nn.Conv2d(3,10,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(10,20,5)
        self.fc1 = nn.Linear(20*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

def forward(self,x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.reshape(-1,20*5*5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

network = MyNetWork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(),lr=0.001,momentum=0.9)

for epoch in range(5):
    running_loss = 0.0
    for i,data in enumerate(train_data_loader,0):
        inputs,labels = data[0].to(device),data[1].to(device)
        optimizer.zero_grad()
        outputs = network(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
        if i% 20 == 19:
            print('epoch:%d,batch:%5d,loss:%.5f' %(epoch+1,i+1,running_loss/20))
            running_loss = 0.0
print('Finished Training.')

输出结果:

epoch:1,batch:   20,loss:2.30407
epoch:1,batch:   40,loss:2.30514
epoch:1,batch:   60,loss:2.30424
epoch:1,batch:   80,loss:2.30183
epoch:1,batch:  100,loss:2.30199
epoch:1,batch:  120,loss:2.30430
epoch:1,batch:  140,loss:2.30368
epoch:1,batch:  160,loss:2.30484
epoch:1,batch:  180,loss:2.30222
epoch:1,batch:  200,loss:2.30332
epoch:1,batch:  220,loss:2.30390
epoch:1,batch:  240,loss:2.30387
epoch:1,batch:  260,loss:2.30224
epoch:1,batch:  280,loss:2.30129
epoch:1,batch:  300,loss:2.30354
epoch:1,batch:  320,loss:2.30145
epoch:1,batch:  340,loss:2.30055
epoch:1,batch:  360,loss:2.30392
epoch:1,batch:  380,loss:2.30170
epoch:1,batch:  400,loss:2.30211
epoch:1,batch:  420,loss:2.29953
epoch:1,batch:  440,loss:2.30088
epoch:1,batch:  460,loss:2.30293
epoch:1,batch:  480,loss:2.30295
epoch:1,batch:  500,loss:2.30138
epoch:1,batch:  520,loss:2.30210
epoch:1,batch:  540,loss:2.30024
epoch:1,batch:  560,loss:2.30039
epoch:1,batch:  580,loss:2.30072
epoch:1,batch:  600,loss:2.30042
epoch:1,batch:  620,loss:2.30036
epoch:1,batch:  640,loss:2.30065
epoch:1,batch:  660,loss:2.29986
epoch:1,batch:  680,loss:2.30053
epoch:1,batch:  700,loss:2.30006
epoch:1,batch:  720,loss:2.29997
epoch:1,batch:  740,loss:2.30031
epoch:1,batch:  760,loss:2.29934
epoch:1,batch:  780,loss:2.29972
epoch:2,batch:   20,loss:2.29991
epoch:2,batch:   40,loss:2.29845
epoch:2,batch:   60,loss:2.29885
epoch:2,batch:   80,loss:2.29831
epoch:2,batch:  100,loss:2.29750
epoch:2,batch:  120,loss:2.29753
epoch:2,batch:  140,loss:2.29753
epoch:2,batch:  160,loss:2.29778
epoch:2,batch:  180,loss:2.29660
epoch:2,batch:  200,loss:2.29734
epoch:2,batch:  220,loss:2.29647
epoch:2,batch:  240,loss:2.29604
epoch:2,batch:  260,loss:2.29590
epoch:2,batch:  280,loss:2.29516
epoch:2,batch:  300,loss:2.29531
epoch:2,batch:  320,loss:2.29341
epoch:2,batch:  340,loss:2.29357
epoch:2,batch:  360,loss:2.29263
epoch:2,batch:  380,loss:2.29335
epoch:2,batch:  400,loss:2.29060
epoch:2,batch:  420,loss:2.29001
epoch:2,batch:  440,loss:2.28912
epoch:2,batch:  460,loss:2.28862
epoch:2,batch:  480,loss:2.28826
epoch:2,batch:  500,loss:2.28609
epoch:2,batch:  520,loss:2.28272
epoch:2,batch:  540,loss:2.28185
epoch:2,batch:  560,loss:2.28122
epoch:2,batch:  580,loss:2.27785
epoch:2,batch:  600,loss:2.27587
epoch:2,batch:  620,loss:2.27441
epoch:2,batch:  640,loss:2.27153
epoch:2,batch:  660,loss:2.26470
epoch:2,batch:  680,loss:2.26031
epoch:2,batch:  700,loss:2.25865
epoch:2,batch:  720,loss:2.25509
epoch:2,batch:  740,loss:2.24848
epoch:2,batch:  760,loss:2.24276
epoch:2,batch:  780,loss:2.23898
epoch:3,batch:   20,loss:2.22724
epoch:3,batch:   40,loss:2.22337
epoch:3,batch:   60,loss:2.20429
epoch:3,batch:   80,loss:2.20147
epoch:3,batch:  100,loss:2.18487
epoch:3,batch:  120,loss:2.16920
epoch:3,batch:  140,loss:2.16260
epoch:3,batch:  160,loss:2.14475
epoch:3,batch:  180,loss:2.12718
epoch:3,batch:  200,loss:2.11580
epoch:3,batch:  220,loss:2.10174
epoch:3,batch:  240,loss:2.07965
epoch:3,batch:  260,loss:2.07347
epoch:3,batch:  280,loss:2.06264
epoch:3,batch:  300,loss:2.02216
epoch:3,batch:  320,loss:2.03486
epoch:3,batch:  340,loss:2.00265
epoch:3,batch:  360,loss:2.00933
epoch:3,batch:  380,loss:2.00810
epoch:3,batch:  400,loss:2.01137
epoch:3,batch:  420,loss:2.01366
epoch:3,batch:  440,loss:1.99376
epoch:3,batch:  460,loss:2.00583
epoch:3,batch:  480,loss:1.97421
epoch:3,batch:  500,loss:1.95560
epoch:3,batch:  520,loss:1.95393
epoch:3,batch:  540,loss:1.95929
epoch:3,batch:  560,loss:1.93034
epoch:3,batch:  580,loss:1.92830
epoch:3,batch:  600,loss:1.94590
epoch:3,batch:  620,loss:1.93020
epoch:3,batch:  640,loss:1.92495
epoch:3,batch:  660,loss:1.92019
epoch:3,batch:  680,loss:1.90083
epoch:3,batch:  700,loss:1.90906
epoch:3,batch:  720,loss:1.89305
epoch:3,batch:  740,loss:1.91752
epoch:3,batch:  760,loss:1.91681
epoch:3,batch:  780,loss:1.89553
epoch:4,batch:   20,loss:1.85242
epoch:4,batch:   40,loss:1.87267
epoch:4,batch:   60,loss:1.86945
epoch:4,batch:   80,loss:1.85676
epoch:4,batch:  100,loss:1.88468
epoch:4,batch:  120,loss:1.82760
epoch:4,batch:  140,loss:1.85370
epoch:4,batch:  160,loss:1.86612
epoch:4,batch:  180,loss:1.82057
epoch:4,batch:  200,loss:1.88566
epoch:4,batch:  220,loss:1.85049
epoch:4,batch:  240,loss:1.85989
epoch:4,batch:  260,loss:1.80388
epoch:4,batch:  280,loss:1.80448
epoch:4,batch:  300,loss:1.79293
epoch:4,batch:  320,loss:1.82317
epoch:4,batch:  340,loss:1.80807
epoch:4,batch:  360,loss:1.76351
epoch:4,batch:  380,loss:1.77514
epoch:4,batch:  400,loss:1.81420
epoch:4,batch:  420,loss:1.80197
epoch:4,batch:  440,loss:1.80661
epoch:4,batch:  460,loss:1.78744
epoch:4,batch:  480,loss:1.81156
epoch:4,batch:  500,loss:1.75178
epoch:4,batch:  520,loss:1.75818
epoch:4,batch:  540,loss:1.79016
epoch:4,batch:  560,loss:1.77460
epoch:4,batch:  580,loss:1.76466
epoch:4,batch:  600,loss:1.73503
epoch:4,batch:  620,loss:1.72737
epoch:4,batch:  640,loss:1.69971
epoch:4,batch:  660,loss:1.75526
epoch:4,batch:  680,loss:1.67342
epoch:4,batch:  700,loss:1.73263
epoch:4,batch:  720,loss:1.70854
epoch:4,batch:  740,loss:1.73153
epoch:4,batch:  760,loss:1.69152
epoch:4,batch:  780,loss:1.69597
epoch:5,batch:   20,loss:1.70627
epoch:5,batch:   40,loss:1.73348
epoch:5,batch:   60,loss:1.66977
epoch:5,batch:   80,loss:1.71927
epoch:5,batch:  100,loss:1.70653
epoch:5,batch:  120,loss:1.69865
epoch:5,batch:  140,loss:1.70063
epoch:5,batch:  160,loss:1.66975
epoch:5,batch:  180,loss:1.64562
epoch:5,batch:  200,loss:1.61968
epoch:5,batch:  220,loss:1.65880
epoch:5,batch:  240,loss:1.66435
epoch:5,batch:  260,loss:1.61658
epoch:5,batch:  280,loss:1.65945
epoch:5,batch:  300,loss:1.67617
epoch:5,batch:  320,loss:1.65040
epoch:5,batch:  340,loss:1.61006
epoch:5,batch:  360,loss:1.65418
epoch:5,batch:  380,loss:1.67202
epoch:5,batch:  400,loss:1.68170
epoch:5,batch:  420,loss:1.61025
epoch:5,batch:  440,loss:1.64626
epoch:5,batch:  460,loss:1.57466
epoch:5,batch:  480,loss:1.60090
epoch:5,batch:  500,loss:1.63489
epoch:5,batch:  520,loss:1.61065
epoch:5,batch:  540,loss:1.62483
epoch:5,batch:  560,loss:1.60688
epoch:5,batch:  580,loss:1.64252
epoch:5,batch:  600,loss:1.62119
epoch:5,batch:  620,loss:1.57755
epoch:5,batch:  640,loss:1.61463
epoch:5,batch:  660,loss:1.57873
epoch:5,batch:  680,loss:1.59183
epoch:5,batch:  700,loss:1.60572
epoch:5,batch:  720,loss:1.63272
epoch:5,batch:  740,loss:1.57111
epoch:5,batch:  760,loss:1.59886
epoch:5,batch:  780,loss:1.59199
Finished Training.

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: