pytorch如何使用Imagenet预训练模型训练

 更新时间:2023年09月09日 11:27:08   作者:josenxiao  
这篇文章主要介绍了pytorch如何使用Imagenet预训练模型训练问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

pytorch使用Imagenet预训练模型训练

1、loading models

#加载以resnet50为例子
import torchvision as p
model = p.models.resnet50(pretrained=True)

此时加载数据模型以后,我们要是思考如何利用它,但是在此之前你必须了解你加载的模型的结构。

2、处理分类数据

如果是用来处理分类数据:

你只需要替换最后一个全连接分类进行输出。

model.fc = nn.Sequential(nn.Linear(2048,num_classes))
######

3、作为模型的backbone

如果你需要作为要做模型的bacbone,比如RCNN、Semantic Segment等,此时你要将这些模型预加载进行来,以下面的一个FCN8-语义切割为例子:

这里的model就是之前Resnet50model that has pretrained Imageset dataset

class FCN(nn.Module):
    def __init__(self):
        super(FCN,self).__init__()
        self.layer1 = nn.Conv2d(256,nClasses,1,stride=1,padding=0,bias=True)
        self.trans = nn.ConvTranspose2d(nClasses,nClasses,2,stride=2,padding=0,bias=True)
        self.layer2 = nn.Conv2d(128,nClasses,1,stride=1,padding=0,bias=True)
        self.up = nn.ConvTranspose2d(nClasses,nClasses,8,stride=8,padding=0,bias=True)
        for m in self.modules():
            if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
                #m.weight.detach().normal_(0,0.01)
                nn.init.xavier_uniform(m.weight.data)
                m.bias.detach().zero_()
    def forward(self,x,model):
        x = model.conv1(x)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        x = model.layer1(x)
        x1 = model.layer2(x)
        x2 = model.layer3(x1)
        #layers.append(x)#20
        x = model.layer4(x2)
        x = model.avgpool(x)#20
        skip = self.layer1(x2)
        y = skip + x
        c = self.trans(y)
        #### 40
        v = self.layer2(x1)
        y = c+v
        x = self.up(y)
        return x

当然还有其他写法,比如直接类的构造函数里面,你先取出来后面也是非常简单了:

values = []
for m in model.modules():
    values.append(m)
#nn.Sequential()

PyTorch ImageNet示例

import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
best_prec1 = 0
def main():
    global args, best_prec1
    args = parser.parse_args()
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
   # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    cudnn.benchmark = True
    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to train mode
    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to evaluate mode
    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))
    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))
    return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
if __name__ == '__main__':
    main()

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Django缓存系统实现过程解析

    Django缓存系统实现过程解析

    这篇文章主要介绍了Django缓存系统实现过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python环境的安装以及PyCharm编辑器配置教程详解

    Python环境的安装以及PyCharm编辑器配置教程详解

    优质的教程可以让我们少走很多弯路,这一点毋庸置疑。这篇文章主要为大家介绍了纯净Python环境的安装以及PyCharm编辑器的配置,需要的可以参考一下
    2023-04-04
  • Python脚本完成post接口测试的实例

    Python脚本完成post接口测试的实例

    今天小编就为大家分享一篇Python脚本完成post接口测试的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 基于Python中request请求得到的response的属性问题

    基于Python中request请求得到的response的属性问题

    这篇文章主要介绍了基于Python中request请求得到的response的属性问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python Pandas中的shift()函数实现数据完美平移应用场景探究

    Python Pandas中的shift()函数实现数据完美平移应用场景探究

    shift() 是 Pandas 中一个常用的数据处理函数,它用于对数据进行移动或偏移操作,常用于时间序列数据或需要计算前后差值的情况,本文将详细介绍 shift() 函数的用法,包括语法、参数、示例以及常见应用场景
    2024-01-01
  • scrapy-redis源码分析之发送POST请求详解

    scrapy-redis源码分析之发送POST请求详解

    这篇文章主要给大家介绍了关于scrapy-redis源码分析之发送POST请求的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用scrapy-redis具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-05-05
  • OpenCV半小时掌握基本操作之腐蚀膨胀

    OpenCV半小时掌握基本操作之腐蚀膨胀

    这篇文章主要介绍了OpenCV基本操作之腐蚀膨胀,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • Python 在OpenCV里实现仿射变换—坐标变换效果

    Python 在OpenCV里实现仿射变换—坐标变换效果

    这篇文章主要介绍了Python 在OpenCV里实现仿射变换—坐标变换效果,本文通过一个例子给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • 利用python3随机生成中文字符的实现方法

    利用python3随机生成中文字符的实现方法

    最近在学习python3,发现网上关于ptyhon3随机生成中文的资料非常少,所以决定将自己实现的方法分享下,下面这篇文章主要给大家介绍了关于利用python3随机生成中文字符的实现方法,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-11-11
  • Python图像处理之简单画板实现方法示例

    Python图像处理之简单画板实现方法示例

    这篇文章主要介绍了Python图像处理之简单画板实现方法,结合实例形式分析了Python基于cv2模块与numpy模块的数值计算及矩形图形绘制简单操作技巧,需要的朋友可以参考下
    2018-08-08

最新评论