Spark: Custom UDAF Example

Below is a simple example of how to write custom aggregate function (also referred as user defined aggregate function) in Spark. This feature is fairly new and is introduced in spark 1.5.1. Furthermore its currently missing from pyspark

In order to write a custom UDAF you need to extend UserDefinedAggregateFunctions and define following four methods:

  1. initialize — On a given node, this method is called once for each group.
  2. update — For a given group, spark will call “update” for each input record of that group.
  3. merge — if the function supports partial aggregates, spark might (as an optimization) compute partial result and combine them together
  4. evaluate — Once all the entries for a group are exhausted, spark will call evaluate to get the final result.

Depending on whether the function supports combiner option or not, the order of execution can vary in the following two ways:

  1. If the function doesn’t support partial aggregates (or combiner)
  2. if the function supports partial aggregates


You can read more about the execution pattern in my earlier blog on custom UDAF in hive.

Apart from defining the above four methods you also need to define input, intermediate and final datatype. Below is a example showing how to write a custom function that computes mean.

package com.myuadfs

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

 * Created by ragrawal on 9/23/15.
 * Computes Mean 

//Extend UserDefinedAggregateFunction to write custom aggregate function
//You can also specify any constructor arguments. For instance you 
//can have CustomMean(arg1: Int, arg2: String)
class CustomMean() extends UserDefinedAggregateFunction {

  // Input Data Type Schema
  def inputSchema: StructType = StructType(Array(StructField("item", DoubleType)))

  // Intermediate Schema
  def bufferSchema = StructType(Array(
    StructField("sum", DoubleType),
    StructField("cnt", LongType)

  // Returned Data Type .
  def dataType: DataType = DoubleType

  // Self-explaining
  def deterministic = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0.toDouble // set sum to zero
    buffer(1) = 0L // set number of items to 0

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    buffer(0) = buffer.getDouble(0) + input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row) = {


Below is the code that shows how to use UDAF with dataframe.

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._

import com.myudafs.CustomMean

// define UDAF
val customMean = new CustomMean()

// create test dataset
val data = (1 to 1000).map{x:Int => x match {
case t if t <= 500 => Row("A", t.toDouble)
case t => Row("B", t.toDouble)

// create schema of the test dataset
val schema = StructType(Array(
StructField("key", StringType),
StructField("value", DoubleType)

// construct data frame
val rdd = sc.parallelize(data)
val df = sqlContext.createDataFrame(rdd, schema)

// Calculate average value for each group

Output should be

key custom_mean avg
A 250.5 250.5
B 750.5 750.5
—– —–

Few shortcomings of the UserDefinedAggregateFunction class

  1. Missing Generic Datatype: There is no way you can define generic data type. For instance if a function works both on String and Numeric type you will need to duplicate the functionality for both String and Numeric type. You can’t easily reuse the code by defining some generic data type. One example of such a function is random sampling of values.
  2. Non-Algebraic functions: Not all functions support merge operation such as “Median”. In such cases you don’t want to implement “merge” method. But currently you are required to overwrite “merge” method. In Pig, you explicitly indicate whether the function can is algebraic or not by implementing Algebraic interface.

18 thoughts on “Spark: Custom UDAF Example

  1. Hi Ritesh,
    Can explain at which step which function of UDAF gets called. With above explanation I can’t get how internally it works means which method is called at what step

  2. Hello,

    Can you explain Spark UDAF with step by step. I understood about abstract methods in it. But when working with example I am not getting which method is called at which step(means for each row of input how it excutes). How it gets executed for every row. Can you explain with multiple examples


  3. Hi,
    Thanks for the post. If I want to calculate a median and not use merge and just iterate over a set of rows is there an alternative?


    1. In order to compute true median, just put together two buffer from update state into one and resort them. Note however this approach won’t scale. A better approach is to compute approximate median. You can find several implementation of that.

  4. You have mentioned “merge — if the function supports partial aggregates, spark might (as an optimization) compute partial result and combine them together”. Looking at your previous thread and as per the current example it seems like merge is reduce step part correct? Or it can run on Mapper side as well ? If no how can we enable it to run on mapper side.

    We have custom UDAF and the merge step is really expensive as it runs only on reducer side, I could not find any other thread to enable it to run on Mapper side, Spark implementation doesnt help either.

    Thanks and let me know if you need more information would be happy to provide you.

    1. hi Hardik,

      Merge doesn’t have to be part of reduce step. Reduce step is related to evaluate function. Merge is same as combiner in map reduce. If required, will be happy to talk more offline

      1. Ok that makes sense! I was wondering if there is way to do partial merge locally on the Mapper side before grouping the data by key on Combiner side and then performing another level of merge. This way some of the computation can be performed initially and it can help reduce computation time.

        Sure, we can talk offline about the use case.

        1. Regarding this, I had one more question about UDAF, in our Combiner Step or Partial Merge step we see that there are calls to InitializeBuffer and we are not sure why the buffers are getting initialized again when it has already been initialized once during update step.
          Do you know why this happens, one theory is since these buffers are getting initialized we are spending time in copying the data from buffer to the new buffers and this also causes GC overhead.

  5. Hi,

    I am trying to create a UDAF which cannot be aggregated partially. If I am giving a empty merge method I am not getting correct results. Is there any step we need to do explicit to make UDAF a non partial aggregate. ?

    Merge defination in aggregate function class :
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = ()

    1. Hi Puneeta,
      If you function doesn’t support partial aggregation then your merge method should combine the data into a single array and return that. For instance, if you update method is collecting all the input values in an array then merge will just stack them together. Let me know if this makes sense or not.

  6. Hi i have a DataSet of Track.class i want to merge all tracks that are within same interval of time for example 5 min .i.e any tracks start after a track that ends within 5 min before will be the same track.its look like fusion task.

    my input :

    | trackId | start_time | end_time |
    | 1 | 12:00:00 | 12:04:00 |
    | 2 | 12:05:00 | 12:08:00 |
    | 3 | 12:20:00 | 12:22:00 |
    output :(trackId : 1,2 are been merged since diff of start and end of each one is within 5 min)

    | trackId | start_time | end_time |
    | 1 | 12:00:00 | 12:08:00 |
    | 3 | 12:20:00 | 12:22:00 |
    so how can i do that ?

  7. Hi, I don’t want to implement method merge
    I want: initialize -> update -> update -> … -> evaluate

    What shoud I do? Thanks

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s