今天分享一个使用sparksql的

spark.write.format("hbase").save()

spark.read.format("hbase").load()

方式读写Hbase的方法。

1、引入maven依赖

只需要引用sparksql依赖和hbase-mapreduce包,即可实现spark sql读写hbase功能。

<dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <scope>provided</scope>
            <version>2.4.0</version>
        </dependency>
        <!--hbase-->
        <dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-mapreduce</artifactId>
            <version>2.1.0</version>
        </dependency>

2、创建包和类如下

spark读取gbk spark读取hbase_spark读取gbk

2.1、Scala package

        告诉spark框架,这个包下面的代码是spark的souce和sink,spark框架运行时,用户调用spark.read.format("hbase").load()时,会自动定位到该包下的代码执行读取和写入。

package com.spark
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter}

/**
 * spark sql自定义hbase的sink和source
 * @author ket
 * @since 2021-11-23
 */
package object hbase {
  implicit class HBaseDataFrameReader(reader: DataFrameReader) {
    def hbase: DataFrame = reader.format("com.spark.hbase").load
  }

  implicit class HBaseDataFrameWriter[T](writer: DataFrameWriter[T]) {
    def hbase:Unit = writer.format("com.spark.hbase").save
  }
}

上述代码的 com.spark.habse 需要根据自己包路径指定自己的路径。

2.2、SerializableConfiguration(序列化的配置,用户保存sparkSql读写habse的自定义配置)

package com.spark.hbase

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import org.apache.hadoop.conf.Configuration

import scala.util.control.NonFatal

class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
  private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
    out.defaultWriteObject()
    value.write(out)
  }

  private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
    value = new Configuration(false)
    value.readFields(in)
  }

  def tryOrIOException(block: => Unit) {
    try {
      block
    } catch {
      case e: IOException => throw e
      case NonFatal(t) => throw new IOException(t)
    }
  }
}

2.3、HBaseRelation(Hbase的实际读写操作在这个类里面实现)

继承spark SQL的这3个类 org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}

自定义hbase连接的配置的key,到时候可以通过spark.read.format("hbase").option()的方式传入对应配置参数.

//hbase的zookeeper连接地址
  val HBASE_ZK_QUORUM_KEY: String = "hbase.zookeeper.quorum"
  //zookeeper端口
  val HBASE_ZK_PORT_KEY: String = "hbase.zookeeper.property.clientPort"
  //需要读取或者写入的hbase table名字
  val HBASE_TABLE: String = "hbase.table"
  //需要读取或者写入的hbase 列簇
  val HBASE_TABLE_FAMILY: String = "hbase.family"
  //读多个列名指定的分割符号
  val SPLIT: String = ","
  //读hbase才需要传入的配置,要读取的列名
  val HBASE_TABLE_SELECT_FIELDS: String = "hbase.select.family.column"
  //写hbase时指定dataframe哪个列作为rowkey
  val HBASE_TABLE_ROWKEY_NAME: String = "hbase.rowkey.column.name"

因为SparkSQL的数据类型和Hbase的数据类型不能一一对应,所以统一转换成String类型读取和写入。

类代码如下所示:

package com.spark.hbase

import com.lava.spark.hbase.HBaseRelation._
import com.lava.utils.StringUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.client.{Put, Result, Scan}
import .ImmutableBytesWritable
import org.apache.hadoop.hbase.mapreduce.{TableInputFormat, TableOutputFormat}
import org.apache.hadoop.hbase.protobuf.ProtobufUtil
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructType}

import java.util.Base64

/**
 * @author Ket
 * @Date 2021/11/22
 */
object HBaseRelation {
  //hbase的zookeeper连接地址
  val HBASE_ZK_QUORUM_KEY: String = "hbase.zookeeper.quorum"
  //zookeeper端口
  val HBASE_ZK_PORT_KEY: String = "hbase.zookeeper.property.clientPort"
  //需要读取或者写入的hbase table名字
  val HBASE_TABLE: String = "hbase.table"
  //需要读取或者写入的hbase 列簇
  val HBASE_TABLE_FAMILY: String = "hbase.family"
  //读多个列名指定的分割符号
  val SPLIT: String = ","
  //读hbase才需要传入的配置,要读取的列名
  val HBASE_TABLE_SELECT_FIELDS: String = "hbase.select.family.column"
  //写hbase时指定dataframe哪个列作为rowkey
  val HBASE_TABLE_ROWKEY_NAME: String = "hbase.rowkey.column.name"
}

case class HBaseRelation(
                          params: Map[String, String],
                          override val schema: StructType)
                        (@transient val sparkSession: SparkSession)
  extends BaseRelation
    with TableScan
    with InsertableRelation
    with Serializable {
  override def sqlContext: SQLContext = sparkSession.sqlContext


  private val wrappedConf = {
    val hConf = {
      val conf = HBaseConfiguration.create
      val hbaseParams = params.filterKeys(_.contains("hbase"))
      hbaseParams.foreach(f => conf.set(f._1, f._2))
      conf
    }
    new SerializableConfiguration(hConf)
  }

  def hbaseConf: Configuration = wrappedConf.value

 /**
   * 读取hbase中的数据并返回RDD[ROW]
   * @return
   */
  override def buildScan(): RDD[Row] = {
    hbaseConf.set(TableInputFormat.INPUT_TABLE, params(HBASE_TABLE))
    val scan: Scan = new Scan()

    val cfBytes = Bytes.toBytes(params(HBASE_TABLE_FAMILY))
    scan.addFamily(cfBytes)

    val fields = params(HBASE_TABLE_SELECT_FIELDS).split(SPLIT)
    fields.foreach { field =>
      scan.addColumn(cfBytes, Bytes.toBytes(field))
    }

    val scanToString = new String(Base64.getEncoder.encode(ProtobufUtil.toScan(scan).toByteArray))

    hbaseConf.set(
      TableInputFormat.SCAN,
      scanToString
    )

    val dataRdd: RDD[(ImmutableBytesWritable, Result)] = sqlContext.sparkContext
      .newAPIHadoopRDD(
        hbaseConf,
        classOf[TableInputFormat],
        classOf[ImmutableBytesWritable],
        classOf[Result]
      )

    val rowsRdd: RDD[Row] = dataRdd.mapPartitions { iter =>
      iter.map { case (_, result) =>
        val rowKey = Bytes.toString(result.getRow)
        val values: Seq[String] = fields.map { field =>

          val fieldBytes: Array[Byte] = result.getValue(cfBytes, Bytes.toBytes(field))

          val fieldValue: String = Bytes.toString(fieldBytes)
          fieldValue
        }
        val data = rowKey +: values
        Row.fromSeq(data)
      }
    }
    rowsRdd
  }

  /**
   * 写入数据到Hbase
   */
  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    val columns = data.columns
    val dataCols = columns.filter(!_.equals(params(HBASE_TABLE_ROWKEY_NAME)))
    val cfBytes = Bytes.toBytes(params(HBASE_TABLE_FAMILY))

    val stringTypeCols = columns.map(m => col(m).cast(StringType).as(m))
    val putsRdd: RDD[(ImmutableBytesWritable, Put)] = data.select(stringTypeCols: _*)
      .rdd.mapPartitions(iter => {
      iter.map(row => {
        val rowKeyValue = row.getAs[String](params(HBASE_TABLE_ROWKEY_NAME))
        val rowKey = new ImmutableBytesWritable(Bytes.toBytes(rowKeyValue))

        val put = new Put(rowKey.get())
        dataCols.foreach(column => {
          val columnValue = row.getAs[String](column)
          if (StringUtils.checkStr(columnValue)) {
            put.addColumn(cfBytes, Bytes.toBytes(column), Bytes.toBytes(columnValue))
          }
        })
        (rowKey, put)
      })
    })

    hbaseConf.set(TableOutputFormat.OUTPUT_TABLE, params(HBASE_TABLE))

    val job = Job.getInstance(hbaseConf)
    job.setOutputKeyClass(classOf[ImmutableBytesWritable])
    job.setOutputValueClass(classOf[Result])
    job.setOutputFormatClass(classOf[TableOutputFormat[ImmutableBytesWritable]])

    putsRdd.saveAsNewAPIHadoopDataset(job.getConfiguration)
  }
}

2.4、DefaultSource (默认类名,不可更改)

因为SparkSQL源码中调用读写操作是通过反射加载 DefaultSource 类的,这个类名是源码中写死的,spark源码实际是找到我们指定的scala包,然后调用包下的这个类实现读写操作的。所以这个类名不可更改。

这个类需要实现以下3个类 org.apache.spark.sql.sources.{CreatableRelationProvider, DataSourceRegister, RelationProvider}

继承后实现以下方法可以指定数据操作的名字,用于spark.read.format("hbase")中识别,和包名一样。

override def shortName(): String = "hbase"
package com.spark.hbase

import com.lava.spark.hbase.HBaseRelation.{HBASE_TABLE_SELECT_FIELDS, SPLIT}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

/**
 * @author Ket
 * @Date 2021/11/23
 */
class DefaultSource extends RelationProvider
  with CreatableRelationProvider
  with DataSourceRegister
  with Serializable {

  /**
   * 最后读数据是调用这个方法实现的
   */
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    //HBaseRelation读取Hbase数据过来后得到的RDD[Row] 的schema在这里指定
    val fields = parameters(HBASE_TABLE_SELECT_FIELDS)
      .split(SPLIT)
      .map { field =>
        StructField(field, StringType, nullable = true)
      }
    val schema: StructType = StructType(
      StructField("rowKey",StringType,nullable = false) +: fields)
    HBaseRelation(parameters, schema)(sqlContext.sparkSession)
  }

  /**
   * 最后写数据是调用这个方法实现的
   */
  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val relation = HBaseRelation(parameters, data.schema)(sqlContext.sparkSession)

    relation.insert(data,overwrite = false)
    relation
  }

  override def shortName(): String = "hbase"
}

我们自定义SparkSQL已经开发完成,下面来进行验证。

3、验证

3.1、验证读数据

查看Hbase的表 student 数据如下

spark读取gbk spark读取hbase_sql_02

 读取列name、id,代码如下所示:

object Test{
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]")
      .appName("test hbase")
      .getOrCreate()

    val rawDF = spark.read.format("com.spark.hbase")
      .option(HBaseRelation.HBASE_TABLE, "student")
      .option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
      .option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
      .option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
      .option(HBaseRelation.HBASE_TABLE_SELECT_FIELDS, "name,id")
      .load()

    rawDF.show()
    spark.stop()
  }
}

idea控制台打印如下:

spark读取gbk spark读取hbase_spark_03

 读数据验证完成

3.2、写数据验证

写数据我们直接把读出来的数据rowKey后面拼接 yyds 后,并指定yyds这一列为rowKey列写入

代码如下:

object Test {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[*]")
      .appName("test hbase")
      .getOrCreate()

    val rawDF = spark.read.format("com.spark.hbase")
      .option(HBaseRelation.HBASE_TABLE, "student")
      .option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
      .option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
      .option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
      .option(HBaseRelation.HBASE_TABLE_SELECT_FIELDS, "name,id")
      .load()

    import spark.implicits._
    import org.apache.spark.sql.functions._
    rawDF.select(concat($"rowKey", lit("yyds")).as("yyds"), $"name", $"id")
      .write.format("com.spark.hbase")
      .option(HBaseRelation.HBASE_TABLE, "student")
      .option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
      .option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
      .option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
      .option(HBaseRelation.HBASE_TABLE_ROWKEY_NAME, "yyds")
      .save()

    spark.stop()
  }
}

查看hbase rowkey为 20200722yyds 数据校验

spark读取gbk spark读取hbase_sql_04

 可以看到数据已经写入成功!

至此SparkSQL自定义实现Hbase的读取和写入已经完成,小伙伴们都试试吧,用起来是真香~