思路
使用卷积操作不断提取图片的特征,在生成大量的特征图后使用全连接,输出分类结果。本质上类似于一个回归问题
网络部分
网络部分参考了ResNet,避免了训练时候的梯度弥散。但我的网络结构也不深,一般也不会弥散。
import torch
from torch import nn
from torch.nn import functional
class ResBlock_2(nn.Module):
"""
resnet的残差块(对应一次残差比较)
最终层数会*2,长宽/2
"""
# inputChannel:一般为3,RGB图
# outputChannel:一般为16
def __init__(self, inputChannel, outputChannel):
super(ResBlock_2, self).__init__()
# x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
self.conv1 = nn.Conv2d(inputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(outputChannel)
# x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
self.conv2 = nn.Conv2d(outputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(outputChannel)
# x: [b, 32 ,h, w] => x:[b, 32 , h/2, w/2]
self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2)
# 短接回路
self.extra = nn.Sequential()
if outputChannel != inputChannel:
# [b, inputChannel, h, w] => [b, outputChannel, h/2, w/2]
self.extra = nn.Sequential(
nn.Conv2d(inputChannel, outputChannel, kernel_size=2, stride=2),
nn.BatchNorm2d(outputChannel)
)
def forward(self, x):
# 计算残差块的卷积输出
out = functional.relu(self.bn1(self.conv1(x)))
out = self.maxPool(self.bn2(self.conv2(out)))
# 计算短接的输出并相加
out = self.extra(x) + out
out = functional.relu(out)
return out
class ResBlock_5(nn.Module):
"""
resnet的残差块(对应一次残差比较)
最终层数会*2,长宽/5
"""
# inputChannel:一般为3,RGB图
# outputChannel:一般为16
def __init__(self, inputChannel, outputChannel):
super(ResBlock_5, self).__init__()
# x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
self.conv1 = nn.Conv2d(inputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(outputChannel)
# x: [b, 16 ,h, w] => x:[b, 32 ,h, w]
self.conv2 = nn.Conv2d(outputChannel, outputChannel, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(outputChannel)
# x: [b, 32 ,h, w] => x:[b, 32 , h/5, w/5]
self.maxPool = nn.MaxPool2d(kernel_size=5, stride=1)
# 短接回路
self.extra = nn.Sequential()
if outputChannel != inputChannel:
# [b, inputChannel, h, w] => [b, outputChannel, h/2, w/2]
self.extra = nn.Sequential(
nn.Conv2d(inputChannel, outputChannel, kernel_size=5, stride=1),
nn.BatchNorm2d(outputChannel)
)
def forward(self, x):
# 计算残差块的卷积输出
out = functional.relu(self.bn1(self.conv1(x)))
out = self.maxPool(self.bn2(self.conv2(out)))
# 计算短接的输出并相加
out = self.extra(x) + out
out = functional.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, classNum):
super(ResNet, self).__init__()
self.conv1 = nn.Sequential(
# [b, 3, 160, 160] => [b, 16, 160, 160]
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 16, 160, 160] => [b, 32, 80, 80]
self.blk1 = ResBlock_2(16, 32)
# [b, 32, 80, 80] => [b, 64, 40, 40]
self.blk2 = ResBlock_2(32, 64)
# # [b, 64, 40, 40] => [b, 128, 20, 20]
self.blk3 = ResBlock_2(64, 128)
# # [b, 128, 20, 20] => [b, 256, 10, 10]
self.blk4 = ResBlock_2(128, 256)
# # [b, 256, 10, 10] => [b, 512, 5, 5]
self.blk5 = ResBlock_2(256, 512)
# # [b, 512, 5, 5] => [b, 1024, 1, 1]
self.blk6 = ResBlock_5(512, 1024)
# [b, 512, 5, 5] => [b, classNum]
self.fc = nn.Sequential(
nn.Linear(1024 * 1 * 1, classNum),
# nn.Linear(512 * 5 * 5, classNum),
# nn.Sigmoid(),
)
def forward(self, x):
x = functional.relu(self.conv1(x))
# [b, 3, 160, 160] => [b, 512, 5, 5]
x = self.blk1(x)
x = nn.functional.dropout2d(x, p=0.1, training=self.training)
x = self.blk2(x)
x = nn.functional.dropout2d(x, p=0.1, training=self.training)
x = self.blk3(x)
x = nn.functional.dropout2d(x, p=0.1, training=self.training)
x = self.blk4(x)
x = nn.functional.dropout2d(x, p=0.1, training=self.training)
x = self.blk5(x)
x = nn.functional.dropout2d(x, p=0.1, training=self.training)
x = self.blk6(x)
# 打平
x = x.view(x.size(0), -1)
logits = self.fc(x)
return logits
训练脚本
训练脚本里面用了visdom,很方便的loss可视化工具,具体使用见《pytorch实时可视化数据工具visdom》
import torch
from torch import nn
import torch.optim as optim
from fire_split.fireData import FireData
from fire_split.utils.visdomUtil import *
from resNet import ResNet
from torch.nn import functional
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from flatten import Flatten
def test(testLoader,gpuDevice,resNet):
# 以下是测试
correctNum = 0
totalNum = 0
for input, target in testLoader:
input, target = input.to(gpuDevice), target.to(gpuDevice)
resNet.eval()
logits = resNet(input)
pred = logits.argmax(dim=1)
correct = torch.eq(pred, target).float().sum().item()
correctNum += correct
totalNum += input.size(0)
accurancyRate = correctNum / totalNum
print("测试集准确率:",accurancyRate)
return accurancyRate
def main():
trainLoss="trainLoss"
testAccurancy="testAccurancy"
trainFirePath="D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\train\\160 x 160"
valFirePath="D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\val\\160 x 160"
classNum=2
resize=160
# 每次批训练的加载张量数(图片数)
batchSize=16
train_db = FireData(trainFirePath, resize, mode='train')
val_db = FireData(valFirePath, resize, mode='train')
train_loader = DataLoader(train_db, batch_size=batchSize, shuffle=True,
num_workers=0)
val_loader = DataLoader(val_db, batch_size=batchSize, num_workers=0,shuffle=True)
# 创建曲线
createLine(trainLoss)
createLine(testAccurancy)
gpuDevice = torch.device("cuda")
cpuDevice = torch.device("cpu")
# resNet = ResNet(classNum).to(gpuDevice)
resNet = ResNet(classNum).to(gpuDevice)
optimizer = optim.Adam(resNet.parameters(),lr=0.001)
# optimizer = optim.SGD(resNet.parameters(), lr=0.05,weight_decay=0.01,momentum=0.9)
# schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min",factor=0.1,patience=10,verbose=True)
criteon = nn.CrossEntropyLoss().to(gpuDevice)
trainCount=0
bestAccurancyRate = 0
epochs=1000
for epoch in range(epochs):
resNet.train()
# batchIndex:batch序号
for batchIndex, (input, target) in enumerate(train_loader):
input = input.to(gpuDevice)
target = target.to(gpuDevice)
logits = resNet(input)
loss = criteon(logits, target)
drawLine(trainLoss,x=trainCount,y=loss.item())
# 判断是否需要动态更新lr
# schedule.step(loss)
optimizer.zero_grad()
loss.backward()
# 控制梯度,避免梯度的模超过10,造成梯度爆炸
for p in resNet.parameters():
torch.nn.utils.clip_grad_norm(p,10)
optimizer.step()
trainCount += 1
# 以下是测试
accurancyRate = test(val_loader, gpuDevice, resNet)
drawLine(testAccurancy, x=epoch, y=accurancyRate)
if accurancyRate > bestAccurancyRate :
bestAccurancyRate=accurancyRate
torch.save(resNet,"bestModel.pth")
if __name__ == '__main__':
main()
测试训练模型
也用了visdom,可视化分类的结果
import time
import torch
from torch.utils.data import DataLoader
import visdom
from fire_split.fireData import FireData
from fire_split.utils.imgUtil import denormalize
visdom = visdom.Visdom()
gpuDevice = torch.device("cuda")
cpuDevice = torch.device("cpu")
resNet = torch.load("bestModel_myBest.pth")
resNet.to(gpuDevice)
resize=160
batchSize=16
# testFirePath = "D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\test\\160 x 160"
testFirePath = "D:\\Develop\\PycharmProjects\\StudyPytorch\\fire_split\\dataSet\\val\\160 x 160"
# testFirePath = "E:\\实验数据\\trainData\\10度_紧密_7根\\160 x 160"
test_db = FireData(testFirePath, resize, mode='test')
test_loader = DataLoader(test_db, batch_size=batchSize, num_workers=0,shuffle=True)
for input, target in test_loader:
input, target = input.to(gpuDevice), target.to(gpuDevice)
resNet.eval()
logits = resNet(input)
pred = logits.argmax(dim=1)
visdom.images(input, nrow=4, win='img', opts=dict(title='img'))
visdom.text(str(pred), win='label', opts=dict(title='label'))
time.sleep(0.1)
评论区