30秒轻松实现TensorFlow物体检测

 更新时间:2018年03月14日 15:09:25   作者:wangli0519  
这篇文章主要为大家详细介绍了30秒轻松实现TensorFlow物体检测,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。

使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

首先我们载入一些会使用的库

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image 

接下来进行环境设置

%matplotlib inline 
sys.path.append("..") 

物体检测载入

from utils import label_map_util 
 
from utils import visualization_utils as vis_util 

准备模型

变量  任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90 

下载模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd()) 

将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='') 

载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories) 

辅助代码

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( 
   (im_height, im_width, 3)).astype(np.uint8) 

检测

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
[python] view plain copy
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 这个array在之后会被用来准备为图片加上框和标签 
   image_np = load_image_into_numpy_array(image) 
   # 扩展维度,应为模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每个框代表一个物体被侦测到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每个分值代表侦测到物体的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 执行侦测任务. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 图形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np) 

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 使用python装饰器计算函数运行时间的实例

    使用python装饰器计算函数运行时间的实例

    下面小编就为大家分享一篇使用python装饰器计算函数运行时间的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python使用pandas处理excel文件转为csv文件的方法示例

    python使用pandas处理excel文件转为csv文件的方法示例

    这篇文章主要介绍了python使用pandas处理excel文件转为csv文件的方法示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python读写二进制文件的实现

    Python读写二进制文件的实现

    本文主要介绍了Python读写二进制文件的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-04-04
  • VScode连接远程服务器上的jupyter notebook的实现

    VScode连接远程服务器上的jupyter notebook的实现

    这篇文章主要介绍了VScode连接远程服务器上的jupyter notebook的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • python列表使用实现名字管理系统

    python列表使用实现名字管理系统

    这篇文章主要为大家详细介绍了python列表使用实现名字管理系统,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-01-01
  • python修改字符串值的三种方法

    python修改字符串值的三种方法

    本文主要介绍了python修改字符串值的三种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2025-01-01
  • python使用pil生成缩略图的方法

    python使用pil生成缩略图的方法

    这篇文章主要介绍了python使用pil生成缩略图的方法,涉及Python使用pil模块操作图片的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-03-03
  • python3实现tailf命令的示例代码

    python3实现tailf命令的示例代码

    本文主要介绍了python3实现tailf命令的示例代码,tail -f 是一个linux的操作命令.其主要的是会把文件里的最尾部的内容显显示在屏幕上,并且不断刷新,只要文件有变动就可以看到最新的文件内容,感兴趣的可以了解一下
    2023-11-11
  • Python中time模块与datetime模块在使用中的不同之处

    Python中time模块与datetime模块在使用中的不同之处

    这篇文章主要介绍了Python中time模块与datetime模块在使用中的不同之处,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-11-11
  • Python中str字符串的内置方法详解

    Python中str字符串的内置方法详解

    这篇文章主要介绍了Python中str字符串的内置方法详解,在 python 中字符串有许多内置的方法,在日常编程中会经常使用到,熟练运用了能够在很多场景大大的提高我们的工作效率,需要的朋友可以参考下
    2023-08-08

最新评论