import os
import sys
sys.path.append(xxxx) # 加入Mask_RCNN源码所在目录
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
from mrcnn.config import Config
from mrcnn import model as modellib,utils
from mrcnn import visualize
import yaml
from mrcnn.model import log
from PIL import Image
ROOT_DIR = os.getcwd()
MODEL_DIR = os.path.join(ROOT_DIR, "models")
iter_num=0
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# 从网上下载训练好的基础模型
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH)
# 配置
class ShapesConfig(Config):
NAME = "shapes" # 命名
GPU_COUNT = 1
IMAGES_PER_GPU = 1
NUM_CLASSES = 1 + 1 # 背景一类,香蕉一类,共两类
IMAGE_MIN_DIM = 320
IMAGE_MAX_DIM = 384
RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6)
TRAIN_ROIS_PER_IMAGE = 100 # Aim to allow ROI sampling to pick 33% positive ROIs
STEPS_PER_EPOCH = 100
VALIDATION_STEPS = 50
config = ShapesConfig()
config.display()
# 重写数据集
class DrugDataset(utils.Dataset):
def get_obj_index(self, image):
n = np.max(image)
return n
# 获取标签
def from_yaml_get_class(self, image_id):
info = self.image_info[image_id]
with open(info['yaml_path']) as f:
temp = yaml.load(f.read())
labels = temp['label_names']
del labels[0]
return labels
# 填充mask
def draw_mask(self, num_obj, mask, image,image_id):
info = self.image_info[image_id]
for index in range(num_obj):
for i in range(info['width']):
for j in range(info['height']):
at_pixel = image.getpixel((i, j))
if at_pixel == index + 1:
mask[j, i, index] = 1
return mask
# 读入训练图片及其配置文件
def load_shapes(self, count, img_floder, mask_floder, imglist, dataset_root_path):
self.add_class("shapes", 1, "banana") # 自定义标签
for i in range(count):
filestr = imglist[i].split(".")[0]
mask_path = mask_floder + "/" + filestr + "_json.png"
yaml_path = dataset_root_path + "labelme_json/" + filestr + "_json/info.yaml"
cv_img = cv2.imread(dataset_root_path + "labelme_json/" + filestr + "_json/img.png")
self.add_image("shapes", image_id=i, path=img_floder + "/" + imglist[i],
width=cv_img.shape[1], height=cv_img.shape[0], mask_path=mask_path, yaml_path=yaml_path)
# 读取标签和配置
def load_mask(self, image_id):
global iter_num
print("image_id",image_id)
info = self.image_info[image_id]
count = 1 # number of object
img = Image.open(info['mask_path'])
num_obj = self.get_obj_index(img)
mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
mask = self.draw_mask(num_obj, mask, img,image_id)
occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
for i in range(count - 2, -1, -1):
mask[:, :, i] = mask[:, :, i] * occlusion
occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
labels = []
labels = self.from_yaml_get_class(image_id)
labels_form = []
for i in range(len(labels)):
if labels[i].find("banana") != -1: # 自定义标签
labels_form.append("banana")
class_ids = np.array([self.class_names.index(s) for s in labels_form])
return mask, class_ids.astype(np.int32)
#基础设置
dataset_root_path="data/"
img_floder = dataset_root_path + "pic" # 基本图片目录
mask_floder = dataset_root_path + "cv2_mask" # mask图片目录
imglist = os.listdir(img_floder)
count = len(imglist)
# 构造训练集
dataset_train = DrugDataset()
dataset_train.load_shapes(count, img_floder, mask_floder, imglist, dataset_root_path)
dataset_train.prepare()
# 构造验证集
dataset_val = DrugDataset()
dataset_val.load_shapes(7, img_floder, mask_floder, imglist, dataset_root_path)
dataset_val.prepare()
# 建立模型
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)
# 定义模式
init_with = "coco" # imagenet, coco, or last
if init_with == "imagenet":
model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
model.load_weights(COCO_MODEL_PATH, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
model.load_weights(model.find_last()[1], by_name=True)
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=10,
layers='heads')
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=30,
layers="all")