PySpark: Top N Records In Each Group

Previously I blogged about extracting top N records from each group using Hive. This post shows how to do the same in PySpark.

As compared to earlier Hive version this is much more efficient as its uses combiners (so that we can do map side computation) and further stores only N records any given time both on the mapper and reducer side.

import heapq

def takeOrderedByKey(self, num, sortValue = None, reverse=False):

        def init(a):
            return [a]

        def combine(agg, a):
            agg.append(a)
            return getTopN(agg)

        def merge(a, b):
            agg = a + b
            return getTopN(agg)

        def getTopN(agg):
            if reverse == True:
                return heapq.nlargest(num, agg, sortValue)
            else:
                return heapq.nsmallest(num, agg, sortValue)              

        return self.combineByKey(init, combine, merge)


# Create some fake student dataset. The objective is to use identify top 2 
# students in each class based on GPA scores. 
data = [
        ('ClassA','Student1', 3.89),('ClassA','Student2', 3.13),('ClassA', 'Student3',3.87),
        ('ClassB','Student1', 2.89),('ClassB','Student2', 3.13),('ClassB', 'Student3',3.97)
    ]

# Add takeOrderedByKey function to RDD class 
from pyspark.rdd import RDD
RDD.takeOrderedByKey = takeOrderedByKey

# Load dataset
rdd1 = sc.parallelize(data).map(lambda x: (x[0], x))

# extract top 2 records in each class ordered by GPA in descending order
for i in rdd1.takeOrderedByKey(2, sortValue=lambda x: x[2], reverse=True).flatMap(lambda x: x[1]).collect():
    print i

Output of the above program is:

('ClassB', 'Student3', 3.97)
('ClassB', 'Student2', 3.13)
('ClassA', 'Student1', 3.89)
('ClassA', 'Student3', 3.87)

The key line to understand is line number 22. We use combineByKey operator to split the dataset by key and then use the heap data structure to order input records by GPA score. You can find a good explanation of combineByKey operator on Adam Shinn’s blog.

Finally note that in line number 40, x in sortValue = lambda x: x[2] refers to the value of the PairRDD created at line number 37.

Posted in Programming | Tagged , , , | Leave a comment