flask搭建服务
参考微博:https://blog.keras.io/building-a-simple-keras-deep-learning-rest-api.html 建议英文好的同学直接去看参考微博,写的很清晰。 当我们用tensorflow训练好了模型,这只是完成了模型训练的部分。在实际的生产环境中,还需要将模型部署到服务器当中,这样才能接受不同的客户端来调用它。
1、部署好开发环境
也就是把该安装的库都安装好
$pip install flask gevent requests
2、准备好服务端脚本
这个脚本我们起名run_sever.py 需要定义实现3个函数:
- load_model: 用于载入模型
- infer: 模型的前向推理
- predict:这个函数作用是server端的数据转换,并把结果返回给client
废话不多说,我们直接看代码
#coding=utf-8
import flask
import tensorflow as tf
import os
app = flask.Flask(__name__)
#从硬盘中载入模型的结构和权重
def load_model(saved_model_path):
sess = tf.Session(graph=tf.Graph())
tf.saved_model.loader.load(sess, ['serve'], saved_model_path)
graph = tf.get_default_graph()
return sess
def model_infer(image_contents):
global sess
out = sess.run(['prediction_ori:0','probability_ori:0'], feed_dict={'input_image_as_bytes:0':image_contents})
return out
@app.route('/predict',methods=['POST'])
def predict():
data = {'success':False}
if flask.request.method == 'POST':
if flask.request.files.get('image'):
image_path = flask.request.form.get('image_path')
data = {'success':True}
image_url = flask.request.files.get('image')
image = image_url.read()
if image:
out = model_infer([image])
label,prob = out
data['label'] = str(label)
data['prob'] = str(prob)
return flask.jsonify(data)
else:
return 'There is no image'
if __name__ == '__main__':
saved_model_path = 'saved_model_path'
sess = load_model(saved_model_path)
app.run(host='0.0.0.0',port='8000')
之后运行此脚本,会出现以下的信息。表示这个flask服务已经在你当前的服务器端部署好了,地址是0.0.0.0:8000
$python run_server.py
* Serving Flask app "server" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment.
Use a production WSGI server instead.
* Debug mode: off
* Running on http://0.0.0.0:8000/ (Press CTRL+C to quit)
3、利用curl来测试我们刚刚搭建的服务端
curl是什么?
curl是一个利用URL语法在命令行下工作的文件传输工具,就把理解成linux的一个命令和(cd、ls一样),可以给服务端传输指令文件等。
$curl -X POST -F image=@test.jpg 'http://localhost:8000/predict'
这里表示我们向http://localhost:8000/predict网址发送了POST请求,并传输了一张图片test.jpg 如果你上一步在服务器上搭建的server,这里的地址要变成你服务器的地址,比如:192.168.11.11(具体要看你部署的机器端口号) 正常情况下会返回predict函数里返回的json文本。
4、编写脚本来调用线上服务
这个脚本可以理解为所谓的客户端,所以起名client.py
在第3步中,我们不是利用命令行中的curl调用了部署的服务并返回了我们想要的结果了吗?为什么我们还要编写脚本来调用线上的服务呢?
在命令行中使用curl命令是可以获得返回的json文本数据,看起来很方便,但是当你需要将结果转换成特定的格式,或者想要大批量的跑测试图片时,curl这种方式就看起来非常蛋疼了。我们是可以通过写shell脚本来批量的执行curl命令,并且把结果存取到文本中然后再进行格式转换,但是你如果试过这种方式就知道多么难受了。非常耗时,而且容易出错。所以实际使用的时候,我们都会编写一个脚本来做这件事情。
直接看代码吧
#coding=utf-8
import requests
url = 'http://192.30.12.15/predict'
image_path = 'test.jpg'
#载入图像
image = open(image_path,'rb').read()
payload = {'image':image}
#提交请求
r = requests.post(url,files=payload).json()
print(r)
在命令行中运行python client.py应该可以看到返回结果,python脚本嘛结果就可以随意操作了嘛。