目录
- 什么是模型部署?
- 模型部署工具介绍
- MLflow
- MLeap
- PMML
- Pyspark
- Sklearn
- ONNX介绍
- TensorRT
- [TensorFlow Serving](https://github.com/tensorflow/serving)
- Web服务化部署
- Docker菜鸟教程
- 参考
什么是模型部署?
在典型的机器学习和深度学习项目中,建模的常规流程是定义问题、数据收集、数据理解、数据处理、构建模型。但是,如果我们想要将模型提供给最终用户,以便用户能够使用它,就需要进行模型部署,模型部署要做的工作就是如何将机器学习模型传递给客户/利益相关者。模型的部署大致分为以下三个步骤:
- 模型持久化;
持久化,通俗得讲,就是临时数据(比如内存中的数据,是不能永久保存的)持久化为持久数据(比如持久化至数据库中,能够长久保存)。那我们训练好的模型一般都是存储在内存中,这个时候就需要用到持久化方式,在 Python 中,常用的模型持久化方式一般都是以文件的方式持久化。 - 选择适合的服务器加载已经持久化的模型;
- 提高服务接口,拉通前后端数据交流;
模型部署工具介绍
MLflow
MLeap
PMML
依赖包:
- sklearn
- sklearn2pmml
将训练好的机器学习模型转化为PMML
格式,以供Java
调用。Python
代码如下
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml
if __name__ == '__main__':
iris = load_iris() # 经典的数据
X = iris.data # 样本特征
y = iris.target # 分类目标
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())]) # 用决策树分类
pipeline.fit(X, y) # 训练
sklearn2pmml(pipeline, "iris.pmml", with_repr=True) # 输出PMML文件
Java
读取模型文件并预测,具体代码如下:
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
public class TestPmml {
public static void main(String args[]) throws Exception {
String fp = "iris.pmml";
TestPmml obj = new TestPmml();
Evaluator model = obj.loadPmml(fp);
List<Map<String, Object>> inputs = new ArrayList<>();
inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
inputs.add(obj.getRawMap(4.9, 3, 1.4, 0.2));
for (int i = 0; i < inputs.size(); i++) {
Map<String, Object> output = obj.predict(model, inputs.get(i));
System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
}
}
private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException {
InputStream is = new FileInputStream(fp);
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance();
return factory.newModelEvaluator(pmml);
}
private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
Map<String, Object> data = new HashMap<String, Object>();
data.put("x1", a);
data.put("x2", b);
data.put("x3", c);
data.put("x4", d);
return data;
}
/**
* 运行模型得到结果。
*/
private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
Map<String, Object> output = evaluate(evaluator, input);
return output;
}
/**
* 把原始输入转换成PMML格式的输入。
*/
private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
List<InputField> inputFields = evaluator.getInputFields();
Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
for (InputField field : inputFields) {
FieldName fieldName = field.getName();
Object rawValue = input.get(fieldName.getValue());
FieldValue value = field.prepare(rawValue);
map.put(fieldName, value);
}
return map;
}
/**
* 运行模型得到结果。
*/
private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
Map<FieldName, ?> results = evaluator.evaluate(input);
List<TargetField> targetFields = evaluator.getTargetFields();
Map<String, Object> output = new LinkedHashMap<String, Object>();
for (int i = 0; i < targetFields.size(); i++) {
TargetField field = targetFields.get(i);
FieldName fieldName = field.getName();
Object value = results.get(fieldName);
if (value instanceof Computable) {
Computable computable = (Computable) value;
value = computable.getResult();
}
output.put(fieldName.getValue(), value);
}
return output;
}
}
Pyspark
Sklearn
ONNX介绍
可通过示例–python模型转换为ONNX格式了解简单的模型转换。详细内容可参考ONNX官方教程
TensorRT
TensorFlow Serving
Web服务化部署
该方法主要是通过一些Web
框架将预测模型打包成Web
服务接口的形式,是一种比较常见的线上部署方式。常用的Web
框架如下所示:
Docker菜鸟教程
参考