标签搜索

目 录CONTENT

文章目录

手写一个简单的卷积神经网络执行分类任务

陈铭
2021-10-24 / 0 评论 / 0 点赞 / 108 阅读 / 1,502 字 / 正在检测是否收录...

思路

使用卷积操作不断提取图片的特征,在生成大量的特征图后使用全连接,输出分类结果。本质上类似于一个回归问题

网络部分

网络部分参考了ResNet,避免了训练时候的梯度弥散。但我的网络结构也不深,一般也不会弥散。

import torch
from torch import nn
from torch.nn import functional


class ResBlock_2(nn.Module):
    """
    resnet的残差块(对应一次残差比较)
    最终层数会*2,长宽/2
    """

    # inputChannel:一般为3,RGB图
    # outputChannel:一般为16
    def __init__(self, inputChannel, outputChannel):
        super(ResBlock_2, self).__init__()

        # x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
        self.conv1 = nn.Conv2d(inputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(outputChannel)

        # x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
        self.conv2 = nn.Conv2d(outputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(outputChannel)

        # x: [b, 32 ,h, w] => x:[b, 32 , h/2, w/2]
        self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        # 短接回路
        self.extra = nn.Sequential()
        if outputChannel != inputChannel:
            # [b, inputChannel, h, w] => [b, outputChannel, h/2, w/2]
            self.extra = nn.Sequential(
                nn.Conv2d(inputChannel, outputChannel, kernel_size=2, stride=2),
                nn.BatchNorm2d(outputChannel)
            )

    def forward(self, x):
        # 计算残差块的卷积输出
        out = functional.relu(self.bn1(self.conv1(x)))
        out = self.maxPool(self.bn2(self.conv2(out)))

        # 计算短接的输出并相加
        out = self.extra(x) + out
        out = functional.relu(out)
        return out


class ResBlock_5(nn.Module):
    """
    resnet的残差块(对应一次残差比较)
    最终层数会*2,长宽/5
    """

    # inputChannel:一般为3,RGB图
    # outputChannel:一般为16
    def __init__(self, inputChannel, outputChannel):
        super(ResBlock_5, self).__init__()

        # x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
        self.conv1 = nn.Conv2d(inputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(outputChannel)

        # x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
        self.conv2 = nn.Conv2d(outputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(outputChannel)

        # x: [b, 32 ,h, w] => x:[b, 32 , h/5, w/5]
        self.maxPool = nn.MaxPool2d(kernel_size=5, stride=1)

        # 短接回路
        self.extra = nn.Sequential()
        if outputChannel != inputChannel:
            # [b, inputChannel, h, w] => [b, outputChannel, h/2, w/2]
            self.extra = nn.Sequential(
                nn.Conv2d(inputChannel, outputChannel, kernel_size=5, stride=1),
                nn.BatchNorm2d(outputChannel)
            )

    def forward(self, x):
        # 计算残差块的卷积输出
        out = functional.relu(self.bn1(self.conv1(x)))
        out = self.maxPool(self.bn2(self.conv2(out)))

        # 计算短接的输出并相加
        out = self.extra(x) + out
        out = functional.relu(out)
        return out


class ResNet(nn.Module):

    def __init__(self, classNum):
        super(ResNet, self).__init__()

        self.conv1 = nn.Sequential(
            # [b, 3, 160, 160] => [b, 16, 160, 160]
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16)
        )
        # followed 4 blocks
        # [b, 16, 160, 160] => [b, 32, 80, 80]
        self.blk1 = ResBlock_2(16, 32)
        # [b, 32, 80, 80] => [b, 64, 40, 40]
        self.blk2 = ResBlock_2(32, 64)
        # # [b, 64, 40, 40] => [b, 128, 20, 20]
        self.blk3 = ResBlock_2(64, 128)
        # # [b, 128, 20, 20] => [b, 256, 10, 10]
        self.blk4 = ResBlock_2(128, 256)
        # # [b, 256, 10, 10] => [b, 512, 5, 5]
        self.blk5 = ResBlock_2(256, 512)
        # # [b, 512, 5, 5] => [b, 1024, 1, 1]
        self.blk6 = ResBlock_5(512, 1024)

        # [b, 512, 5, 5] => [b, classNum]
        self.fc = nn.Sequential(
            nn.Linear(1024 * 1 * 1, classNum),
            # nn.Linear(512 * 5 * 5, classNum),
            # nn.Sigmoid(),
        )

    def forward(self, x):
        x = functional.relu(self.conv1(x))

        # [b, 3, 160, 160] => [b, 512, 5, 5]
        x = self.blk1(x)
        x = nn.functional.dropout2d(x, p=0.1, training=self.training)
        x = self.blk2(x)
        x = nn.functional.dropout2d(x, p=0.1, training=self.training)
        x = self.blk3(x)
        x = nn.functional.dropout2d(x, p=0.1, training=self.training)
        x = self.blk4(x)
        x = nn.functional.dropout2d(x, p=0.1, training=self.training)
        x = self.blk5(x)
        x = nn.functional.dropout2d(x, p=0.1, training=self.training)
        x = self.blk6(x)

        # 打平
        x = x.view(x.size(0), -1)
        logits = self.fc(x)

        return logits

训练脚本

训练脚本里面用了visdom,很方便的loss可视化工具,具体使用见《pytorch实时可视化数据工具visdom》

import torch
from torch import nn
import torch.optim as optim

from fire_split.fireData import FireData

from fire_split.utils.visdomUtil import *
from resNet import ResNet
from    torch.nn import functional
from    torch.utils.data import DataLoader
from    torchvision import datasets
from    torchvision import transforms
from flatten import Flatten


def test(testLoader,gpuDevice,resNet):
    # 以下是测试
    correctNum = 0
    totalNum = 0
    for input, target in testLoader:
        input, target = input.to(gpuDevice), target.to(gpuDevice)

        resNet.eval()
        logits = resNet(input)
        pred = logits.argmax(dim=1)

        correct = torch.eq(pred, target).float().sum().item()
        correctNum += correct
        totalNum += input.size(0)

    accurancyRate = correctNum / totalNum
    print("测试集准确率:",accurancyRate)

    return accurancyRate

def main():

    trainLoss="trainLoss"
    testAccurancy="testAccurancy"
    trainFirePath="D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\train\\160 x 160"
    valFirePath="D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\val\\160 x 160"
    classNum=2
    resize=160
    # 每次批训练的加载张量数(图片数)
    batchSize=16

    train_db = FireData(trainFirePath, resize, mode='train')
    val_db = FireData(valFirePath, resize, mode='train')
    train_loader = DataLoader(train_db, batch_size=batchSize, shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_db, batch_size=batchSize, num_workers=0,shuffle=True)

    # 创建曲线
    createLine(trainLoss)
    createLine(testAccurancy)

    gpuDevice = torch.device("cuda")
    cpuDevice = torch.device("cpu")

    # resNet = ResNet(classNum).to(gpuDevice)
    resNet = ResNet(classNum).to(gpuDevice)


    optimizer = optim.Adam(resNet.parameters(),lr=0.001)
    # optimizer = optim.SGD(resNet.parameters(), lr=0.05,weight_decay=0.01,momentum=0.9)
    # schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min",factor=0.1,patience=10,verbose=True)
    criteon = nn.CrossEntropyLoss().to(gpuDevice)


    trainCount=0
    bestAccurancyRate = 0

    epochs=1000

    for epoch in range(epochs):
        resNet.train()
        # batchIndex:batch序号
        for batchIndex, (input, target) in enumerate(train_loader):

            input = input.to(gpuDevice)
            target = target.to(gpuDevice)

            logits = resNet(input)

            loss = criteon(logits, target)

            drawLine(trainLoss,x=trainCount,y=loss.item())

            # 判断是否需要动态更新lr
            # schedule.step(loss)


            optimizer.zero_grad()
            loss.backward()
            # 控制梯度,避免梯度的模超过10,造成梯度爆炸
            for p in resNet.parameters():
                torch.nn.utils.clip_grad_norm(p,10)
            optimizer.step()

            trainCount += 1

        # 以下是测试
        accurancyRate = test(val_loader, gpuDevice, resNet)
        drawLine(testAccurancy, x=epoch, y=accurancyRate)
        if accurancyRate > bestAccurancyRate :
            bestAccurancyRate=accurancyRate
            torch.save(resNet,"bestModel.pth")


if __name__ == '__main__':
    main()

测试训练模型

也用了visdom,可视化分类的结果

import time

import torch
from torch.utils.data import DataLoader
import visdom
from fire_split.fireData import FireData
from fire_split.utils.imgUtil import denormalize

visdom = visdom.Visdom()
gpuDevice = torch.device("cuda")
cpuDevice = torch.device("cpu")

resNet = torch.load("bestModel_myBest.pth")
resNet.to(gpuDevice)

resize=160
batchSize=16

# testFirePath = "D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\test\\160 x 160"
testFirePath = "D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\val\\160 x 160"
# testFirePath = "E:\\实验数据\\trainData\\10度_紧密_7根\\160 x 160"

test_db = FireData(testFirePath, resize, mode='test')
test_loader = DataLoader(test_db, batch_size=batchSize, num_workers=0,shuffle=True)

for input, target in test_loader:
    input, target = input.to(gpuDevice), target.to(gpuDevice)

    resNet.eval()
    logits = resNet(input)
    pred = logits.argmax(dim=1)



    visdom.images(input, nrow=4, win='img', opts=dict(title='img'))
    visdom.text(str(pred), win='label', opts=dict(title='label'))

    time.sleep(0.1)

0

评论区