继上一篇:Spark分区器探索(HashPartitioner、RangePartitioner),现来看看自定义分区器。 继承 org.apache.spark.Partitioner 类并实现下面三个方法。 (1)numPartitions: Int:设置分区数。 (2)getPartition(key: Any): Int:返回给定key计算出的分区编号(0到numPartitions-1)。 (3)equals():判断分区器是否相等的方法。用于判断当前使用分区器对象与其他分区器实例是否相等相同,以此判断两个 RDD 的分区方式是否相同。
示例代码如下:
import org.apache.spark.rdd.RDD import org.apache.spark.{ SparkConf, SparkContext} class MyPartitioner(numParts:Int) extends org.apache.spark.Partitioner{ //number of partitions override def numPartitions: Int = numParts //get partition id base on key override def getPartition(key: Any): Int = { /* val ckey: String = key.toString ckey.substring(ckey.length-1).toInt%numParts*/ key.toString.toInt%numPartitions } override def equals(other: Any): Boolean = other match { case h: MyPartitioner => h.numPartitions == numPartitions case _ => false } } object partitionTest { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("TestPartitioner").setMaster("local[*]") val sc = new SparkContext(conf) //map to kv val r1: RDD[(Int, Int)] = sc.parallelize(1 to 10,5).map((_,1)) //partition by MyPatitioner, and map to resource val rp: RDD[Int] = r1.partitionBy(new MyPartitioner(2)).map(_._1) rp.mapPartitionsWithIndex((index,items)=>items.map((index,_))).collect.foreach(println) sc.stop() } }运行结果:
(0,2) (0,4) (0,6) (0,8) (0,10) (1,1) (1,3) (1,5) (1,7) (1,9)
在以上代码中创建了一个分区数为5的RDD,随后运用partitionBy(new MyPartitioner(2)重新分为两个分区并输出对应分区数据。