import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, MapType, StringType, StructField, StructType}
import java.text.DecimalFormat
object SparkSQL10_TopN {
def main(args: Array[String]): Unit = {
// 创建SparkSession对象
val spark: SparkSession = SparkSession
.builder()
.enableHiveSupport()
.master("local[*]")
.appName("")
.getOrCreate()
// 选择使用hive的哪个库
spark.sql("use default")
// 注册一个自定义函数
spark.udf.register("remark", new CityClickUDAF)
spark.sql(
"""
|select ci.area,
| pi.product_name,
| count(if(uva.click_product_id != -1, 1, null)) as sumClick,
| remark(ci.city_name) as c_remark,
| row_number() over (partition by ci.area order by count(if(uva.click_product_id != -1, 1, null)) desc ) as ranking
| from user_visit_action uva
| join product_info pi on uva.click_product_id = pi.product_id
| join city_info ci on uva.city_id = ci.city_id
| group by ci.area, pi.product_name
|""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
|select t1.area,
| t1.product_name,
| t1.sumClick,
| t1.c_remark,
| t1.ranking
|from t1
|where t1.ranking <= 3
|""".stripMargin).show(false)
//释放资源
spark.stop()
}
}
// 自定义一个UDAF聚合函数,完成城市点击量统计
class CityClickUDAF extends UserDefinedAggregateFunction {
// 输入数据类型
override def inputSchema: StructType = {
StructType(Array(StructField("city_name", StringType)))
}
//缓存的数据类型 用Map缓存城市以及该城市点击数 :北京->2,天津->3;总的点击量Long:北京2 + 天津3 = 5
override def bufferSchema: StructType = {
StructType(Array(
StructField("city_count", MapType(StringType, LongType))
, StructField("total_count", LongType)))
}
// 输出的数据类型 北京21.2%,天津13.2%,其他65.6%
override def dataType: DataType = StringType
// 稳定性
override def deterministic: Boolean = false
// 为缓存数据进行初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
buffer(1) = 0L
}
// 对缓存数据进行更新
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val cityName: String = input.getAs[String](0)
// 取数据时,需要指定数据的格式
val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
// 取出不可变集合里面的值,加1,再放回不可变集合,然后放回内存缓存区中
buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L)) // 城市点击量 + 1
buffer(1) = buffer.getAs[Long](1) + 1L // 总点击量 + 1
}
// 分区间的缓存合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 获取每一个节点城市点击缓存Map
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
// 合并两个节点上的城市点击
buffer1(0) = map1.foldLeft(map2) {
case (map2, (k, v)) => {
map2 + (k -> (map2.getOrElse(k, 0L) + v))
}
}
// 合并两个节点上的总点击数
buffer1(1) = buffer1.getAs[Long](1) + buffer2.getAs[Long](1)
}
// 得到最终的输出效果 北京21.2%,天津13.2%,其他65.6%
override def evaluate(buffer: Row): Any = {
// 取出缓存中的数据
val map = buffer.getAs[Map[String, Long]](0)
val totalCount: Long = buffer.getAs[Long](1)
// 对Map集合中城市点击记录进行降序排序,取前2个
val sortList: List[(String, Long)] = map.toList.sortWith((left, right) => {
left._2 > right._2
}).take(2)
// 计算排名前2的点击率
var citRatio: List[CityRemark] = sortList.map {
case (cityName, cnt) => {
CityRemark(cityName, cnt.toDouble / totalCount)
}
}
//如果城市的个数超过2个,那么其它情况的处理
if (map.size > 2) {
citRatio = citRatio :+ CityRemark("其它", citRatio.foldLeft(1D)(_ - _.cityRatio))
}
citRatio.mkString(",")
}
}
case class CityRemark(cityName: String, cityRatio: Double) {
val formatter = new DecimalFormat("0.00%")
override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
}