spark下自定义UDAF的简单示例与使用

mac2026-04-02  3

自定义UDAF定义步骤如下:

继承抽象类UserDefinedAggregateFunction 重写如下八个部分:

输入类型:inputSchema: StructType 中间数据运算类型:bufferSchema: StructType 输出类型: dataType: DataType 设置输入输出数据类型是否一致:deterministic: Boolean 中间运算值初始化方法:initialize(buffer: MutableAggregationBuffer): Unit 分区内运算方法:update(buffer: MutableAggregationBuffer, input: Row): Unit 分区间数据合并运算方法:merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit 最终数据聚合方法:evaluate(buffer: Row): Any

而后session.udf.register("myavg",new MyAvg)进行注册后在sql中使用即可 具体代码如下(注释copy自源码)

import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /* return (avg) */ case class Person(name :String,age :Int) class MyAvg extends UserDefinedAggregateFunction { /** * A `StructType` represents data types of input arguments of this aggregate function. * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments * with type of `DoubleType` and `LongType`, the returned `StructType` will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this `StructType` is only used to identify the corresponding * input argument. Users can choose names to identify the input arguments. * * @since 1.5.0 */ override def inputSchema: StructType = StructType(StructField("input",IntegerType)::Nil) //buffer datatype in compute /** * A `StructType` represents data types of values in the aggregation buffer. * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values * (i.e. two intermediate values) with type of `DoubleType` and `LongType`, * the returned `StructType` will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this `StructType` is only used to identify the corresponding * buffer value. Users can choose names to identify the input arguments. * * @since 1.5.0 */ override def bufferSchema: StructType = StructType(StructField("sum",IntegerType)::StructField("count",IntegerType)::Nil) //output data type /** * The `DataType` of the returned value of this [[UserDefinedAggregateFunction]]. * * @since 1.5.0 */ override def dataType: DataType = DoubleType //stable feature /** * Returns true iff this function is deterministic, i.e. given the same input, * always return the same output. * * @since 1.5.0 */ override def deterministic: Boolean = true //buffer data initial a value /** * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. * * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. * * @since 1.5.0 */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0)=0 buffer(1)=0 } //compute in inner partition /** * Updates the given aggregation buffer `buffer` with new input data from `input`. * * This is called once per input row. * * @since 1.5.0 */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //judge data is null if(!input.isNullAt(0)){ //sum of input data buffer(0)=buffer.getInt(0)+input.getInt(0) //sum of input data count buffer(1)=buffer.getInt(1)+1 } } //compute between partitions /** * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. * * This is called when we merge two partially aggregated data together. * * @since 1.5.0 */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // buffer1(0)=buffer1.getInt(0)+buffer2.getInt(0) buffer1(1)=buffer1.getInt(1)+buffer2.getInt(1) } //last step computing /** * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given * aggregation buffer. * * @since 1.5.0 */ override def evaluate(buffer: Row): Any = { buffer.getInt(0)/buffer.getInt(1).toDouble } } object Udaf { def main(args: Array[String]): Unit = { //using of udaf val session: SparkSession = SparkSession.builder(). master("local[*]").appName("sparkSession").getOrCreate() import session.implicits._ session.udf.register("myavg",new MyAvg) val r1: RDD[String] = session.sparkContext.textFile("C:\\Users\\Administrator\\Desktop\\tmp\\spark\\sql\\data\\people.txt") val pr: RDD[Person] = r1.map(x=>Person(x.split(",")(0),x.split(",")(1).trim.toInt)) val df: Dataset[Person] = pr.toDS() df.createTempView("people") //using session.sql("select myavg(age) from people").show() session.stop() } }
最新回复(0)