







/* Calculates 'x' modulo 'mod', takes to consideration sign of x,
* i.e. if 'x' is negative, than 'x' % 'mod' is negative too
* so function return (x % mod) + mod in that case.
def nonNegativeMod(x: Int, mod: Int): Int = {
val rawMod = x % mod
rawMod + (if (rawMod < 0) mod else 0)






* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
* If any of the RDDs already has a partitioner, choose that one.
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
* defaultParallelism, otherwise we'll use the max number of upstream partitions.
* Unless spark.default.parallelism is set, the number of partitions will be the
* same as the number of partitions in the largest upstream RDD, as this should
* be least likely to cause out-of-memory errors.
* We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
for (r <- bysize="" if="" r="" partitioner="" isdefined="" r="" partitioner="" get="" numpartitions=""> 0) {
return r.partitioner.get
if (rdd.context.conf.contains("spark.default.parallelism")) {
new HashPartitioner(rdd.context.defaultParallelism)
} else {
new HashPartitioner(bySize.head.partitions.size)




* A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
* Java's `Object.hashCode`.
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
class HashPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

def numPartitions: Int = partitions

def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)

override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>

override def hashCode: Int = numPartitions





* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
* Note that the actual number of partitions created by the RangePartitioner might not be the same
* as the `partitions` parameter, in the case where the number of sampled records is less than
* the value of `partitions`.
class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int,
rdd: RDD[_ <: product2="" k="" v="" private="" var="" ascending:="" boolean="true)" extends="" partitioner="" we="" allow="" partitions="0," which="" happens="" when="" sorting="" an="" empty="" rdd="" under="" the="" default="" settings="" require="" partitions="">= 0, s"Number of partitions cannot be negative but found $partitions.")

private var ordering = implicitly[Ordering[K]]

// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions <= 1) {
} else {
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
val sampleSize = math.min(20.0 * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
val (numItems, sketched) = RangePartitioner.sketch(, sampleSizePerPartition)
if (numItems == 0L) {
} else {
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.size).toFloat
for (key <- sample="" candidates="" key="" weight="" if="" imbalancedpartitions="" nonempty="" re-sample="" imbalanced="" partitions="" with="" the="" desired="" sampling="" probability="" val="" imbalanced="new" partitionpruningrdd="" rdd="" map="" _="" _1="" imbalancedpartitions="" contains="" val="" seed="byteswap32(" -="" 1="" val="" resampled="imbalanced.sample(withReplacement" false="" fraction="" seed="" collect="" val="" weight="(1.0" fraction="" tofloat="" candidates="" resampled="" map="" x=""> (x, weight))
RangePartitioner.determineBounds(candidates, partitions)

def numPartitions: Int = rangeBounds.length + 1

private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length <= 128) {
// If we have less than 128 partitions naive search
while (partition < rangeBounds.length &&, rangeBounds(partition))) {
partition += 1
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
if (partition > rangeBounds.length) {
partition = rangeBounds.length
if (ascending) {
} else {
rangeBounds.length - partition

override def equals(other: Any): Boolean = other match {
case r: RangePartitioner[_, _] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>

override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < rangeBounds.length) {
result = prime * result + rangeBounds(i).hashCode
i += 1
result = prime * result + ascending.hashCode

private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
case _ =>

val ser = sfactory.newInstance()
Utils.serializeViaNestedStream(out, ser) { stream =>

private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
case _ =>
ascending = in.readBoolean()
ordering = in.readObject().asInstanceOf[Ordering[K]]
binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]

val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()





import org.apache.spark.Partitioner
* Created by Jeff Yang on 2017/3/30
* Update date:
* Time: 18:03
* Describle :
* Result of Test:
* Command:
* Email:
class MySparkPartition(numParts: Int) extends Partitioner {

override def numPartitions: Int = numParts

* 可以自定义分区算法
* @param key
* @return
override def getPartition(key: Any): Int = {
val domain = new
val code = (domain.hashCode % numPartitions)
if (code < 0) {
code + numPartitions
} else {
override def equals(other: Any): Boolean = other match {
case mypartition: MySparkPartition =>
mypartition.numPartitions == numPartitions
case _ =>
override def hashCode: Int = numPartitions

* def numPartitions:这个方法需要返回你想要创建分区的个数;
* def getPartition:这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1;
* equals():这个是Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样。
* /




import org.apache.spark.{SparkConf, SparkContext}

* Created by Jeff Yang on 2017/3/30
* Update date:
* Time: 18:47
* Describle :使用自定义的分区器
* Result of Test:
* Command:
* Email:
object UseMyPartitioner {

def main(args: Array[String]) {
val conf=new SparkConf()
val sc=new SparkContext(conf)

val lines=sc.textFile("hdfs://hadoop2:8020/user/test/word.txt")
val splitMap=lines.flatMap(line=>line.split("\t")).map(word=>(word,2))//注意:RDD一定要是key-value

splitMap.partitionBy(new MySparkPartition(3)).saveAsTextFile("F:/partrion/test")


