今天分享一个使用sparksql的
spark.write.format("hbase").save()
spark.read.format("hbase").load()
方式读写Hbase的方法。
1、引入maven依赖
只需要引用sparksql依赖和hbase-mapreduce包,即可实现spark sql读写hbase功能。
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<scope>provided</scope>
<version>2.4.0</version>
</dependency>
<!--hbase-->
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-mapreduce</artifactId>
<version>2.1.0</version>
</dependency>2、创建包和类如下

2.1、Scala package
告诉spark框架,这个包下面的代码是spark的souce和sink,spark框架运行时,用户调用spark.read.format("hbase").load()时,会自动定位到该包下的代码执行读取和写入。
package com.spark
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter}
/**
* spark sql自定义hbase的sink和source
* @author ket
* @since 2021-11-23
*/
package object hbase {
implicit class HBaseDataFrameReader(reader: DataFrameReader) {
def hbase: DataFrame = reader.format("com.spark.hbase").load
}
implicit class HBaseDataFrameWriter[T](writer: DataFrameWriter[T]) {
def hbase:Unit = writer.format("com.spark.hbase").save
}
}上述代码的 com.spark.habse 需要根据自己包路径指定自己的路径。
2.2、SerializableConfiguration(序列化的配置,用户保存sparkSql读写habse的自定义配置)
package com.spark.hbase
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.conf.Configuration
import scala.util.control.NonFatal
class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
out.defaultWriteObject()
value.write(out)
}
private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
value = new Configuration(false)
value.readFields(in)
}
def tryOrIOException(block: => Unit) {
try {
block
} catch {
case e: IOException => throw e
case NonFatal(t) => throw new IOException(t)
}
}
}2.3、HBaseRelation(Hbase的实际读写操作在这个类里面实现)
继承spark SQL的这3个类 org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
自定义hbase连接的配置的key,到时候可以通过spark.read.format("hbase").option()的方式传入对应配置参数.
//hbase的zookeeper连接地址
val HBASE_ZK_QUORUM_KEY: String = "hbase.zookeeper.quorum"
//zookeeper端口
val HBASE_ZK_PORT_KEY: String = "hbase.zookeeper.property.clientPort"
//需要读取或者写入的hbase table名字
val HBASE_TABLE: String = "hbase.table"
//需要读取或者写入的hbase 列簇
val HBASE_TABLE_FAMILY: String = "hbase.family"
//读多个列名指定的分割符号
val SPLIT: String = ","
//读hbase才需要传入的配置,要读取的列名
val HBASE_TABLE_SELECT_FIELDS: String = "hbase.select.family.column"
//写hbase时指定dataframe哪个列作为rowkey
val HBASE_TABLE_ROWKEY_NAME: String = "hbase.rowkey.column.name"因为SparkSQL的数据类型和Hbase的数据类型不能一一对应,所以统一转换成String类型读取和写入。
类代码如下所示:
package com.spark.hbase
import com.lava.spark.hbase.HBaseRelation._
import com.lava.utils.StringUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.client.{Put, Result, Scan}
import .ImmutableBytesWritable
import org.apache.hadoop.hbase.mapreduce.{TableInputFormat, TableOutputFormat}
import org.apache.hadoop.hbase.protobuf.ProtobufUtil
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructType}
import java.util.Base64
/**
* @author Ket
* @Date 2021/11/22
*/
object HBaseRelation {
//hbase的zookeeper连接地址
val HBASE_ZK_QUORUM_KEY: String = "hbase.zookeeper.quorum"
//zookeeper端口
val HBASE_ZK_PORT_KEY: String = "hbase.zookeeper.property.clientPort"
//需要读取或者写入的hbase table名字
val HBASE_TABLE: String = "hbase.table"
//需要读取或者写入的hbase 列簇
val HBASE_TABLE_FAMILY: String = "hbase.family"
//读多个列名指定的分割符号
val SPLIT: String = ","
//读hbase才需要传入的配置,要读取的列名
val HBASE_TABLE_SELECT_FIELDS: String = "hbase.select.family.column"
//写hbase时指定dataframe哪个列作为rowkey
val HBASE_TABLE_ROWKEY_NAME: String = "hbase.rowkey.column.name"
}
case class HBaseRelation(
params: Map[String, String],
override val schema: StructType)
(@transient val sparkSession: SparkSession)
extends BaseRelation
with TableScan
with InsertableRelation
with Serializable {
override def sqlContext: SQLContext = sparkSession.sqlContext
private val wrappedConf = {
val hConf = {
val conf = HBaseConfiguration.create
val hbaseParams = params.filterKeys(_.contains("hbase"))
hbaseParams.foreach(f => conf.set(f._1, f._2))
conf
}
new SerializableConfiguration(hConf)
}
def hbaseConf: Configuration = wrappedConf.value
/**
* 读取hbase中的数据并返回RDD[ROW]
* @return
*/
override def buildScan(): RDD[Row] = {
hbaseConf.set(TableInputFormat.INPUT_TABLE, params(HBASE_TABLE))
val scan: Scan = new Scan()
val cfBytes = Bytes.toBytes(params(HBASE_TABLE_FAMILY))
scan.addFamily(cfBytes)
val fields = params(HBASE_TABLE_SELECT_FIELDS).split(SPLIT)
fields.foreach { field =>
scan.addColumn(cfBytes, Bytes.toBytes(field))
}
val scanToString = new String(Base64.getEncoder.encode(ProtobufUtil.toScan(scan).toByteArray))
hbaseConf.set(
TableInputFormat.SCAN,
scanToString
)
val dataRdd: RDD[(ImmutableBytesWritable, Result)] = sqlContext.sparkContext
.newAPIHadoopRDD(
hbaseConf,
classOf[TableInputFormat],
classOf[ImmutableBytesWritable],
classOf[Result]
)
val rowsRdd: RDD[Row] = dataRdd.mapPartitions { iter =>
iter.map { case (_, result) =>
val rowKey = Bytes.toString(result.getRow)
val values: Seq[String] = fields.map { field =>
val fieldBytes: Array[Byte] = result.getValue(cfBytes, Bytes.toBytes(field))
val fieldValue: String = Bytes.toString(fieldBytes)
fieldValue
}
val data = rowKey +: values
Row.fromSeq(data)
}
}
rowsRdd
}
/**
* 写入数据到Hbase
*/
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val columns = data.columns
val dataCols = columns.filter(!_.equals(params(HBASE_TABLE_ROWKEY_NAME)))
val cfBytes = Bytes.toBytes(params(HBASE_TABLE_FAMILY))
val stringTypeCols = columns.map(m => col(m).cast(StringType).as(m))
val putsRdd: RDD[(ImmutableBytesWritable, Put)] = data.select(stringTypeCols: _*)
.rdd.mapPartitions(iter => {
iter.map(row => {
val rowKeyValue = row.getAs[String](params(HBASE_TABLE_ROWKEY_NAME))
val rowKey = new ImmutableBytesWritable(Bytes.toBytes(rowKeyValue))
val put = new Put(rowKey.get())
dataCols.foreach(column => {
val columnValue = row.getAs[String](column)
if (StringUtils.checkStr(columnValue)) {
put.addColumn(cfBytes, Bytes.toBytes(column), Bytes.toBytes(columnValue))
}
})
(rowKey, put)
})
})
hbaseConf.set(TableOutputFormat.OUTPUT_TABLE, params(HBASE_TABLE))
val job = Job.getInstance(hbaseConf)
job.setOutputKeyClass(classOf[ImmutableBytesWritable])
job.setOutputValueClass(classOf[Result])
job.setOutputFormatClass(classOf[TableOutputFormat[ImmutableBytesWritable]])
putsRdd.saveAsNewAPIHadoopDataset(job.getConfiguration)
}
}2.4、DefaultSource (默认类名,不可更改)
因为SparkSQL源码中调用读写操作是通过反射加载 DefaultSource 类的,这个类名是源码中写死的,spark源码实际是找到我们指定的scala包,然后调用包下的这个类实现读写操作的。所以这个类名不可更改。
这个类需要实现以下3个类 org.apache.spark.sql.sources.{CreatableRelationProvider, DataSourceRegister, RelationProvider}
继承后实现以下方法可以指定数据操作的名字,用于spark.read.format("hbase")中识别,和包名一样。
override def shortName(): String = "hbase"package com.spark.hbase
import com.lava.spark.hbase.HBaseRelation.{HBASE_TABLE_SELECT_FIELDS, SPLIT}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
/**
* @author Ket
* @Date 2021/11/23
*/
class DefaultSource extends RelationProvider
with CreatableRelationProvider
with DataSourceRegister
with Serializable {
/**
* 最后读数据是调用这个方法实现的
*/
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
//HBaseRelation读取Hbase数据过来后得到的RDD[Row] 的schema在这里指定
val fields = parameters(HBASE_TABLE_SELECT_FIELDS)
.split(SPLIT)
.map { field =>
StructField(field, StringType, nullable = true)
}
val schema: StructType = StructType(
StructField("rowKey",StringType,nullable = false) +: fields)
HBaseRelation(parameters, schema)(sqlContext.sparkSession)
}
/**
* 最后写数据是调用这个方法实现的
*/
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
val relation = HBaseRelation(parameters, data.schema)(sqlContext.sparkSession)
relation.insert(data,overwrite = false)
relation
}
override def shortName(): String = "hbase"
}我们自定义SparkSQL已经开发完成,下面来进行验证。
3、验证
3.1、验证读数据
查看Hbase的表 student 数据如下

读取列name、id,代码如下所示:
object Test{
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]")
.appName("test hbase")
.getOrCreate()
val rawDF = spark.read.format("com.spark.hbase")
.option(HBaseRelation.HBASE_TABLE, "student")
.option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
.option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
.option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
.option(HBaseRelation.HBASE_TABLE_SELECT_FIELDS, "name,id")
.load()
rawDF.show()
spark.stop()
}
}idea控制台打印如下:

读数据验证完成
3.2、写数据验证
写数据我们直接把读出来的数据rowKey后面拼接 yyds 后,并指定yyds这一列为rowKey列写入
代码如下:
object Test {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]")
.appName("test hbase")
.getOrCreate()
val rawDF = spark.read.format("com.spark.hbase")
.option(HBaseRelation.HBASE_TABLE, "student")
.option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
.option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
.option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
.option(HBaseRelation.HBASE_TABLE_SELECT_FIELDS, "name,id")
.load()
import spark.implicits._
import org.apache.spark.sql.functions._
rawDF.select(concat($"rowKey", lit("yyds")).as("yyds"), $"name", $"id")
.write.format("com.spark.hbase")
.option(HBaseRelation.HBASE_TABLE, "student")
.option(HBaseRelation.HBASE_TABLE_FAMILY, "info")
.option(HBaseRelation.HBASE_ZK_QUORUM_KEY, "dev0,dev1,dev2")
.option(HBaseRelation.HBASE_ZK_PORT_KEY, "2181")
.option(HBaseRelation.HBASE_TABLE_ROWKEY_NAME, "yyds")
.save()
spark.stop()
}
}查看hbase rowkey为 20200722yyds 数据校验

可以看到数据已经写入成功!
至此SparkSQL自定义实现Hbase的读取和写入已经完成,小伙伴们都试试吧,用起来是真香~
















