

from __future__ import division
from models import *
 from utils.utils import *
 from utils.datasets import *import os
 import sys
 import time
 import datetime
 import argparse
 import cv2from PIL import Image
import torch
 from torch.utils.data import DataLoader
 from torchvision import datasets
 from torch.autograd import Variableimport matplotlib.pyplot as plt
 import matplotlib.patches as patches
 from matplotlib.ticker import NullLocator def changeBGR2RGB(img):
     b = img[:, :, 0].copy()
     g = img[:, :, 1].copy()
     r = img[:, :, 2].copy()    # RGB > BGR
     img[:, :, 0] = r
     img[:, :, 1] = g
     img[:, :, 2] = b    return img
 def changeRGB2BGR(img):
     r = img[:, :, 0].copy()
     g = img[:, :, 1].copy()
     b = img[:, :, 2].copy()    # RGB > BGR
     img[:, :, 0] = b
     img[:, :, 1] = g
     img[:, :, 2] = r    return img
if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--image_folder", type=str, default="data/samples", help="path to dataset")
     parser.add_argument("--vedio_file", type=str, default="/home/fut/test2.asf", help="path to dataset")
     parser.add_argument("--model_def", type=str, default="myproj/config/yolov3.cfg", help="path to model definition file")
     parser.add_argument("--weights_path", type=str, default="trained-model/yolov3_ckpt_5.pth", help="path to weights file")
     parser.add_argument("--class_path", type=str, default="myproj/config/obj.names", help="path to class label file")
     parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")
     parser.add_argument("--nms_thres", type=float, default=0.5, help="iou thresshold for non-maximum suppression")
     parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
     parser.add_argument("--n_cpu", type=int, default=3, help="number of cpu threads to use during batch generation")
     parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
     parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model")
     opt = parser.parse_args()
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     model = Darknet(opt.model_def, img_size=opt.img_size).to(device)
     if opt.weights_path.endswith(".weights"):
         # Load darknet weights
         # Load checkpoint weights
     model.eval()  # Set in evaluation mode
     classes = load_classes(opt.class_path)
     Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
     #if opt.vedio_file.endswith(".mp4"):
     cap = cv2.VideoCapture(opt.vedio_file)
     colors = np.random.randint(0, 255, size=(len(classes), 3), dtype="uint8")
     time_begin = time.time()
     NUM = cap.get(cv2.CAP_PROP_FRAME_COUNT)
     while cap.isOpened():
         ret, img = cap.read()
         if ret is False:
         img = cv2.resize(img, (1280, 960), interpolation=cv2.INTER_CUBIC)        #PILimg = np.array(Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)))
         #imgTensor = transforms.ToTensor()(PILimg)
         #基于pytorch的yolov3 从github拉的
         # yolov3如何改进成可以对视频进行实时分析
         #以下的代码可以在utils的文件里找到 是在data loader里面对数据进行处理的,那么也可以把代码直接复制过来用
         #需要注意的是 PIL读取的图片是RGB的 这里的图片是BGR的 是opencv读取的
         #前面的都很简单 都是从detect的代码复制过来的,加了一个打开视频cap
         imgTensor = transforms.ToTensor()(RGBimg)
         imgTensor, _ = pad_to_square(imgTensor, 0)
         imgTensor = resize(imgTensor, 416)
         #是看了莫烦的机器学习想到的 结合报错信息
         imgTensor = imgTensor.unsqueeze(0)
         imgTensor = Variable(imgTensor.type(Tensor))
         #展示一下吧        with torch.no_grad():
             detections = model(imgTensor)
             detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)        a.clear()
         if detections is not None:
         if len(a)  :
             for detections in a:
                 if detections is not None:
                     detections = rescale_boxes(detections, opt.img_size, RGBimg.shape[:2])
                     unique_labels = detections[:, -1].cpu().unique()
                     n_cls_preds = len(unique_labels)
                     for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
                         box_w = x2 - x1
                         box_h = y2 - y1
                         color = [int(c) for c in colors[int(cls_pred)]]
                         img = cv2.rectangle(img, (x1, y1 + box_h), (x2, y1), color, 2)
                         cv2.putText(img, classes[int(cls_pred)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                         cv2.putText(img, str("%.2f" % float(conf)), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                     color, 2)            #print()
         #cv2.putText(img,"Hello World!",(400,50),cv2.FONT_HERSHEY_PLAIN,2.0,(0,0,255),2)        cv2.imshow('frame', changeRGB2BGR(RGBimg))
         #cv2.waitKey(0)        if cv2.waitKey(25) & 0xFF == ord('q'):
     time_end = time.time()
     time_total = time_end - time_begin
     print(NUM // time_total)    cap.release()