pytorch如何使用Imagenet预训练模型训练

 更新时间:2023年09月09日 11:27:08   作者:josenxiao  
这篇文章主要介绍了pytorch如何使用Imagenet预训练模型训练问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

pytorch使用Imagenet预训练模型训练

1、loading models

#加载以resnet50为例子
import torchvision as p
model = p.models.resnet50(pretrained=True)

此时加载数据模型以后,我们要是思考如何利用它,但是在此之前你必须了解你加载的模型的结构。

2、处理分类数据

如果是用来处理分类数据:

你只需要替换最后一个全连接分类进行输出。

model.fc = nn.Sequential(nn.Linear(2048,num_classes))
######

3、作为模型的backbone

如果你需要作为要做模型的bacbone,比如RCNN、Semantic Segment等,此时你要将这些模型预加载进行来,以下面的一个FCN8-语义切割为例子:

这里的model就是之前Resnet50model that has pretrained Imageset dataset

class FCN(nn.Module):
    def __init__(self):
        super(FCN,self).__init__()
        self.layer1 = nn.Conv2d(256,nClasses,1,stride=1,padding=0,bias=True)
        self.trans = nn.ConvTranspose2d(nClasses,nClasses,2,stride=2,padding=0,bias=True)
        self.layer2 = nn.Conv2d(128,nClasses,1,stride=1,padding=0,bias=True)
        self.up = nn.ConvTranspose2d(nClasses,nClasses,8,stride=8,padding=0,bias=True)
        for m in self.modules():
            if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
                #m.weight.detach().normal_(0,0.01)
                nn.init.xavier_uniform(m.weight.data)
                m.bias.detach().zero_()
    def forward(self,x,model):
        x = model.conv1(x)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        x = model.layer1(x)
        x1 = model.layer2(x)
        x2 = model.layer3(x1)
        #layers.append(x)#20
        x = model.layer4(x2)
        x = model.avgpool(x)#20
        skip = self.layer1(x2)
        y = skip + x
        c = self.trans(y)
        #### 40
        v = self.layer2(x1)
        y = c+v
        x = self.up(y)
        return x

当然还有其他写法,比如直接类的构造函数里面,你先取出来后面也是非常简单了:

values = []
for m in model.modules():
    values.append(m)
#nn.Sequential()

PyTorch ImageNet示例

import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
best_prec1 = 0
def main():
    global args, best_prec1
    args = parser.parse_args()
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
   # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    cudnn.benchmark = True
    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to train mode
    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to evaluate mode
    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))
    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))
    return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
if __name__ == '__main__':
    main()

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Django框架之路由用法

    Django框架之路由用法

    这篇文章介绍了Django框架之路由的用法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-06-06
  • python实现动态GIF英数验证码识别示例

    python实现动态GIF英数验证码识别示例

    这篇文章主要为大家介绍了python实现动态GIF英数验证码识别示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • 20个超实用Python自动化脚本分享

    20个超实用Python自动化脚本分享

    在当今的快节奏工作环境中,自动化不再是一种奢侈,而是提高效率和精确性的必需手段,这篇文章为大家整理了20个超实用Python自动化脚本,希望对大家有所帮助
    2024-01-01
  • python aiohttp的使用详解

    python aiohttp的使用详解

    这篇文章主要介绍了python aiohttp的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06
  • python tornado获取请求参数的方法

    python tornado获取请求参数的方法

    在Tornado框架中,获取请求参数包括查询字符串参数、表单数据和JSON数据等,JSON数据通过直接访问self.request.json获取,正确处理这些数据需要注意请求的Content-Type和数据格式,感兴趣的朋友跟随小编一起看看吧
    2024-09-09
  • Python向Excel中插入图片的简单实现方法

    Python向Excel中插入图片的简单实现方法

    这篇文章主要介绍了Python向Excel中插入图片的简单实现方法,结合实例形式分析了Python使用XlsxWriter模块操作Excel单元格插入jpg格式图片的相关操作技巧,非常简单实用,需要的朋友可以参考下
    2018-04-04
  • Python使用scrapy采集数据过程中放回下载过大页面的方法

    Python使用scrapy采集数据过程中放回下载过大页面的方法

    这篇文章主要介绍了Python使用scrapy采集数据过程中放回下载过大页面的方法,可实现限制下载过大页面的功能,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • python实现微信自动回复机器人功能

    python实现微信自动回复机器人功能

    wxpy基于itchat,使用了 Web 微信的通讯协议,通过大量接口优化提升了模块的易用性,并进行丰富的功能扩展。这篇文章主要介绍了python实现微信自动回复机器人功能,需要的朋友可以参考下
    2019-07-07
  • 解决Django Static内容不能加载显示的问题

    解决Django Static内容不能加载显示的问题

    今天小编就为大家分享一篇解决Django Static内容不能加载显示的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python实现将一段话txt生成字幕srt文件

    Python实现将一段话txt生成字幕srt文件

    这篇文章主要为大家详细介绍了如何利用Python实现将一段话txt生成字幕srt文件,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2023-02-02

最新评论