Post

Combining rows into an array in pyspark

Yeah, I know how to explode in Spark, but what is the opposite and how do I do it? HINT (collect_list)

Combining rows into an array in pyspark

Overview

I’ve just spent a bit of time trying to work out how to group a Spark Dataframe by a given column then aggregate up the rows into a single ArrayType column.

Given the input;

transaction_iditem
1a
1b
1c
1d
2a
2d
3c
4b
4c
4d

I want to turn that into the following;

transaction_iditems
1[a, b, c, d]
2[a, d]
3[c]
4[b, c, d]

To achieve this, I can use the following query;

1
2
3
4
5
from pyspark.sql.functions import collect_list

df = spark.sql('select transaction_id, item from transaction_data')

grouped_transactions = df.groupBy('transaction_id').agg(collect_list('item').alias('items'))
This post is licensed under CC BY 4.0 by the author.