pytorch通过自己的数据集训练Unet网络架构

 更新时间:2022年12月08日 09:15:34   作者:专业女神杀手  
Unet是一个最近比较火的网络结构。它的理论已经有很多大佬在讨论了。本文主要从实际操作的层面,讲解如何使用pytorch实现unet图像分割

在图像分割这个问题上,主要有两个流派:Encoder-Decoder和Dialated Conv。本文介绍的是编解码网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++

一、Unet网络介绍

论文:https://arxiv.org/abs/1505.04597v1(2015)

UNet的设计就是应用与医学图像的分割。由于医学影像处理中,数据量较少,本文提出的方法有效提升了使用少量数据集训练检测的效果,提出了处理大尺寸图像的有效方法。

UNet的网络架构继承自FCN,并在此基础上做了些改变。提出了Encoder-Decoder概念,实际上就是FCN那个先卷积再上采样的思想。

上图是Unet的网络结构,从图中可以看出,

结构左边为Encoder,即下采样提取特征的过程。Encoder基本模块为双卷积形式,即输入经过两个

conu 3x3,使用的valid卷积,在代码实现时我们可以增加padding使用same卷积,来适应Skip Architecture。下采样采用的池化层直接缩小2倍。

结构右边是Decoder,即上采样恢复图像尺寸并预测的过程。Decoder一样采用双卷积的形式,其中上采样使用转置卷积实现,每次转置卷积放大2倍。

结构中间copy and crop是一个cat操作,即feature map的通道叠加。

二、VOC训练Unet

2.1 Unet代码实现

根据上面对于Unet网络结构的介绍,可见其结构非常对称简单,代码Unet.py实现如下:

from turtle import forward
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.output = nn.Conv2d(64, out_ch, 1)
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        up6 = self.up6(conv5)
        meger6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(meger6)
        up7 = self.up7(conv6)
        meger7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(meger7)
        up8 = self.up8(conv7)
        meger8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(meger8)
        up9 = self.up9(conv8)
        meger9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(meger9)
        out = self.output(conv9)
        return out
if __name__=="__main__":
    model = Unet(3, 21)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(model)

2.2 数据集处理

数据来源于kaggle,下载地址我忘了。包含2个类别,1个车,还有1个背景类,共有5k+的数据,按照比例分为训练集和验证集即可。具体见carnava.py

from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt
class Car(Dataset):
    def __init__(self, root, train=True):
        self.root = root
        self.crop_size = (256, 256)
        self.img_path = os.path.join(root, "train_hq")
        self.label_path = os.path.join(root, "train_masks")
        img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
        train_path_list, val_path_list = self._split_data_set(img_path_list)
        if train:
            self.imgs_list = train_path_list
        else:
            self.imgs_list = val_path_list
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transforms = T.Compose([
                T.Resize(256),
                T.CenterCrop(256),
                T.ToTensor(),
                normalize
            ])
        self.transforms_val = T.Compose([
            T.Resize(256),
            T.CenterCrop(256)
        ])
        self.color_map = [[0, 0, 0], [255, 255, 255]]
    def __getitem__(self, index: int):
        im_path = self.imgs_list[index]
        image = Image.open(im_path).convert("RGB")
        data = self.transforms(image)
        (filepath, filename) = os.path.split(im_path)
        filename = filename.split('.')[0]
        label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
        label = self.transforms_val(label)
        cm2lb=np.zeros(256**3)
        for i,cm in enumerate(self.color_map):
            cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
        image=np.array(label,dtype=np.int64)
        idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
        label=np.array(cm2lb[idx],dtype=np.int64)
        label=torch.from_numpy(label).long()
        return data, label
    def label2img(self, label):
        cmap = self.color_map
        cmap = np.array(cmap).astype(np.uint8)
        pred = cmap[label]
        return pred
    def __len__(self):
        return len(self.imgs_list)
    def _split_data_set(self, img_path_list):
        val_path_list = img_path_list[::8]
        train_path_list = []
        for item in img_path_list:
            if item not in val_path_list:
                train_path_list.append(item)
        return train_path_list, val_path_list
if __name__=="__main__":
    root = "../dataset/carvana"
    car_train = Car(root,train=True)
    train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
    print(len(car_train))
    print(len(train_dataloader))
    # for data, label in car_train:
    #     print(data.shape)
    #     print(label.shape)
    #     break
    (data, label) = car_train[190]
    label_np = label.data.numpy()
    label_im = car_train.label2img(label_np)
    plt.figure()
    plt.imshow(label_im)
    plt.show()

2.3 训练过程

分割其实就是给每个像素分类而已,所以损失函数依旧是交叉熵函数,正确率为分类正确的像素点个数/全部的像素点个数

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util
# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu
out_path = "./out"
if not os.path.exists(out_path):
    os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
    os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2
train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)
net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
def train_model():
    best_score = 0.0
    for e in range(epochs):
        net.train()
        train_loss = 0.0
        label_true = torch.LongTensor()
        label_pred = torch.LongTensor()
        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            output = net(data)
            loss = criterion(output, label)
            pred = output.argmax(dim=1).squeeze().data.cpu()
            real = label.data.cpu()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss+=loss.cpu().item()
            label_true = torch.cat((label_true,real),dim=0)
            label_pred = torch.cat((label_pred,pred),dim=0)
        train_loss /= len(train_dataloader)
        acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
        print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
            e+1, train_loss, acc, acc_cls, mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
                e+1,train_loss,acc, acc_cls, mean_iu))
        net.eval()
        val_loss = 0.0
        val_label_true = torch.LongTensor()
        val_label_pred = torch.LongTensor()
        with torch.no_grad():
            for batch_id, (data, label) in enumerate(val_dataloader):
                data, label = data.to(device), label.to(device)
                output = net(data)
                loss = criterion(output, label)
                pred = output.argmax(dim=1).squeeze().data.cpu()
                real = label.data.cpu()
                val_loss += loss.cpu().item()
                val_label_true = torch.cat((val_label_true, real), dim=0)
                val_label_pred = torch.cat((val_label_pred, pred), dim=0)
            val_loss/=len(val_dataloader)
            val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
                                                                    val_label_pred.numpy(),numclasses)
        print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
            e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))
        score = (val_acc_cls+val_mean_iu)/2
        if score > best_score:
            best_score = score
            torch.save(net.state_dict(), model_path)
def evaluate():
    import util
    import random
    import matplotlib.pyplot as plt
    net.load_state_dict(torch.load(model_path))
    index = random.randint(0, len(val_data)-1)
    val_image, val_label = val_data[index]
    out = net(val_image.unsqueeze(0).to(device))
    pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
    label = val_label.data.numpy()
    img_pred = val_data.label2img(pred)
    img_label = val_data.label2img(label)
    temp = val_image.numpy()
    temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
    ax[1].imshow(img_label)
    ax[2].imshow(img_pred)
    plt.show()
if __name__=="__main__":
    # train_model()
    evaluate()

最终训练结果是:

由于数据比较简单,训练到epoch为5时,mIOU就已经达到0.97了。

最后测试一下效果:

从左到右分别是:原图、真实label、预测label

备注:

其实最开始使用voc数据集训练的,但效果极差,也没发现哪里有问题。换个数据集效果就好了,可能有两个原因:

1. voc数据我在处理数据时出错了,没检查出来

2. 这个数据集比较简单,容易学习,所以效果差不多。

到此这篇关于pytorch通过自己的数据集训练Unet网络架构的文章就介绍到这了,更多相关pytorch Unet内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python一秒搭建FTP服务器

    python一秒搭建FTP服务器

    今天给大家分享一篇教程关于python一秒搭建FTP服务器的教程,在搭建过程中需要用到pyftpdlib模块,对python FTP服务器搭建过程感兴趣的朋友跟随小编一起看看吧
    2021-05-05
  • python中turtle库的简单使用教程

    python中turtle库的简单使用教程

    这篇文章主要给大家介绍了关于python中turtle库的简单使用教程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python函数sort()与sorted()的区别及key=lambda x:x[]的理解

    Python函数sort()与sorted()的区别及key=lambda x:x[]的理解

    这篇文章主要介绍了Python函数sort()与sorted()的区别及key=lambda x:x[]的理解方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • Python编程之字符串模板(Template)用法实例分析

    Python编程之字符串模板(Template)用法实例分析

    这篇文章主要介绍了Python编程之字符串模板(Template)用法,结合具体实例形式分析了Python字符串模板的功能、定义与使用方法,需要的朋友可以参考下
    2017-07-07
  • pyqt 多窗口之间的相互调用方法

    pyqt 多窗口之间的相互调用方法

    今天小编就为大家分享一篇pyqt 多窗口之间的相互调用方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python函数的嵌套详解

    Python函数的嵌套详解

    这篇文章主要为大家介绍了Python函数的嵌套,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-01-01
  • python图形界面开发之wxPython树控件使用方法详解

    python图形界面开发之wxPython树控件使用方法详解

    这篇文章主要介绍了python图形界面开发之wxPython树控件使用方法详解,需要的朋友可以参考下
    2020-02-02
  • numpy.linspace函数具体使用详解

    numpy.linspace函数具体使用详解

    这篇文章主要介绍了numpy.linspace具体使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • 使用python画个小猪佩奇的示例代码

    使用python画个小猪佩奇的示例代码

    本文给大家较详细的介绍了使用python画个小猪佩奇的示例代码,感兴趣的朋友一起看看吧
    2018-06-06
  • Django之form组件自动校验数据实现

    Django之form组件自动校验数据实现

    这篇文章主要介绍了Django之form组件自动校验数据实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-01-01

最新评论