Overview
I’ve just spent a bit of time trying to work out how to group a Spark Dataframe by a given column then aggregate up the rows into a single ArrayType
column.
Given the input;
transaction_id | item |
---|---|
1 | a |
1 | b |
1 | c |
1 | d |
2 | a |
2 | d |
3 | c |
4 | b |
4 | c |
4 | d |
I want to turn that into the following;
transaction_id | items |
---|---|
1 | [a, b, c, d] |
2 | [a, d] |
3 | [c] |
4 | [b, c, d] |
To achieve this, I can use the following query;
from pyspark.sql.functions import collect_list
df = spark.sql('select transaction_id, item from transaction_data')
grouped_transactions = df.groupBy('transaction_id').agg(collect_list('item').alias('items'))