Spark机器学习——逻辑回归分类算法_apache

逻辑回归介绍

逻辑回归是一种的监督学习算法,主要用于分类问题。

Logistic Regression 虽然被称为回归,但其实际上是分类模型,并常用于二分类。Logistic Regression 因其简单、可并行化、可解释强深受工业界喜爱。

Logistic 回归的本质是:假设数据服从这个分布,然后使用极大似然估计做参数的估计。

逻辑回归案例

这里主要通过逻辑回归模型建立一个二元分类器模型,根据过去的考试成绩预测下一次学生的考试及格/不及格成绩

scores.csv

(第一次考试的分数,第二次考试的分数,是(0)否(1)能通过第三次考试)

score1,score2,result
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0

Maven依赖

<properties>
    <scala.version>2.11.8</scala.version>
    <spark.version>2.2.2</spark.version>
    <hadoop.version>2.7.6</hadoop.version>
</properties>
<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-core_2.11</artifactId>
        <version>${spark.version}</version>
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-mllib_2.11</artifactId>
        <version>${spark.version}</version>
    </dependency>
</dependencies>

LogisticRegression.scala

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

/**
  * @Author Daniel
  * @Description 逻辑回归算法
  *
  **/
object LogisticRegression {
  def main(args: Array[String]): Unit = {
    // 构建spark编程入口
    val spark = SparkSession
      .builder
      .master("local[*]")
      .appName("LogisticRegression")
      .getOrCreate()
    val (assembler, logisticRegressionModelLoaded) = getModel(spark)
    // 测试数据,给出5个同学第一次和第二次的考试成绩,对第三次的考试成绩进行预测,0代表及格,1代表不及格(由于添加了label,所以与最初相反)
    import spark.implicits._
    val df1 = Seq(
      (70.66150955499435, 92.92713789364831),
      (76.97878372747498, 47.57596364975532),
      (67.37202754570876, 42.83843832029179),
      (89.67677575072079, 65.79936592745237),
      (50.534788289883, 48.85581152764205)
    ).toDF("score1", "score2")
    // 转换样本数据集并添加特征列
    val df2 = assembler.transform(df1)
    df2.show()
    // 最后的结果表示预测新来学生第三门课的及格和不及格状态(0表示及格,1表示不及格)
    val df3 = logisticRegressionModelLoaded.transform(df2)
    df3.show()
  }

  // 获取训练模型
  def getModel(spark: SparkSession): (VectorAssembler, LogisticRegressionModel) = {
    // 表结构
    val schema = StructType(
      StructField("score1", DoubleType, nullable = true) ::
        StructField("score2", DoubleType, nullable = true) ::
        // 0表示不及格1表示及格
        StructField("result", IntegerType, nullable = true) ::
        Nil
    )

    // 将数据转换为DataFrame
    val marksDf = spark.read.format("csv")
      .option("header", value = true)
      .option("delimiter", ",")
      .schema(schema)
      .load("scores.csv")
      // 持久化
      .cache()

    // 需要转换成特征向量的列
    val cols = Array("score1", "score2")

    // 转化成向量
    val assembler = new VectorAssembler()
      .setInputCols(cols)
      .setOutputCol("features")
    // 得到特征向量DataFrame
    val featureDf = assembler.transform(marksDf)

    // 根据result列新建一个标签列
    val indexer = new StringIndexer()
      .setInputCol("result")
      .setOutputCol("label")
    val labelDf = indexer.fit(featureDf).transform(featureDf)

    val seed = 5043
    // 70%的数据用于训练模型,30%用于测试
    val Array(trainingData, testData) = labelDf.randomSplit(Array(0.7, 0.3), seed)
    // 建立回归模型,用训练集数据开始训练
    val logisticRegression = new LogisticRegression()
      .setMaxIter(100)
      .setRegParam(0.02)
      .setElasticNetParam(0.8)
    val logisticRegressionModel = logisticRegression.fit(trainingData)
    /*
    使用测试数据集预测得到的DataFrame,添加三个新的列
    1.rawPrediction
      通常是直接概率
    2.probability
      每个类的条件概率
    3.prediction
      rawPrediction - via的统计结果
     */
    val predictionDf = logisticRegressionModel.transform(testData)
    // ROC下面积的评估模型
    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("prediction")
      .setMetricName("areaUnderROC")
    // 测量精度
    val accuracy = evaluator.evaluate(predictionDf)
    println("预测的精度为" + accuracy)
    // 保存模型
    logisticRegressionModel.write.overwrite()
      .save("score-model")
    // 加载模型
    val logisticRegressionModelLoaded = LogisticRegressionModel
      .load("score-model")
    (assembler, logisticRegressionModelLoaded)
  }
}

结果

预测的精度为0.8928571428571429
+-----------------+-----------------+--------------------+
|           score1|           score2|            features|
+-----------------+-----------------+--------------------+
|70.66150955499435|92.92713789364831|[70.6615095549943...|
|76.97878372747498|47.57596364975532|[76.9787837274749...|
|67.37202754570876|42.83843832029179|[67.3720275457087...|
| 89.6767757507208|65.79936592745237|[89.6767757507208...|
|  50.534788289883|48.85581152764205|[50.534788289883,...|
+-----------------+-----------------+--------------------+

+-----------------+-----------------+--------------------+--------------------+--------------------+----------+
|           score1|           score2|            features|       rawPrediction|         probability|prediction|
+-----------------+-----------------+--------------------+--------------------+--------------------+----------+
|70.66150955499435|92.92713789364831|[70.6615095549943...|[4.42488938425420...|[0.98816618042094...|       0.0|
|76.97878372747498|47.57596364975532|[76.9787837274749...|[0.13401559021765...|[0.53345384278692...|       0.0|
|67.37202754570876|42.83843832029179|[67.3720275457087...|[-1.4919079280137...|[0.18363553054854...|       1.0|
| 89.6767757507208|65.79936592745237|[89.6767757507208...|[3.60597758013620...|[0.97355732638820...|       0.0|
|  50.534788289883|48.85581152764205|[50.534788289883,...|[-2.7578193413834...|[0.05964655921865...|       1.0|
+-----------------+-----------------+--------------------+--------------------+--------------------+----------+

即预测有3个学生能通过考试