pytorch版本PSEnet训练并部署方式

 更新时间:2023年05月10日 08:36:39   作者:__JDM__  
这篇文章主要介绍了pytorch版本PSEnet训练并部署方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

概述

源码地址

torch版本

训练环境没有按照torch的readme一样的环境,自己部署环境为:

torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh

制作数据集

1、训练的数据集

采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。

转换代码:

import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
    gt_lines = []
    ET = etree.parse(xml_path)
    objs = ET.findall("object")
    for ix,obj in enumerate(objs):
        name = obj.find("name").text
        robox = obj.find("robndbox")
        cx = int(float(robox.find("cx").text))
        cy = int(float(robox.find("cy").text))
        w = int(float(robox.find("w").text))
        h = int(float(robox.find("h").text))
        angle = float(robox.find("angle").text)
        # angle = math.degrees(angle1)
        wx1 = cx - int(0.5 * w)
        wy1 = cy - int(0.5 * h)
        wx2 = cx + int(0.5 * w)
        wy2 = cy - int(0.5 * h)
        wx3 = cx - int(0.5 * w)
        wy3 = cy + int(0.5 * h)
        wx4 = cx + int(0.5 * w)
        wy4 = cy + int(0.5 * h)
        x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
        y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
        x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
        y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
        x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
        y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
        x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
        y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
        lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
                str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
        gt_lines.append(lines)
        return gt_lines
def main():
    count = 0
    for xml_dir in xml_listdir:
        gt_lines = xml_out(os.path.join(src_xml,xml_dir))
        txt_path = "gt_" + xml_dir[:-4] + ".txt"
        with open(os.path.join(txt_dir,txt_path),"a+") as fd:
            fd.writelines(gt_lines)
        count +=1
        print("Write file %s" % str(count))
if __name__ == "__main__":
    main()

rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。

转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes",此处classes为检测的类别,如果是模糊训练的话,classes为“###”。

但是重点,这个源代码对于模糊训练,loss一直为1。

2、将数据集分成训练集和测试集

数据集

这里可以按照源码路径存放数据集,也可以修改源码存放位置。

PSENet-python3\dataset\psenet\psenet_ic15.py

修改下述代码为自己文件夹

3、训练

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

其中根据源码中的readme,

可以根据自己的需要,自行选择配置文件。

4、部署测试

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.
    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """
    #assert os.path.exists(img), 'file is not exists'
    #img = cv2.imread(img)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # h, w = image.shape[:2]
    # scale = long_size / max(h, w)
    img = cv2.resize(img, target_size)
    # 将图片由(w,h)变为(1,img_channel,h,w)
    tensor = transforms.ToTensor()(img)
    tensor = tensor.unsqueeze_(0)
    tensor = tensor.to(torch.device("cuda:0"))
    return tensor
def report_speed(outputs, speed_meters):
    total_time = 0
    for key in outputs:
        if 'time' in key:
            total_time += outputs[key]
            speed_meters[key].update(outputs[key])
            print('%s: %.4f' % (key, speed_meters[key].avg))
    speed_meters['total_time'].update(total_time)
    print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
    model = build_model(cfg.model)
    model = model.cuda()
    model.eval()
    checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
    if checkpoint is not None:
        if os.path.isfile(checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
            sys.stdout.flush()
            checkpoint = torch.load(checkpoint)
            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at")
            raise
        # fuse conv and bn
    model = fuse_module(model)
    return model
if __name__ == '__main__':
    src_dir = "testimg/"
    save_dir = "test_save/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
    for d in [cfg, cfg.data.test]:
        d.update(dict(
            report_speed=False
        ))
    if cfg.report_speed:
        speed_meters = dict(
            backbone_time=AverageMeter(500),
            neck_time=AverageMeter(500),
            det_head_time=AverageMeter(500),
            det_pse_time=AverageMeter(500),
            rec_time=AverageMeter(500),
            total_time=AverageMeter(500)
        )
    model = load_model(cfg)
    model.eval()
    count = 0
    for img_name in os.listdir(src_dir):
        img = cv2.imread(src_dir + img_name)
        tensor = prepare_image(img, target_size=(1376, 1024))
        data = dict()
        img_metas = dict()
        data['imgs'] = tensor
        img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
        img_metas['img_size'] = torch.tensor([[1376, 1024]])
        data['img_metas'] = img_metas
        data.update(dict(
            cfg=cfg
        ))
        with torch.no_grad():
            outputs = model(**data)
        if cfg.report_speed:
            report_speed(outputs, speed_meters)
        for bboxes in outputs['bboxes']:
            x1 = bboxes[0]
            y1 = bboxes[1]
            x2 = bboxes[4]
            y2 = bboxes[5]
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
        count = count + 1
        cv2.imwrite(save_dir + img_name, img)
        print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter

训练代码里含有。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。 

相关文章

  • Python实现的基于优先等级分配糖果问题算法示例

    Python实现的基于优先等级分配糖果问题算法示例

    这篇文章主要介绍了Python实现的基于优先等级分配糖果问题算法,涉及Python针对列表的遍历、判断、计算等相关操作技巧,需要的朋友可以参考下
    2018-04-04
  • Python如何建立多个值和单个键的映射

    Python如何建立多个值和单个键的映射

    在Python中,常见的字典只能映射单个键到单个值,若需映射单个键到多值,可以通过将值存储于列表或集合中实现,使用列表可以保持元素插入顺序,而使用集合则可以去重,collections模块的defaultdict类简化了此类多值字典的创建过程
    2024-09-09
  • 详解Pytorch中的tensor数据结构

    详解Pytorch中的tensor数据结构

    torch.Tensor 是一种包含单一数据类型元素的多维矩阵,类似于 numpy 的 array,这篇文章主要介绍了Pytorch中的tensor数据结构,需要的朋友可以参考下
    2022-09-09
  • python tkinter 做个简单的计算器的方法

    python tkinter 做个简单的计算器的方法

    这篇文章主要介绍了python tkinter 做个简单的计算器的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python使用PDFMiner解析PDF代码实例

    Python使用PDFMiner解析PDF代码实例

    本篇文章主要介绍了Python使用PDFMiner解析PDF代码实例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-03-03
  • python 自动化办公之批量修改文件名实操

    python 自动化办公之批量修改文件名实操

    这篇文章主要介绍了python 自动化办公之批量修改文件名实操,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-07-07
  • keras得到每层的系数方式

    keras得到每层的系数方式

    这篇文章主要介绍了keras得到每层的系数方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • python解决报错ImportError: Bad git executable.问题

    python解决报错ImportError: Bad git executable.问题

    这篇文章主要介绍了python解决报错ImportError: Bad git executable.问题。具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-06-06
  • python爬虫爬取图片的简单代码

    python爬虫爬取图片的简单代码

    在本篇文章里小编给大家整理的是一篇关于python爬虫爬取图片的简单代码内容,有兴趣的朋友们可以测试下。
    2021-01-01
  • Python自动化完成tb喵币任务的操作方法

    Python自动化完成tb喵币任务的操作方法

    2019双十一,tb推出了新的活动,商店喵币,看了一下每天都有几个任务来领取喵币,从而升级店铺赚钱,然而我既想赚红包又不想干苦力,遂使用python来进行手机自动化操作,需要的朋友跟随小编一起看看吧
    2019-10-10

最新评论