Below is a simple example of how to write custom aggregate function (also referred as user defined aggregate function) in Spark. This feature is fairly new and is introduced in spark 1.5.1. Furthermore its currently missing from pyspark
In order to write a custom UDAF you need to extend UserDefinedAggregateFunctions and define following four methods:
initialize
— On a given node, this method is called once for each group.update
— For a given group, spark will call “update” for each input record of that group.merge
— if the function supports partial aggregates, spark might (as an optimization) compute partial result and combine them togetherevaluate
— Once all the entries for a group are exhausted, spark will call evaluate to get the final result.
Depending on whether the function supports combiner option or not, the order of execution can vary in the following two ways:
- If the function doesn’t support partial aggregates (or combiner)
-
if the function supports partial aggregates
You can read more about the execution pattern in my earlier blog on custom UDAF in hive.
Apart from defining the above four methods you also need to define input, intermediate and final datatype. Below is a example showing how to write a custom function that computes mean.
package com.myuadfs import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /** * Created by ragrawal on 9/23/15. * Computes Mean */ //Extend UserDefinedAggregateFunction to write custom aggregate function //You can also specify any constructor arguments. For instance you //can have CustomMean(arg1: Int, arg2: String) class CustomMean() extends UserDefinedAggregateFunction { // Input Data Type Schema def inputSchema: StructType = StructType(Array(StructField("item", DoubleType))) // Intermediate Schema def bufferSchema = StructType(Array( StructField("sum", DoubleType), StructField("cnt", LongType) )) // Returned Data Type . def dataType: DataType = DoubleType // Self-explaining def deterministic = true // This function is called whenever key changes def initialize(buffer: MutableAggregationBuffer) = { buffer(0) = 0.toDouble // set sum to zero buffer(1) = 0L // set number of items to 0 } // Iterate over each entry of a group def update(buffer: MutableAggregationBuffer, input: Row) = { buffer(0) = buffer.getDouble(0) + input.getDouble(0) buffer(1) = buffer.getLong(1) + 1 } // Merge two partial aggregates def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // Called after all the entries are exhausted. def evaluate(buffer: Row) = { buffer.getDouble(0)/buffer.getLong(1).toDouble } }
Below is the code that shows how to use UDAF with dataframe.
import org.apache.spark.sql.Row import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.functions._ import com.myudafs.CustomMean // define UDAF val customMean = new CustomMean() // create test dataset val data = (1 to 1000).map{x:Int => x match { case t if t <= 500 => Row("A", t.toDouble) case t => Row("B", t.toDouble) }} // create schema of the test dataset val schema = StructType(Array( StructField("key", StringType), StructField("value", DoubleType) )) // construct data frame val rdd = sc.parallelize(data) val df = sqlContext.createDataFrame(rdd, schema) // Calculate average value for each group df.groupBy("key").agg( customMean(df.col("value")).as("custom_mean"), avg("value").as("avg") ).show()
Output should be
key | custom_mean | avg |
---|---|---|
A | 250.5 | 250.5 |
B | 750.5 | 750.5 |
— | —– | —– |
Few shortcomings of the UserDefinedAggregateFunction class
- Missing Generic Datatype: There is no way you can define generic data type. For instance if a function works both on String and Numeric type you will need to duplicate the functionality for both String and Numeric type. You can’t easily reuse the code by defining some generic data type. One example of such a function is random sampling of values.
- Non-Algebraic functions: Not all functions support merge operation such as “Median”. In such cases you don’t want to implement “merge” method. But currently you are required to overwrite “merge” method. In Pig, you explicitly indicate whether the function can is algebraic or not by implementing Algebraic interface.