Spark SQL是用于结构化数据处理的一个模块。同Spark RDD 不同地方在于Spark SQL的API可以给Spark计算引擎提供更多地信息,例如:数据结构、计算算子等。在内部Spark可以通过这些信息有针对对任务做优化和调整。这里有几种方式和Spark SQL进行交互,例如Dataset API和SQL等,这两种API可以混合使用。Spark SQL的一个用途是执行SQL查询。 Spark SQL还可用于从现有Hive安装中读取数据。从其他编程语言中运行SQL时,结果将作为Dataset/DataFrame返回,使用命令行或JDBC / ODBC与SQL接口进行交互。

Dataset是一个分布式数据集合在Spark 1.6提供一个新的接口,Dataset提供RDD的优势(强类型,使用强大的lambda函数)以及具备了Spark SQL执行引擎的优点。Dataset可以通过JVM对象构建,然后可以使用转换函数等(例如:map、flatMap、filter等),目前Dataset API支持Scala和Java 目前Python对Dataset支持还不算完备。

Data Frame是命名列的数据集,他在概念是等价于关系型数据库。DataFrames可以从很多地方构建,比如说结构化数据文件、hive中的表或者外部数据库,使用Dataset[row]的数据集,可以理解DataFrame就是一个Dataset[row].

SparkSession

Spark中所有功能的入口点是SparkSession类。要创建基本的SparkSession,只需使用SparkSession.builder():

import org.apache.spark.sql.SparkSession

object SparkSessionTests {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
                      .master("local[5]")
                      .appName("spark session")
                      .getOrCreate()
    import spark.implicits._
    // todo your code here
    spark.stop()
  }
}

Dataset

Dataset与RDD类似,但是,它们不使用Java序列化或Kryo,而是使用专用的Encoder来序列化对象以便通过网络进行处理或传输。虽然Encoder和标准序列化都负责将对象转换为字节,但Encoder是动态生成的代码,并使用一种格式,允许Spark执行许多操作,如过滤,排序和散列,而无需将字节反序列化为对象。

  • 集合case class
case class Person(name: String, age: Long)
def main(args: Array[String]): Unit = {
  val spark = SparkSession.builder()
  .master("local[5]")
  .appName("spark session")
  .getOrCreate()
  spark.sparkContext.setLogLevel("FATAL")
  import spark.implicits._
  //通过 case class创建dataset
  val caseClassDS = Seq(Person("zhangsan", 28)).toDS()
  caseClassDS.show()

  spark.stop()
}
+--------+---+
|    name|age|
+--------+---+
|zhangsan| 28|
+--------+---+
  • 集合 元组
def main(args: Array[String]): Unit = {
  val spark = SparkSession.builder()
  .master("local[5]")
  .appName("spark session")
  .getOrCreate()
  spark.sparkContext.setLogLevel("FATAL")
  import spark.implicits._
  //将元素直接创建dataset
  val caseClassDS = Seq((1,"zhangsan",true),(2,"lisi",false)).toDS()
  caseClassDS.show()

  spark.stop()
}
+---+--------+-----+
| _1|      _2|   _3|
+---+--------+-----+
|  1|zhangsan| true|
|  2|    lisi|false|
+---+--------+-----+
  • 加载json数据
{"name":"张三","age":18}
{"name":"lisi","age":28}
{"name":"wangwu","age":38}
case class Person(name: String, age: Long)
def main(args: Array[String]): Unit = {
  val spark = SparkSession.builder()
  .master("local[5]")
  .appName("spark session")
  .getOrCreate()
  spark.sparkContext.setLogLevel("FATAL")
  import spark.implicits._

  val dataset = spark.read.json("D:///Persion.json").as[Person]
  dataset.show()

  spark.stop()
}
+---+------+
|age|  name|
+---+------+
| 18|  张三|
| 28|  lisi|
| 38|wangwu|
+---+------+

Data Frame

Data Frame是命名列的数据集,他在概念是等价于关系型数据库。DataFrames可以从很多地方构建,比如说结构化数据文件、hive中的表或者外部数据库,使用Dataset[row]的数据集,可以理解DataFrame就是一个Dataset[row].

  • json文件创建
val df = spark.read.json("file:///D:/Persion.json")
df.show()
  • case class
spark.sparkContext.parallelize(List("zhangsan,20","lisi,30"))
  .map(line=>Person(line.split(",")(0),line.split(",")(1).toInt))
  .toDF("uname","uage")
  .show()
  • 元组
spark.sparkContext.parallelize(List("zhangsan,20","lisi,30"))
        .map(line=>(line.split(",")(0),line.split(",")(1).toInt))
        .toDF("uname","uage")
        .show()
  • 自定义Schema
val userRDD: RDD[Row] = spark.sparkContext.parallelize(List("zhangsan,20", "lisi,30"))
      .map(line => (line.split(",")(0), line.split(",")(1).toInt))
      .map(item=> Row(item._1,item._2))
	  
//创建fields
var fields=Array(StructField("name",StringType,true),StructField("age",IntegerType,true))
//构建Schema
var schema=new StructType(fields)
spark.createDataFrame(,schema)
	.show()

DataFrame 算子操作

如下格式数据

Michael,29,2000,true
Andy,30,5000,true
Justin,19,1000,true
Kaine,20,5000,true
Lisa,19,1000,false

select

var rdd=  spark.sparkContext.textFile("file:///D:/people.txt")
.map(_.split(","))
.map(arr=>Row(arr(0),arr(1).trim().toInt,arr(2).trim().toDouble))

val fields =StructField("name",StringType,true)::
                StructField("age",IntegerType,true)::
                StructField("salary",DoubleType,true)::
                StructField("sex",BooleanType,true)::
                StructField("job",StringType,true)::
                StructField("deptno",IntegerType,true)::
                Nil

val schema= StructType(fields)
val userDF = spark.createDataFrame(rdd,schema).as("user")

 userDF.select($"name" , $"age",$"salary",$"salary" * 12 as "年薪" )
          .show()
+-------+---+------+-------+
|   name|age|salary|   年薪|
+-------+---+------+-------+
|Michael| 29|2000.0|24000.0|
|   Andy| 30|5000.0|60000.0|
| Justin| 19|1000.0|12000.0|
|  Kaine| 20|5000.0|60000.0|
|   Lisa| 19|1000.0|12000.0|
+-------+---+------+-------+

filter

userDF.select($"name" , $"age",$"salary",$"salary" * 12 as "年薪" )
          .filter($"name" === "Michael" or $"年薪" <= 50000)
          .show()
+-------+---+------+-------+
|   name|age|salary|   年薪|
+-------+---+------+-------+
|Michael| 29|2000.0|24000.0|
| Justin| 19|1000.0|12000.0|
|   Lisa| 19|1000.0|12000.0|
+-------+---+------+-------+

withColumn

userDF.select($"name" , $"age",$"salary",$"salary" * 12 as "年薪" )
      .filter($"name" === "Michael" or $"年薪" <= 50000)
      .withColumn("年终奖",$"salary"* 0.8)
      .show()
+-------+---+------+-------+------+
|   name|age|salary|   年薪|年终奖|
+-------+---+------+-------+------+
|Michael| 29|2000.0|24000.0|1600.0|
| Justin| 19|1000.0|12000.0| 800.0|
|   Lisa| 19|1000.0|12000.0| 800.0|
+-------+---+------+-------+------+

groupBy

userDF.select($"deptno",$"salary" )
          .groupBy($"deptno")
          .sum("salary")
          .show()

agg

import org.apache.spark.sql.functions._ 

userDF.select($"deptno",$"salary" )
.groupBy($"deptno")
.agg(sum($"salary") as "总薪资",avg($"salary") as "平均值",max($"salary") as "最大值")
.show()
+------+-------+------------------+-------+
|deptno| 总薪资|            平均值| 最大值|
+------+-------+------------------+-------+
|     1|43000.0|14333.333333333334|20000.0|
|     2|38000.0|           19000.0|20000.0|
+------+-------+------------------+-------+

join

准备一下数据dept.txt

1,销售部门
2,研发部门
3,媒体运营
4,后勤部门
var deptDF =  spark.sparkContext.textFile("D:/dept.txt")
      .map(line =>(line.split(",")(0).toInt,line.split(",")(1)))
      .toDF("deptno","deptname").as("dept")


userDF.select($"deptno",$"salary" )
.groupBy($"deptno")
.agg(sum($"salary") as "总薪资",avg($"salary") as "平均值",max($"salary") as "最大值")
.join(deptDF,$"dept.deptno" === $"user.deptno")
.show()
+------+-------+------------------+-------+------+--------+
|deptno| 总薪资|            平均值| 最大值|deptno|deptname|
+------+-------+------------------+-------+------+--------+
|     1|43000.0|14333.333333333334|20000.0|     1|销售部门|
|     2|38000.0|           19000.0|20000.0|     2|研发部门|
+------+-------+------------------+-------+------+--------+

drop

userDF.select($"deptno",$"salary" )
.groupBy($"deptno")
.agg(sum($"salary") as "总薪资",avg($"salary") as "平均值",max($"salary") as "最大值")
.join(deptDF,$"dept.deptno" === $"user.deptno")
.drop($"dept.deptno")
.show()
+------+-------+------------------+-------+--------+
|deptno| 总薪资|            平均值| 最大值|deptname|
+------+-------+------------------+-------+--------+
|     1|43000.0|14333.333333333334|20000.0|销售部门|
|     2|38000.0|           19000.0|20000.0|研发部门|
+------+-------+------------------+-------+--------+

orderBy

userDF.select($"deptno",$"salary" )
.groupBy($"deptno")
.agg(sum($"salary") as "总薪资",avg($"salary") as "平均值",max($"salary") as "最大值")
.join(deptDF,$"dept.deptno" === $"user.deptno")
.drop($"dept.deptno")
.orderBy($"总薪资" asc)
.show()
+------+-------+------------------+-------+--------+
|deptno| 总薪资|            平均值| 最大值|deptname|
+------+-------+------------------+-------+--------+
|     2|38000.0|           19000.0|20000.0|研发部门|
|     1|43000.0|14333.333333333334|20000.0|销售部门|
+------+-------+------------------+-------+--------+

map

userDF.map(row => (row.getString(0),row.getInt(1))).show()
+--------+---+
|    name|age|
+--------+---+
|zhangsan| 28|
+--------+---+

默认情况下SparkSQL会在执行SQL的时候将序列化里面的参数数值,一般情况下系统提供了常见类型的Encoder,如果出现了没有的Encoder,用户需要声明 隐式转换Encoder

implicit val mapEncoder = org.apache.spark.sql.Encoders.kryo[Map[String, Any]]
userDF.map(row => row.getValuesMap[Any](List("name","age","salary")))
.foreach(map=>{
  var name=map.getOrElse("name","")
  var age=map.getOrElse("age",0)
  var salary=map.getOrElse("salary",0.0)
  println(name+" "+age+" "+salary)
})

flatMap

implicit val mapEncoder = org.apache.spark.sql.Encoders.kryo[Map[String, Any]]
userDF.flatMap(row => row.getValuesMap(List("name","age")))
    .map(item => item._1 +" -> "+item._2)
    .show()
+---------------+
|          value|
+---------------+
|name -> Michael|
|      age -> 29|
|   name -> Andy|
|      age -> 30|
| name -> Justin|
|      age -> 19|
|  name -> Kaine|
|      age -> 20|
|   name -> Lisa|
|      age -> 19|
+---------------+

SQL获取DataFrame

SQL查询记录

userDF.createTempView("t_emp")
spark.sql("select * from t_emp where name= 'Michael' or salary >= 10000")
.show()
+-------+---+-------+-----+--------+------+
|   name|age| salary|  sex|     job|deptno|
+-------+---+-------+-----+--------+------+
|Michael| 29|20000.0| true| MANAGER|     1|
|   Andy| 30|15000.0| true|SALESMAN|     1|
|  Kaine| 20|20000.0| true| MANAGER|     2|
|   Lisa| 19|18000.0|false|SALESMAN|     2|
+-------+---+-------+-----+--------+------+

group by 统计

userDF.createTempView("t_emp")
spark.sql("select deptno,sum(salary),max(age) from t_emp group by deptno")
.show()
+------+-----------+--------+
|deptno|sum(salary)|max(age)|
+------+-----------+--------+
|     1|    43000.0|      30|
|     2|    38000.0|      20|
+------+-----------+--------+

having字句

userDF.createTempView("t_emp")
spark.sql("select deptno,sum(salary) totalSalary ,max(age) maxAge from t_emp group by deptno having totalSalary > 40000  ")
.show()
+------+-----------+------+
|deptno|totalSalary|maxAge|
+------+-----------+------+
|     1|    43000.0|    30|
+------+-----------+------+

join查询

userDF.createTempView("t_emp")
deptDF.createTempView("t_dept")
val sql="select e.*,t.deptname from t_emp as e left join t_dept as t on e.deptno = t.deptno "
spark.sql(sql)
.show()
+-------+---+-------+-----+--------+------+--------+
|   name|age| salary|  sex|     job|deptno|deptname|
+-------+---+-------+-----+--------+------+--------+
|Michael| 29|20000.0| true| MANAGER|     1|销售部门|
|   Andy| 30|15000.0| true|SALESMAN|     1|销售部门|
| Justin| 19| 8000.0| true|   CLERK|     1|销售部门|
|  Kaine| 20|20000.0| true| MANAGER|     2|研发部门|
|   Lisa| 19|18000.0|false|SALESMAN|     2|研发部门|
+-------+---+-------+-----+--------+------+--------+

limit数目限制

userDF.createTempView("t_emp")
deptDF.createTempView("t_dept")
spark.sql("select e.*,t.deptname from t_emp as e left outer join t_dept as t on e.deptno = t.deptno limit 5")
.show()
+-------+---+-------+-----+--------+------+--------+
|   name|age| salary|  sex|     job|deptno|deptname|
+-------+---+-------+-----+--------+------+--------+
|Michael| 29|20000.0| true| MANAGER|     1|销售部门|
|   Andy| 30|15000.0| true|SALESMAN|     1|销售部门|
| Justin| 19| 8000.0| true|   CLERK|     1|销售部门|
|  Kaine| 20|20000.0| true| MANAGER|     2|研发部门|
|   Lisa| 19|18000.0|false|SALESMAN|     2|研发部门|
+-------+---+-------+-----+--------+------+--------+

开窗函数

在正常的统计分析中 ,通常使用聚合函数作为分析,聚合分析函数的特点是将n行记录合并成一行,在数据库的统计当中还有一种统计称为开窗统计,开窗函数可以实现将一行变成多行。可以将数据库查询的每一条记录比作是一幢高楼的一层, 开窗函数就是在每一层开一扇窗, 让每一层能看到整装楼的全貌或一部分。

查询每个部门员工信息,并返回本部门的平均薪资

Michael,29,20000,true,MANAGER,1
Andy,30,15000,true,SALESMAN,1
Justin,19,8000,true,CLERK,1
Kaine,20,20000,true,MANAGER,2
Lisa,19,18000,false,SALESMAN,2
----
+-------+---+-------+-----+--------+------+
|   name|age| salary|  sex|     job|deptno|
+-------+---+-------+-----+--------+------+
|Michael| 29|20000.0| true| MANAGER|     1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|
|   Andy| 30|15000.0| true|SALESMAN|     1|
| Justin| 19| 8000.0| true|   CLERK|     1|
|  Kaine| 20|20000.0| true| MANAGER|     2|
|   Lisa| 19|18000.0|false|SALESMAN|     2|
+-------+---+-------+-----+--------+------+
val sql="select * ,avg(salary) over(partition by deptno) as avgSalary  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+---------+
|   name|age| salary|  sex|     job|deptno|avgSalary|
+-------+---+-------+-----+--------+------+---------+
|Michael| 29|20000.0| true| MANAGER|     1|  15750.0|
|   Jimi| 25|20000.0| true|SALESMAN|     1|  15750.0|
|   Andy| 30|15000.0| true|SALESMAN|     1|  15750.0|
| Justin| 19| 8000.0| true|   CLERK|     1|  15750.0|
|  Kaine| 20|20000.0| true| MANAGER|     2|  19000.0|
|   Lisa| 19|18000.0|false|SALESMAN|     2|  19000.0|
+-------+---+-------+-----+--------+------+---------+

ROW_NUMBER()

统计员工在部门内部薪资排名

val sql="select * , ROW_NUMBER() over(partition by deptno order by salary DESC)  rank  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+----+
|   name|age| salary|  sex|     job|deptno|rank|
+-------+---+-------+-----+--------+------+----+
|Michael| 29|20000.0| true| MANAGER|     1|   1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|   2|
|   Andy| 30|15000.0| true|SALESMAN|     1|   3|
| Justin| 19| 8000.0| true|   CLERK|     1|   4|
|  Kaine| 20|20000.0| true| MANAGER|     2|   1|
|   Lisa| 19|18000.0|false|SALESMAN|     2|   2|
+-------+---+-------+-----+--------+------+----+

统计员工在公司所有员工的薪资排名

val sql="select * , ROW_NUMBER() over(order by salary DESC)  rank  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+----+
|   name|age| salary|  sex|     job|deptno|rank|
+-------+---+-------+-----+--------+------+----+
|Michael| 29|20000.0| true| MANAGER|     1|   1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|   2|
|  Kaine| 20|20000.0| true| MANAGER|     2|   3|
|   Lisa| 19|18000.0|false|SALESMAN|     2|   4|
|   Andy| 30|15000.0| true|SALESMAN|     1|   5|
| Justin| 19| 8000.0| true|   CLERK|     1|   6|
+-------+---+-------+-----+--------+------+----+

可以看出ROW_NUMBER()函数只能计算结果在当前开窗函数中的顺序。并不能计算排名。

DENSE_RANK()

计算员工在公司薪资排名

val sql="select * , DENSE_RANK() over(order by salary DESC)  rank  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+----+
|   name|age| salary|  sex|     job|deptno|rank|
+-------+---+-------+-----+--------+------+----+
|Michael| 29|20000.0| true| MANAGER|     1|   1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|   1|
|  Kaine| 20|20000.0| true| MANAGER|     2|   1|
|   Lisa| 19|18000.0|false|SALESMAN|     2|   2|
|   Andy| 30|15000.0| true|SALESMAN|     1|   3|
| Justin| 19| 8000.0| true|   CLERK|     1|   4|
+-------+---+-------+-----+--------+------+----+

计算员工在公司部门薪资排名

val sql="select * , DENSE_RANK() over(partition by deptno order by salary DESC)  rank  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+----+
|   name|age| salary|  sex|     job|deptno|rank|
+-------+---+-------+-----+--------+------+----+
|Michael| 29|20000.0| true| MANAGER|     1|   1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|   1|
|   Andy| 30|15000.0| true|SALESMAN|     1|   2|
| Justin| 19| 8000.0| true|   CLERK|     1|   3|
|  Kaine| 20|20000.0| true| MANAGER|     2|   1|
|   Lisa| 19|18000.0|false|SALESMAN|     2|   2|
+-------+---+-------+-----+--------+------+----+

RANK()

该函数和DENSE_RANK()类似,不同的是RANK计算的排名顺序不连续。

计算员工在公司部门薪资排名

val sql="select * , RANK() over(partition by deptno order by salary DESC)  rank  from t_emp"
spark.sql(sql).show()
+-------+---+-------+-----+--------+------+----+
|   name|age| salary|  sex|     job|deptno|rank|
+-------+---+-------+-----+--------+------+----+
|Michael| 29|20000.0| true| MANAGER|     1|   1|
|   Jimi| 25|20000.0| true|SALESMAN|     1|   1|
|   Andy| 30|15000.0| true|SALESMAN|     1|   3|
| Justin| 19| 8000.0| true|   CLERK|     1|   4|
|  Kaine| 20|20000.0| true| MANAGER|     2|   1|
|   Lisa| 19|18000.0|false|SALESMAN|     2|   2|
+-------+---+-------+-----+--------+------+----+

自定义函数

单行函数

spark.udf.register("deptFun",(v:String)=> v+"部门")
spark.udf.register("annualSalary",(s:Double)=> s * 12)
val sql="select deptFun(deptno) deptname ,name,annualSalary(salary) from t_emp"
spark.sql(sql).show()
+--------+-------+------------------------+
|deptname|   name|UDF:annualSalary(salary)|
+--------+-------+------------------------+
|   1部门|Michael|                240000.0|
|   1部门|   Jimi|                240000.0|
|   1部门|   Andy|                180000.0|
|   1部门| Justin|                 96000.0|
|   2部门|  Kaine|                240000.0|
|   2部门|   Lisa|                216000.0|
+--------+-------+------------------------+

聚合函数

无类型聚合函数

用户必须扩展UserDefinedAggregateFunction抽象类以实现自定义无类型聚合函数。例如,用户定义的平均值可能如下所示:

1,苹果,4.5,2,001
2,橘子,2.5,5,001
3,机械键盘,800,1,002
var orderDF=  spark.sparkContext.textFile("D:/order.log")
.map(_.split(","))
.map(arr=>OrderLog(arr(2).toDouble,arr(3).toInt,arr(4)))
.toDF()

orderDF.createTempView("t_order")

spark.sql("select uid,sum(price * count) cost from t_order group by uid").show()
+---+-----+
|uid| cost|
+---+-----+
|001| 21.5|
|002|800.0|
+---+-----+

自定义聚合函数

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class MySumAggregateFunction extends UserDefinedAggregateFunction{
  /**
    * 说明输入的字段类型,name没有意义,只需要给类型
    * @return
    */
  override def inputSchema: StructType = {
    new StructType()
      .add("price",DoubleType)
      .add("count",IntegerType)
  }

  /**
    * 返回结果的Schema 声明
    * @return
    */
  override def bufferSchema: StructType = {
    new StructType()
      .add("cost",DoubleType)
  }
 /**
    * 结果类型
    * @return
    */
  override def dataType: DataType = DoubleType
  /**
    * 一般返回true即可
    * @return
    */
  override def deterministic: Boolean = true

  /**
    * 初始化值
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0.0
  }

  /**
    * 局部计算
    * @param buffer
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    println(input.schema)
    val cost=input.getAs[Double]("input0")* input.getAs[Int]("input1")
    buffer(0)=buffer.getDouble(0)+cost
  }

  /**
    * 合并最终结果
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getDouble(0)+buffer2.getDouble(0)
  }

  /**
    * 计算最终结果
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = {
      buffer.getDouble(0)
  }
}
case class OrderLog(price: Double, count: Int,userid:String)

spark.udf.register("customsum",new MySumAggregateFunction())
var orderDF=  spark.sparkContext.textFile("D:/order.log")
.map(_.split(","))
.map(arr=> OrderLog(arr(2).toDouble,arr(3).toInt,arr(4)))
.toDF()

orderDF.createTempView("t_order")
spark.sql("select userid,customsum(price, count) cost from t_order group by userid").show()
+------+-----+
|userid| cost|
+------+-----+
|   001| 21.5|
|   002|800.0|
+------+-----+

有类型聚合函数

强类型数据集的用户定义聚合围绕Aggregator抽象类。例如,类型安全的用户定义平均值可能如下所示:

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

case class Average(var sum: Double, var count: Int)
object MyAggregator extends Aggregator[GenericRowWithSchema,Average,Double]{
  override def zero: Average = {
    Average(0,0)
  }

  override def reduce(b: Average, a: GenericRowWithSchema): Average = {
    val total=b.sum  + a.getAs[Int]("count")*a.getAs[Double]("price")
    val count= b.count + 1
    Average(total,count)
  }

  override def merge(b1: Average, b2: Average): Average = {
    b1.sum=b1.sum+b2.sum
    b1.count=b1.count+b2.count
    b1
  }

  override def finish(reduction: Average): Double = {
    reduction.sum/reduction.count
  }

  override def bufferEncoder: Encoder[Average] = {
    Encoders.product[Average]
  }

  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}
case class OrderLog(price: Double, count: Int,userid:String)

var orderDS=  spark.sparkContext.textFile("D:/order.log")
.map(_.split(","))
.map(arr=> OrderLog(arr(2).toDouble,arr(3).toInt,arr(4)))
.toDS()

val agg = new MyAggregator().toColumn.name("avgSalary")

orderDS.groupBy("userid").agg(agg).show()
+------+---------+
|userid|avgSalary|
+------+---------+
|   001|    10.75|
|   002|    800.0|
+------+---------+

Load & save 函数

  • 从MySQL中加载数据
//访问传统数据库
val jdbcDF = spark.read
                  .format("jdbc")
                  .option("url", "jdbc:mysql://CentOS:3306/test")
                  .option("dbtable", "t_user")
                  .option("user", "root")
                  .option("password", "root")
                  .load()

jdbcDF.createTempView("t_user")
val dataFrame: DataFrame = spark.sql("select * from t_user")
  • 读取CSV格式数据
val frame: DataFrame = spark.read
			.option("inferSchema","true")
			.option("header", "true").csv("d:/csv")
    frame.show()
  • 将结果写入MySQL
//将结果写入到MySQL中
val personRDD = spark.sparkContext.parallelize(Array("14 tom 1500", "15 jerry 20000", "16 kitty 26000")).map(_.split(" "))
//通过StrutType直接指定每个字段的schema
val schema = StructType(
    List(
        StructField("id",IntegerType,true),
        StructField("name",StringType,true),
        StructField("salary",DoubleType,true)
    )
)
val rowRDD = personRDD.map(p => Row(p(0).toInt, p(1).trim, p(2).toDouble))
val personDataFrame = spark.createDataFrame(rowRDD,schema)
val prop = new Properties()
prop.put("user", "root")
prop.put("password", "root")
//将数据追加到数据库
personDataFrame.write.mode("append").jdbc("jdbc:mysql://CentOS:3306/test",
                                          "t_user",prop)
  • 将数据存储为json
val frame: DataFrame = spark.sparkContext.textFile("D:/order.log")
      .map(_.split(","))
      .map(x => (x(0).toInt, x(1), x(2).toDouble, x(3).toInt, x(4)))
      .map(x=>(x._5,x._3*x._4))
      .groupByKey()
      .map(x=>(x._1,x._2.sum))
      .toDF("uid", "total")
 frame.write.format("json").save("D:/a")
  • 将数据格式存储为csv
val frame: DataFrame = spark.sparkContext.textFile("D:/order.log")
      .map(_.split(","))
      .map(x => (x(0).toInt, x(1), x(2).toDouble, x(3).toInt, x(4)))
      .map(x=>(x._5,x._3*x._4))
      .groupByKey()
      .map(x=>(x._1,x._2.sum))
      .toDF("uid", "total")
frame.write.format("csv").option("header", "true").save("D:/csv")
  • Parquet格式文件(该种格式文件存储 schema信息)
val frame: DataFrame = spark.sparkContext.textFile("D:/order.log")
      .map(_.split(","))
      .map(x => (x(0).toInt, x(1), x(2).toDouble, x(3).toInt, x(4)))
      .toDF("id", "name","price","count","uid")
 frame.write.parquet("D:/order.parquet")
  • Parquet格式文件
val f: DataFrame = spark.read.parquet("D:/order.parquet")
f.createOrReplaceTempView("t_order")
val resutFrame = spark.sql("select * from t_order")
resutFrame.map(row=>row.getAs[String]("name")).foreach(x=>println(x))
spark.stop()
//直接执行SQL
val sqlDF = spark.sql("SELECT count(*) FROM parquet.`D:/order.parquet`")
sqlDF.show()
  • 分区存储
val frame: DataFrame = spark.sparkContext.textFile("D:/order.log")
.map(_.split(","))
.map(x => (x(0).toInt, x(1), x(2).toDouble, x(3).toInt, x(4)))
.toDF("id", "name","price","count","uid")
frame.write.format("json").mode(SaveMode.Overwrite).partitionBy("uid").save("D:/res")