spark在spring(Java)中的运用
- spark在spring(Java)中的运用
- 转载
- 想法
- 总体实现
- 导出模型
- 在java环境使用该模型
- 注意事项
- 输出结果
spark在spring(Java)中的运用
转载
在Java Web中使用Spark MLlib训练的模型
作者:xingoo
出处:
Spark MLlib之决策树(DecisioinTree)
作者:caiandyong
出处:
想法
问题:在假期的最后两天我突然对正在做的项目有了一个想法,我正在做的是spring的项目,我能不能把上学期刚学的spark MLlib的机器模型用在里面呢?(加点亮点的同时也能运用新知识)
总体实现
纵观从开始到实现,都有一篇文章不得不提
在Java Web中使用Spark MLlib训练的模型.
从这里开始我想到实现的方法是
- 用sacla语言训练好模型
- 使用PMML将模型导出
- 在需要的java环节中直接使用导出的模型
导出模型
因为实现过scala在IDEA训练模型,所以我直接开始了解如何将训练好的模型使用PMML导出。
我在spark决策树的基础上准备把他导出
使用的这个例子:Spark MLlib之决策树(DecisioinTree).
结果就是一直报错,发现走不通
走了很多弯路,因为资料太少也无法借鉴
最后只能去去看官方API了,看着看着官网说我下载的spark就有例子。。。
就在:examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala
代码如下:(数据也在spark的data文件夹下)
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
val data = sc.textFile("data/mllib/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache()
// Cluster the data into two classes using KMeans
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(parsedData, numClusters, numIterations)
// Export to PMML to a String in PMML format
println(s"PMML Model:\n ${clusters.toPMML}")
// Export the model to a local file in PMML format
clusters.toPMML("/tmp/kmeans.xml")
// Export the model to a directory on a distributed file system in PMML format
clusters.toPMML(sc, "/tmp/kmeans")
// Export the model to the OutputStream in PMML format
clusters.toPMML(System.out)
在java环境使用该模型
这时候就需要一开始的
在Java Web中使用Spark MLlib训练的模型 根据博主的使用方法就可以在java环境使用了
注意事项
导包很重要!
因为导包导的太多导致冲突报错就是过程中最大的问题
我在导出模型的时候依赖是:
<properties>
<spark.version>2.1.0</spark.version>
<scala.version>2.11</scala.version>
</properties>
<repositories>
<repository>
<id>nexus-aliyun</id>
<name>Nexus aliyun</name>
<url>http://maven.aliyun.com/nexus/content/groups/public</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.version}</artifactId>
<version>${spark.version}</version>
<exclusions>
<exclusion>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-evaluator-spark</artifactId>
<version>1.2.2</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>${spark.version}</version>
<!--<scope>runtime</scope>-->
</dependency>
然后在使用模型的依赖是:
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator-extension -->
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.3</version>
</dependency>
(因为不知道哪里冲突所以就开了两个项目。。。一定要分开导入依赖)
输出结果
数据:2.txt
0,32 1 2 0
1,27 1 1 1
1,29 1 1 0
1,25 1 2 1
0,23 0 2 1
数据:3.txt
0,32 1 1 0
0,25 1 2 0
1,29 1 2 1
1,24 1 1 0
0,31 1 1 0
1,35 1 2 1
0,30 0 1 0
0,31 1 1 0
1,30 1 2 1
1,21 1 1 0
0,21 1 2 0
1,21 1 2 1
0,29 0 2 1
0,29 1 0 1
0,29 0 2 1
1,30 1 1 0
导出模型代码:
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
/**
* @Author Song
* @Date 2021/3/4 9:13
* @Version 1.0
*/
object demo4 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("demo3")
val sc = new SparkContext(conf)
// $example on$
// Load and parse the data
val data = sc.textFile("D:\\111TEST\\data\\3.txt")
val parsedData = data.map(line => LabeledPoint.parse(line))//.cache()
val testData=sc.textFile("D:\\111TEST\\data\\2.txt")
val parsedData2 = testData.map(line => LabeledPoint.parse(line))
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
.run(parsedData)
val predictionAndLabels = parsedData2.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
val metrics = new MulticlassMetrics(predictionAndLabels)
val accuracy = metrics.accuracy
println(s"Accuracy = $accuracy")
// Export the model to a local file in PMML format
model.toPMML("D:\\111TEST\\data2\\simple.xml")
// Export the model to a directory on a distributed file system in PMML format
model.toPMML(sc, "D:\\111TEST\\data2\\simple")
// Export the model to the OutputStream in PMML format
model.toPMML(System.out)
// $example off$
sc.stop()
}
}
导入模型代码:
/**
* @Author Song
* @Date 2021/3/4 9:25
* @Version 1.0
*/
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
class PMMLDemo2 {
private Evaluator loadPmml(){
PMML pmml1 = new PMML();
try(InputStream inputStream = new FileInputStream("D:\\111TEST\\data2\\simple.xml")){
pmml1 = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
} catch (Exception e) {
e.printStackTrace();
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
return modelEvaluatorFactory.newModelEvaluator(pmml1);
}
private Object predict(Evaluator evaluator,int a, int b, int c, int d) {
Map<String, Integer> data = new HashMap<String, Integer>();
data.put("field_0", a);
data.put("field_1", b);
data.put("field_2", c);
data.put("field_3", d);
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);
System.out.println(a + " " + b + " " + c + " " + d + ":" + target);
return target;
}
public static void main(String args[]){
PMMLDemo2 demo = new PMMLDemo2();
Evaluator model = demo.loadPmml();
demo.predict(model,27,1,1,1);
demo.predict(model,25,1,2,1);
demo.predict(model,23,0,2,1);
}
}