from__future__importabsolute_import from__future__importdivision from__future__importprint_function importargparse importos.path importre importsys importtarfile importnumpyasnp fromsix.movesimporturllib importtensorflowastf
FLAGS=None DATA_URL='http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# NodeLookup类负责将分类器输出的类别编号与人类可读的标签名称对应起来 classNodeLookup(object): def__init__(self, label_lookup_path=None, uid_lookup_path=None): ifnotlabel_lookup_path: label_lookup_path=os.path.join( FLAGS.model_dir,'imagenet_2012_challenge_label_map_proto.pbtxt') ifnotuid_lookup_path: uid_lookup_path=os.path.join( FLAGS.model_dir,'imagenet_synset_to_human_label_map.txt') self.node_lookup=self.load(label_lookup_path,uid_lookup_path)
# 为每一个softmax节点读取人类可读的类标英文名字 defload(self,label_lookup_path,uid_lookup_path): ifnottf.gfile.Exists(uid_lookup_path): tf.logging.fatal('File does not exist %s',uid_lookup_path) ifnottf.gfile.Exists(label_lookup_path): tf.logging.fatal('File does not exist %s',label_lookup_path)
# Loads mapping from string UID to human-readable string proto_as_ascii_lines=tf.gfile.GFile(uid_lookup_path).readlines() uid_to_human={} p=re.compile(r'[n\d]*[ \S,]*') forlineinproto_as_ascii_lines: parsed_items=p.findall(line) uid=parsed_items[0] human_string=parsed_items[2] uid_to_human[uid]=human_string
# Loads mapping from string UID to integer node ID. node_id_to_uid={} proto_as_ascii=tf.gfile.GFile(label_lookup_path).readlines() forlineinproto_as_ascii: ifline.startswith(' target_class:'): target_class=int(line.split(': ')[1]) ifline.startswith(' target_class_string:'): target_class_string=line.split(': ')[1] node_id_to_uid[target_class]=target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string node_id_to_name={} forkey,valinnode_id_to_uid.items(): ifvalnotinuid_to_human: tf.logging.fatal('Failed to locate: %s',val) name=uid_to_human[val] node_id_to_name[key]=name returnnode_id_to_name
defid_to_string(self,node_id): ifnode_idnotinself.node_lookup: return'' returnself.node_lookup[node_id]
# 从protocol buffer文件中反序列化出inception-v3模型及参数 defcreate_graph(): # Creates graph from saved graph_def.pb. withtf.gfile.FastGFile(os.path.join( FLAGS.model_dir,'classify_image_graph_def.pb'),'rb')asf: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) _=tf.import_graph_def(graph_def,name='')
# 使用v3模型对image图片进行分类,并输出top5置信度的类别预测 defrun_inference_on_image(image): ifnottf.gfile.Exists(image): tf.logging.fatal('File does not exist %s',image) image_data=tf.gfile.FastGFile(image,'rb').read()
# Creates graph from saved GraphDef. create_graph()
withtf.Session()assess: # Some useful tensors: # 'softmax:0': A tensor containing the normalized prediction across # 1000 labels. # 'pool_3:0': A tensor containing the next-to-last layer containing 2048 # float description of the image. # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG # encoding of the image. # Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor=sess.graph.get_tensor_by_name('softmax:0') predictions=sess.run(softmax_tensor, {'DecodeJpeg/contents:0':image_data}) predictions=np.squeeze(predictions)
# Creates node ID --> English string lookup. node_lookup=NodeLookup() top_k=predictions.argsort()[-FLAGS.num_top_predictions:][::-1] fornode_idintop_k: human_string=node_lookup.id_to_string(node_id) score=predictions[node_id] print('%s (score = %.5f)'%(human_string,score))
# 下载模型存档并解压 defmaybe_download_and_extract(): dest_directory=FLAGS.model_dir ifnotos.path.exists(dest_directory): os.makedirs(dest_directory) filename=DATA_URL.split('/')[-1] filepath=os.path.join(dest_directory,filename) ifnotos.path.exists(filepath):
def_progress(count,block_size,total_size): sys.stdout.write('\r>> Downloading %s %.1f%%'%( filename,float(count*block_size)/float(total_size)*100.0)) sys.stdout.flush() filepath,_=urllib.request.urlretrieve(DATA_URL,filepath,_progress) print() statinfo=os.stat(filepath) print('Successfully downloaded',filename,statinfo.st_size,'bytes.') tarfile.open(filepath,'r:gz').extractall(dest_directory)
defmain(_): maybe_download_and_extract() image=(FLAGS.image_fileifFLAGS.image_fileelse os.path.join(FLAGS.model_dir,'cropped_panda.jpg')) run_inference_on_image(image)
if__name__=='__main__':
parser=argparse.ArgumentParser() # classify_image_graph_def.pb: # Binary representation of the GraphDef protocol buffer. # imagenet_synset_to_human_label_map.txt: # Map from synset ID to a human readable string. # imagenet_2012_challenge_label_map_proto.pbtxt: # Text representation of a protocol buffer mapping a label to synset ID. parser.add_argument( '--model_dir', type=str, default='/tmp/imagenet', help="""\ Path to classify_image_graph_def.pb, imagenet_synset_to_human_label_map.txt, and imagenet_2012_challenge_label_map_proto.pbtxt.\ """ )
parser.add_argument( '--image_file', type=str, default='', help='Absolute path to image file.' )
parser.add_argument( '--num_top_predictions', type=int, default=5, help='Display this many predictions.' )
FLAGS,unparsed=parser.parse_known_args() tf.app.run(main=main,argv=[sys.argv[0]]+unparsed) |