PySpark: Top N Records In Each Group

Previously I blogged about extracting top N records from each group using Hive. This post shows how to do the same in PySpark.

As compared to earlier Hive version this is much more efficient as its uses combiners (so that we can do map side computation) and further stores only N records any given time both on the mapper and reducer side.

import heapq

def takeOrderedByKey(self, num, sortValue = None, reverse=False):

        def init(a):
            return [a]

        def combine(agg, a):
            agg.append(a)
            return getTopN(agg)

        def merge(a, b):
            agg = a + b
            return getTopN(agg)

        def getTopN(agg):
            if reverse == True:
                return heapq.nlargest(num, agg, sortValue)
            else:
                return heapq.nsmallest(num, agg, sortValue)              

        return self.combineByKey(init, combine, merge)


# Create some fake student dataset. The objective is to use identify top 2 
# students in each class based on GPA scores. 
data = [
        ('ClassA','Student1', 3.89),('ClassA','Student2', 3.13),('ClassA', 'Student3',3.87),
        ('ClassB','Student1', 2.89),('ClassB','Student2', 3.13),('ClassB', 'Student3',3.97)
    ]

# Add takeOrderedByKey function to RDD class 
from pyspark.rdd import RDD
RDD.takeOrderedByKey = takeOrderedByKey

# Load dataset
rdd1 = sc.parallelize(data).map(lambda x: (x[0], x))

# extract top 2 records in each class ordered by GPA in descending order
for i in rdd1.takeOrderedByKey(2, sortValue=lambda x: x[2], reverse=True).flatMap(lambda x: x[1]).collect():
    print i

Output of the above program is:

('ClassB', 'Student3', 3.97)
('ClassB', 'Student2', 3.13)
('ClassA', 'Student1', 3.89)
('ClassA', 'Student3', 3.87)

The key line to understand is line number 22. We use combineByKey operator to split the dataset by key and then use the heap data structure to order input records by GPA score. You can find a good explanation of combineByKey operator on Adam Shinn’s blog.

Finally note that in line number 40, x in sortValue = lambda x: x[2] refers to the value of the PairRDD created at line number 37.

3 thoughts on “PySpark: Top N Records In Each Group

  1. i think this should work:

    def combine(agg, a):
    agg.append(a)
    return agg

    def merge(a, b):
    agg = a + b
    return getTopN(agg)

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s