An Efficient Way To Compute Trimmed Mean

One of my friend recently faced an interesting problem, how to efficiently compute trimmed mean on billions of data points by removing top and bottom 1% of data points? In Hive/Presto, you can easily compute percentile in a streaming and distributed fashion using approx_percentile. But, there is no inbuilt function to compute trimmed mean.

After going through internals of how approx_percentile works, I came across tdigest. It’s a very efficient way of representing distribution of data in a streaming fashion and can be leverage to compute many different kinds of statistics on data. Below is an example of a Spark UDAF leveraging python t-digest implementation (find the complete notebook with example over here)

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.types import DoubleType
from tdigest import TDigest
import json

def trimmed_mean(lb, ub):

  @pandas_udf(DoubleType(), functionType=PandasUDFType.GROUPED_AGG)
  def _wrapper(x):
    t = TDigest()
    for cur in x:
      t.update(cur)
    return t.trimmed_mean(lb, ub)

  return _wrapper

data.agg(trimmed_mean(5, 95)("Feature").alias('mean')).show()

Not only we can compute trimmed mean with a single scan of the data but can also also compute it in a distributed fashion. This allows us to reduce the amount of data that has to stream from one machine to another. Below code demonstrate how to compute partial aggregations and then combine them together.

import numpy as np
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.types import DoubleType, StringType
from tdigest import TDigest
import json

@pandas_udf(StringType(), functionType=PandasUDFType.GROUPED_AGG)
def partial(x):
  t = TDigest()
  t.batch_update(x)
  return json.dumps(t.to_dict())


def merge(lb, ub):
  @pandas_udf(DoubleType(), functionType=PandasUDFType.GROUPED_AGG)
  def _merge(x):
    t = TDigest()
    for p in x:
      t = t + TDigest().update_from_dict(json.loads(p))
    return t.trimmed_mean(lb, ub)

  return _merge


# below I am using groupby operation mainly to create multiple partitions
# TODO: replace groupby by repartition function and find a way to generate
# partial for each partition and thereafter combine tdigest
partials = data.groupby('Id').agg(partial("Feature").alias("partial"))
partials.agg(merge(5, 95)("partial").alias('TrimmedMean')).show()

You can find the link to the notebook with a example over here.

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.