Spark: Custom UDAF Example

Below is a simple example of how to write custom aggregate function (also referred as user defined aggregate function) in Spark. This feature is fairly new and is introduced in spark 1.5.1. Furthermore its currently missing from pyspark

In order to write a custom UDAF you need to extend UserDefinedAggregateFunctions and define following four methods:

  1. initialize — On a given node, this method is called once for each group.
  2. update — For a given group, spark will call “update” for each input record of that group.
  3. merge — if the function supports partial aggregates, spark might (as an optimization) compute partial result and combine them together
  4. evaluate — Once all the entries for a group are exhausted, spark will call evaluate to get the final result.

Depending on whether the function supports combiner option or not, the order of execution can vary in the following two ways:

  1. If the function doesn’t support partial aggregates (or combiner)
    pattern1
  2. if the function supports partial aggregates

pattern2

You can read more about the execution pattern in my earlier blog on custom UDAF in hive.

Apart from defining the above four methods you also need to define input, intermediate and final datatype. Below is a example showing how to write a custom function that computes mean.

package com.myuadfs

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
 * Created by ragrawal on 9/23/15.
 * Computes Mean 
 */

//Extend UserDefinedAggregateFunction to write custom aggregate function
//You can also specify any constructor arguments. For instance you 
//can have CustomMean(arg1: Int, arg2: String)
class CustomMean() extends UserDefinedAggregateFunction {

  // Input Data Type Schema
  def inputSchema: StructType = StructType(Array(StructField("item", DoubleType)))

  // Intermediate Schema
  def bufferSchema = StructType(Array(
    StructField("sum", DoubleType),
    StructField("cnt", LongType)
  ))

  // Returned Data Type .
  def dataType: DataType = DoubleType

  // Self-explaining
  def deterministic = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0.toDouble // set sum to zero
    buffer(1) = 0L // set number of items to 0
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    buffer(0) = buffer.getDouble(0) + input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row) = {
    buffer.getDouble(0)/buffer.getLong(1).toDouble
  }

}

Below is the code that shows how to use UDAF with dataframe.

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._

import com.myudafs.CustomMean

// define UDAF
val customMean = new CustomMean()

// create test dataset
val data = (1 to 1000).map{x:Int => x match {
case t if t <= 500 => Row("A", t.toDouble)
case t => Row("B", t.toDouble)
}}

// create schema of the test dataset
val schema = StructType(Array(
StructField("key", StringType),
StructField("value", DoubleType)
))

// construct data frame
val rdd = sc.parallelize(data)
val df = sqlContext.createDataFrame(rdd, schema)

// Calculate average value for each group
df.groupBy("key").agg(
customMean(df.col("value")).as("custom_mean"),
avg("value").as("avg")
).show()

Output should be

key custom_mean avg
A 250.5 250.5
B 750.5 750.5
—– —–

Few shortcomings of the UserDefinedAggregateFunction class

  1. Missing Generic Datatype: There is no way you can define generic data type. For instance if a function works both on String and Numeric type you will need to duplicate the functionality for both String and Numeric type. You can’t easily reuse the code by defining some generic data type. One example of such a function is random sampling of values.
  2. Non-Algebraic functions: Not all functions support merge operation such as “Median”. In such cases you don’t want to implement “merge” method. But currently you are required to overwrite “merge” method. In Pig, you explicitly indicate whether the function can is algebraic or not by implementing Algebraic interface.