Windows 利用tensorflow2 object-detection api训练自己的模型
预备工作:
windows tensorflow2 object-detection api 下载安装
根据这个文档一步步走
训练自定义对象检测器 — 张量流 2 对象检测 API 教程文档 (tensorflow-object-detection-api-tutorial.readthedocs.io)
1.准备数据集
需将数据集分为训练集和测试集两部分,使用labelImg工具打标签生成.xml文件,然后将.xml文件转成.record文件。
2.训练模型
models/tf2_detection_zoo.md at master · tensorflow/models · GitHub 在这里下载需要的模型,配置config文件时最好将batch_size的值改为1,以防显存不足,当然也可以自己设置一个合适的值。
可能遇到的报错:
self._read_buf = _pywrap_file_io.BufferedInputStream(
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xd5 in position 114: invalid continuation byte可能的解决方法:
配置模型config文件时,将路径中的‘ \ '改成' /'.如:'.....training_demo/annotations/label_map.pbtxt'
路径中含有中文,需将路径改为全英文
AssertionError: Found 260 Python objects that were not bound to checkpointed values, likely due to changes in the Python program.
解决方法:将 fine_tune_checkpoint_type: "classification" 改为 fine_tune_checkpoint_type: "detection"
3.导出.pb模型
用文档的方法即可
4.测试训练好的模型
测试模型的代码转自:
原文链接:https://blog.csdn.net/weixin_48672949/article/details/118808852
#!/usr/bin/env python # -*- coding: utf-8 -*- """ Created on 2020.10.6 @auther:Jacklee """ import os import cv2 import numpy as np from PIL import Image import tkinter import matplotlib matplotlib.use('TkAgg') import matplotlib.pyplot as plt import time import tensorflow as tf from object_detection.utils import label_map_util from object_detection.utils import config_util from object_detection.utils import visualization_utils as viz_utils from object_detection.builders import model_builder from six import BytesIO def load_image_into_numpy_array(path): """Load an image from file into a numpy array. Puts image into numpy array to feed into tensorflow graph. Note that by convention we put it into a numpy array with shape (height, width, channels), where channels=3 for RGB. Args: path: the file path to the image Returns: uint8 numpy array with shape (img_height, img_width, 3) """ img_data = tf.io.gfile.GFile(path, 'rb').read() image = Image.open(BytesIO(img_data)).convert('RGB') (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8) # build detection model and load trained_model weights pipeline_config = os.path.join( 'D:/tensorflow-model/workspace/training_demo/exported-models/my_fast', 'pipeline.config') model_dir = 'D:/tensorflow-model/workspace/training_demo/exported-models/my_fast/checkpoint/' print(pipeline_config) print(model_dir) # Load pipeline config and build a detection model configs = config_util.get_configs_from_pipeline_file(pipeline_config) model_config = configs['model'] detection_model = model_builder.build( model_config=model_config, is_training=False) # Restore checkpoint ckpt = tf.compat.v2.train.Checkpoint(model=detection_model) ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial() def get_model_detection_function(model): """Get a tf.function for detection.""" @tf.function def detect_fn(image): """Detect objects in image.""" image, shapes = model.preprocess(image) prediction_dict = model.predict(image, shapes) detections = model.postprocess(prediction_dict, shapes) return detections, prediction_dict, tf.reshape(shapes, [-1]) return detect_fn detect_fn = get_model_detection_function(detection_model) # Load label_map data label_map_path = configs['eval_input_config'].label_map_path label_map = label_map_util.load_labelmap(label_map_path) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=label_map_util.get_max_label_map_index(label_map), use_display_name=True) category_index = label_map_util.create_category_index(categories) label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True) image_dir = 'D:/tensorflow-model/workspace/training_demo/images/1/' PATH_TO_RESULT = 'D:/tensorflow-model/workspace/training_demo/images/fast/' IMAGE_PATH_CHAR = [] start = time.time() for image in os.listdir(image_dir): if image.endswith(".jpg") or image.endswith(".png"): IMAGE_PATH_CHAR.append(os.path.join(image_dir, image)) # 将每张图像的完整路径加入到列表中 for image_path in IMAGE_PATH_CHAR: print("Running inference for {}...".format(image_path), end='') image_np = load_image_into_numpy_array(image_path) input_tensor = tf.convert_to_tensor( np.expand_dims(image_np, 0), dtype=tf.float32) detections, predictions_dict, shapes = detect_fn(input_tensor) label_id_offset = 1 image_np_with_detections = image_np.copy() viz_utils.visualize_boxes_and_labels_on_image_array( image_np_with_detections, detections['detection_boxes'][0].numpy(), (detections['detection_classes'][0].numpy() + label_id_offset).astype(int), detections['detection_scores'][0].numpy(), category_index, use_normalized_coordinates=True, max_boxes_to_draw=200, min_score_thresh=.30, agnostic_mode=False) plt.figure(figsize=(12, 8)) print(image_path.split('/')[-1]) cv2.imwrite(PATH_TO_RESULT + image_path.split('/')[-1], image_np_with_detections) plt.subplot(122) plt.imshow(image_np_with_detections) plt.show() end = time.time() print('Execution Time: ', end - start)
上述代码需更改四处文件目录:
pipeline_config = os.path.join('training_mouse\\train_export', 'pipeline.config')
.pb文件位置:model_dir = 'training_mouse\\train_export\\checkpoint\\' 测试图片位置:image_dir = 'training_mouse\\test_image\\'
测试结果图:PATH_TO_RESULT = 'training_mouse\\test_image\\'
将该代码保存到training_demo文件夹下,cmd中cd到该目录下运行该.py文件。
5.参考博客
tensorflow2.0训练目标检测模型_weixin_48672949的博客-CSDN博客_tensorflow2.0 目标检测