flask搭建服务

参考微博:https://blog.keras.io/building-a-simple-keras-deep-learning-rest-api.html 建议英文好的同学直接去看参考微博,写的很清晰。 当我们用tensorflow训练好了模型,这只是完成了模型训练的部分。在实际的生产环境中,还需要将模型部署到服务器当中,这样才能接受不同的客户端来调用它。

1、部署好开发环境

也就是把该安装的库都安装好

$pip install flask gevent requests

2、准备好服务端脚本

这个脚本我们起名run_sever.py 需要定义实现3个函数:

  • load_model: 用于载入模型
  • infer: 模型的前向推理
  • predict:这个函数作用是server端的数据转换,并把结果返回给client

废话不多说,我们直接看代码

#coding=utf-8
import flask
import tensorflow as tf
import os

app = flask.Flask(__name__)
#从硬盘中载入模型的结构和权重
def load_model(saved_model_path):
    sess = tf.Session(graph=tf.Graph())
    
    tf.saved_model.loader.load(sess, ['serve'], saved_model_path)
    graph = tf.get_default_graph()
    return sess
    
def model_infer(image_contents):
    global sess
    out = sess.run(['prediction_ori:0','probability_ori:0'], feed_dict={'input_image_as_bytes:0':image_contents})
    return out


@app.route('/predict',methods=['POST'])
def predict():
    data = {'success':False}
    if flask.request.method == 'POST':
        if flask.request.files.get('image'):
            image_path = flask.request.form.get('image_path')
            data = {'success':True}
            image_url = flask.request.files.get('image')
            image = image_url.read()
            if image:
                out = model_infer([image])
                label,prob = out
                
                data['label'] = str(label)
                data['prob'] = str(prob)
                return flask.jsonify(data)
            else:
                return 'There is no image'
                
                
if __name__ == '__main__':
    
    saved_model_path = 'saved_model_path'
    sess = load_model(saved_model_path)
    app.run(host='0.0.0.0',port='8000')

之后运行此脚本,会出现以下的信息。表示这个flask服务已经在你当前的服务器端部署好了,地址是0.0.0.0:8000

$python run_server.py
 
* Serving Flask app "server" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://0.0.0.0:8000/ (Press CTRL+C to quit)

3、利用curl来测试我们刚刚搭建的服务端

curl是什么?

curl是一个利用URL语法在命令行下工作的文件传输工具,就把理解成linux的一个命令和(cd、ls一样),可以给服务端传输指令文件等。

$curl -X POST -F image=@test.jpg 'http://localhost:8000/predict'

这里表示我们向http://localhost:8000/predict网址发送了POST请求,并传输了一张图片test.jpg 如果你上一步在服务器上搭建的server,这里的地址要变成你服务器的地址,比如:192.168.11.11(具体要看你部署的机器端口号) 正常情况下会返回predict函数里返回的json文本。

4、编写脚本来调用线上服务

这个脚本可以理解为所谓的客户端,所以起名client.py

在第3步中,我们不是利用命令行中的curl调用了部署的服务并返回了我们想要的结果了吗?为什么我们还要编写脚本来调用线上的服务呢?

在命令行中使用curl命令是可以获得返回的json文本数据,看起来很方便,但是当你需要将结果转换成特定的格式,或者想要大批量的跑测试图片时,curl这种方式就看起来非常蛋疼了。我们是可以通过写shell脚本来批量的执行curl命令,并且把结果存取到文本中然后再进行格式转换,但是你如果试过这种方式就知道多么难受了。非常耗时,而且容易出错。所以实际使用的时候,我们都会编写一个脚本来做这件事情。

直接看代码吧

#coding=utf-8
import requests

url = 'http://192.30.12.15/predict'
image_path = 'test.jpg'

#载入图像
image = open(image_path,'rb').read()
payload = {'image':image}

#提交请求
r = requests.post(url,files=payload).json()

print(r)

在命令行中运行python client.py应该可以看到返回结果,python脚本嘛结果就可以随意操作了嘛。