最近项目需要在web界面调用深度学习训练模型,后台我采用的springboot框架,在这里记录下找到的解决办法。
1、采用Java语言重现模型接口
2、采用jar包
3、使用socket通信接口
对于1方法,由于算法这块不是我负责,要改写起来很麻烦。2方法没试过,本次采用第三种方法,这个也是借鉴该位博主的思路,采用这种方法避免了算法与后台的各种撕逼,也省的安装相应的环境。
编写Python socket服务端
import socket
import sys
import threading
import json
import numpy as np
# from tag import train2
# nn=network.getNetWork()
# cnn = conv.main(False)
# 深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中
# nnservice = train2.NNService(model='model/20180731.ckpt-1000')
def main():
# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
# 获取本地主机名称
host = socket.gethostname()
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind((host,port))
# 设置监听最大连接数
serversocket.listen(5)
# 获取本地服务器的连接信息
myaddr = serversocket.getsockname()
print("服务器地址:%s"%str(myaddr))
# 循环等待接受客户端信息
while True:
# 获取一个客户端连接
clientsocket,addr = serversocket.accept()
print("连接地址:%s" % str(addr))
try:
t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程
t.start()
pass
except Exception as identifier:
print(identifier)
pass
pass
serversocket.close()
pass
class ServerThreading(threading.Thread):
# words = text2vec.load_lexicon()
def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"):
threading.Thread.__init__(self)
self._socket = clientsocket
self._recvsize = recvsize
self._encoding = encoding
pass
def run(self):
print("开启线程.....")
try:
#接受数据
msg = ''
while True:
# 读取recvsize个字节
rec = self._socket.recv(self._recvsize)
# 解码
msg += rec.decode(self._encoding)
# 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
# 所以需要自定义协议标志数据接受完毕
if msg.strip().endswith('over'):
msg=msg[:-4]
break
# 解析json格式的数据
re = json.loads(msg)
# 调用神经网络模型处理请求
# res = nnservice.hand(re['content'])
res = "123"
sendmsg = json.dumps(res)
# 发送数据
self._socket.send(("%s"%sendmsg).encode(self._encoding))
pass
except Exception as identifier:
self._socket.send("500".encode(self._encoding))
print(identifier)
pass
finally:
self._socket.close()
print("任务结束.....")
pass
def __del__(self):
pass
if __name__ == "__main__":
main()
编写java web socket客户端
public static Object remoteCall(String content){
JSONObject jsonObject = new JSONObject();
jsonObject.put("content", content);
String str = jsonObject.toJSONString();
// 访问服务进程的套接字
Socket socket = null;
// List<Question> questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);
try {
// 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
socket = new Socket("你的host",你的端口);
// 获取输出流对象
OutputStream os = socket.getOutputStream();
PrintStream out = new PrintStream(os);
// 发送内容
out.print(str);
// 告诉服务进程,内容发送完毕,可以开始处理
out.print("over");
// 获取服务进程的输入流
InputStream is = socket.getInputStream();
BufferedReader br = new BufferedReader(new InputStreamReader(is,"utf-8"));
String tmp = null;
StringBuilder sb = new StringBuilder();
// 读取内容
while((tmp=br.readLine())!=null)
sb.append(tmp).append('\n');
// 解析结果
JSONArray res = JSON.parseArray(sb.toString());
return res;
} catch (IOException e) {
e.printStackTrace();
} finally {
try {if(socket!=null) socket.close();} catch (IOException e) {}
// log.info("远程接口调用结束.");
}
return null;
}
Python这块传输图片会比较多,建议还是传输的时候直接传输图片路径。不要传输图片流,影像其速度