SparkSQL实战小项目之热门商品top3

  • 一、说明及需求分析
  • 二、准备测试数据
  • 三、思路分析
  • 四、编码实现
  • 五、验证结果


一、说明及需求分析


  • 软件及环境centos7 + hive-2.3.3 + spark2.4.8 + idea
  • 需求分析
要求实现电商平台上各区域热门商品 Top3,这里的热门商品是从点击量的维度来看的。计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。要实现的效果如下图所示:

mysql查询价格商品销量总和前三_mysql查询价格商品销量总和前三

二、准备测试数据


本次案例的测试数据涉及到三张表的数据,即1 张用户行为表, 1 张城市表, 1 张产品表,三张表均需要在Hive中创建,并将其数据load到对应的各表中,测试数据

  • 说明,在Hive中创建表,目前有两种方式:
  • 一种是可以直接执行hive,然后在CLI中进行创建
  • 一种是利用sparksql的cli来创建,即hiveserver2 + beeline的方式创建(本案例使用这种
  • 前提
    a)先需要启动hiverserver2服务,在spark安装路径下执行类似如下的命令:
sbin/start-thriftserver.sh --master spark://niit01:7077  --hiveconf hive.server2.thrift.bind.host=niit01 --hiveconf hive.server2.thrift.port=10000
  • b)启动beeline,在spark安装路径下,执行如下类似的命令:
bin/beeline
  • 然后,再输入链接:
!connect jdbc:hive2://niit01:10000
  • 建表:在beeline命令行下输入创建表语句即可:
  • 用户行为表user_visit_action,建表语句如下:
CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';

截图所示仅供参考:

mysql查询价格商品销量总和前三_hive_02

  • 城市表city_info,建表语句如下:
CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';
  • 产品信息表product_info,建表语句如下:
CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';
  • 导入数据:将测试数据(上述有地址)先上传至虚拟机中
  • 将数据导入user_visit_action表,执行如下命令:
load data local inpath '/root/testdatas/user_visit_action.txt' into table user_visit_action;
  • 将数据导入city_info表,执行如下命令:
load data local inpath '/root/testdatas/city_info.txt' into table city_info;
  • 将数据导入product_info表,执行如下命令:
load data local inpath '/root/testdatas/product_info.txt' into table product_info;

三、思路分析


  • 查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区与 Product_info 表连接得到产品名称
  • 按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
  • 每个地区内按照点击次数降序排列
  • 只取前三名. 并把结果保存在数据库中
  • 城市备注需要自定义 UDAF 函数

四、编码实现


  • 创建自定义函数udf:AreaClickUDAF
package com.niit.spark.sql

import java.text.DecimalFormat
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class AreaClickUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(StructField("city_name", StringType) :: Nil)
    //        StructType(Array(StructField("city_name", StringType)))
  }

  // 缓存的数据的类型: 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    // MapType(StringType, LongType) 还需要标注 map的key的类型和value的类型
    StructType(StructField("city_count", MapType(StringType, LongType)) :: StructField("total_count", LongType) :: Nil)
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 首先拿到城市名, 然后把成名作为key去查看map中是否存在, 如果存在就把对应的值 +1, 如果不存在, 则直径0+1
    val cityName = input.getString(0)
    //        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
    buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
    // 碰到一个城市, 则总的点击量要+1
    buffer(1) = buffer.getLong(1) + 1L
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    // 把map1的键值对与map2中的累积, 最后赋值给buffer1
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (k, v)) =>
        map + (k -> (map.getOrElse(k, 0L) + v))
    }

    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 最终的输出. "北京21.2%,天津13.2%,其他65.6%"
  override def evaluate(buffer: Row): Any = {
    val cityCountMap = buffer.getAs[Map[String, Long]](0)
    val totalCount = buffer.getLong(1)

    var citysRatio: List[CityRemark] = cityCountMap.toList.sortBy(-_._2).take(2).map {
      case (cityName, count) => {
        CityRemark(cityName, count.toDouble / totalCount)
      }
    }
    // 如果城市的个数超过2才显示其他
    if (cityCountMap.size > 2) {
      citysRatio = citysRatio :+ CityRemark("其他", citysRatio.foldLeft(1D)(_ - _.cityRatio))
    }
    citysRatio.mkString(", ")
  }
}


case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
}
  • 创建object实现业务:AreaClickApp
import org.apache.spark.sql.SparkSession

object AreaClickApp {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local[2]")
      .appName("AreaClickApp")
      .enableHiveSupport()
      .getOrCreate()
    spark.sql("use default")
    // 0 注册自定义聚合函数
    spark.udf.register("city_remark", new AreaClickUDAF)
    // 1. 查询出所有的点击记录,并和城市表产品表做内连接
    spark.sql(
      """
        |select
        |    c.*,
        |    v.click_product_id,
        |    p.product_name
        |from user_visit_action v join city_info c join product_info p on v.city_id=c.city_id and v.click_product_id=p.product_id
        |where click_product_id>-1
            """.stripMargin).createOrReplaceTempView("t1")

    // 2. 计算每个区域, 每个产品的点击量
    spark.sql(
      """
        |select
        |    t1.area,
        |    t1.product_name,
        |    count(*) click_count,
        |    city_remark(t1.city_name)
        |from t1
        |group by t1.area, t1.product_name
            """.stripMargin).createOrReplaceTempView("t2")

    // 3. 对每个区域内产品的点击量进行倒序排列
    spark.sql(
      """
        |select
        |    *,
        |    rank() over(partition by t2.area order by t2.click_count desc) rank
        |from t2
            """.stripMargin).createOrReplaceTempView("t3")

    // 4. 每个区域取top3
    spark.sql(
      """
        |select
        |    *
        |from t3
        |where rank<=3
            """.stripMargin).show
  }
}

五、验证结果


  • 说明本次的项目参考了网络上其他资料,收集整理出来,供诸君参考!