PySpark Explained: The explode and collect_list Functions

Author:Murphy  |  View: 22128  |  Time: 2025-03-22 21:11:31

PySpark SQL, the Python interface for SQL in Apache PySpark, is a powerful set of tools for data transformation and analysis. Built to emulate the most common types of operations that are available in database SQL systems, Pyspark SQL is also able to leverage the dataframe paradigm available in Spark to offer additional functionality.

In short, Pyspark SQL provides a rich set of functions that enable developers to manipulate and process data efficiently.

Among these functions, two of the less well-known ones that I want to highlight are particularly noteworthy for their ability to transform and aggregate data in unique ways. These are the explode and collect_list operators.

In this article, I'll explain exactly what each of these does and show some use cases and sample PySpark code for each.

Explode

The explode function in PySpark SQL is a versatile tool for transforming and flattening nested data structures, such as arrays or maps, into individual rows. This function is particularly useful when working with complex datasets that contain nested collections, as it allows you to analyze and manipulate individual elements within these structures.

Arrays in Pyspark are very similar to what arrays are in other computer languages – namely a data structure that holds a collection of elements, typically of the same type, in a specific order usually in contiguous memory locations.

Maps in Spark are equivalent to dictionaries in languages such as Python. They hold a series of key-value pairs and are useful for very fast lookups of values for specific keys. We'll see examples of using Explode with both later on.

When applied to an array column, the explode function creates a new row for each element in the array, with the element value stored in a new column. By default, this new column is named col, but you can specify a custom column name using an alias.

Similarly, when applied to a map column, the explode function creates two new columns: one for the keys and another for the values. By default, these columns are named key and value, respectively, but, again, you can provide custom column names using aliases.

Collect_list

The collect_list function in PySpark SQL is an aggregation function that gathers values from a column and converts them into an array. It is particularly useful when you need to reconstruct or aggregate data that has been flattened or transformed using other PySpark SQL functions, such as explode.In many ways, it can be thought of as a complement function to explode.

This function is often used in conjunction with the groupBy operator to group the data based on one or more columns before aggregating the values.

Accessing a FREE Pyspark development environment

If you want to follow along with the code in this article, you'll need access to a PySpark development environment.

If you're lucky enough to have access to PySpark either through your work, via the cloud, or a local install, go ahead and use that. If not, please take a look at the link below where I go into detail about how you can access a great FREE online PySpark development environment called the Databricks Community Edition.

Databricks is a cloud-based platform for Data Engineering, machine learning, and analytics built around Apache Spark and provides a unified environment for working with big data workloads. The founders of Databricks created Spark so they know their stuff.

How to access a FREE online Spark development environment

Example use cases

Now that we know a bit more about what explode and collect_list do, let's consider some use cases for them.

The explode function

We'll start with using the explode function to transform an array. If you recall, in Spark an array is a data structure that stores a fixed-size sequential collection of elements of the same type.

We'll set up a Pyspark dataframe that holds the names of people in one text column and their favourite fruit to eat in an array column.

from pyspark.sql import SparkSession
from pyspark.sql.functions import explode

# Initialize Spark session
spark = SparkSession.builder.appName("ArrayExplodeExample").getOrCreate()

# Create a DataFrame with an array column - "fruits"
data = [
    ("John", ["apple", "banana", "cherry"]),
    ("Mary", ["orange", "grape"]),
    ("Jane", ["strawberry", "blueberry", "raspberry"]),
    ("Mark", ["watermelon"])
]

# Define schema and create DataFrame
df = spark.createDataFrame(data, ["name", "fruits"])

# Show the original DataFrame
df.show(truncate=False)

+----+----------------------------------+
|John|[apple, banana, cherry]           |
|Mary|[orange, grape]                   |
|Jane|[strawberry, blueberry, raspberry]|
|Mark|[watermelon]                      |
+----+----------------------------------+

For many cases where we want to analyze this data, it makes things much easier if each different combination of "name" and "fruit" is on a separate record. We can use the explode function to achieve that.

Python"># Use explode function to flatten the array column
exploded_df = df.withColumn("fruit", explode(df.fruits))

# Show the exploded DataFrame
exploded_df["name","fruit"].show(truncate=False)

+----+----------+
|name|fruit     |
+----+----------+
|John|apple     |
|John|banana    |
|John|cherry    |
|Mary|orange    |
|Mary|grape     |
|Jane|strawberry|
|Jane|blueberry |
|Jane|raspberry |
|Mark|watermelon|
+----+----------+

The data now looks more like a regular data table and is now better organised if we want to perform extra dataframe or Sql operations on it for further analysis.

Using explode to deal with maps in Pyspark is very similar.

from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, create_map, lit, col
from pyspark.sql.types import MapType, StringType

# Initialize Spark session
spark = SparkSession.builder.appName("ExplodeExample").getOrCreate()

# Sample data
data = [
    ("Tom", {"Salary": "£5000", "Bonus": "£0"}),
    ("Dick", {"Salary": "£2690", "Bonus": None}),
    ("Harry", {"Salary": "£45000", "Bonus": "£20000"})
]

# Create DataFrame
df = spark.createDataFrame(data, ["Name", "Remuneration"])

# Show original DataFrame
df.show(truncate=False)

+-------+-----------------------------------+
|Name   |Remuneration                       |
+-------+-----------------------------------+
|Tom    |{Salary -> £5000, Bonus -> £0}     |
|Dick   |{Salary -> £2690, Bonus -> null}   |
|Harry  |{Salary -> £45000, Bonus -> £20000}|
+-------+-----------------------------------+

Applying explode we get our key-value pairings split into individual records this time. As with our last example, this results in a much better organisation for further analysis.

remuneration_exploded = df.select(
    col("Name"),
    explode(col("Remuneration")).alias("key", "value")
)

# Show the transformed DataFrame
remuneration_exploded.show(truncate=False)

+-------+------+-------+
| Name  | key  | value |
+-------+------+-------+
|Tom    |Salary|£5000  |
|Tom    |Bonus |   £0  |
|Dick   |Salary|£2690  |
|Dick   |Bonus | null  |
|Harry  |Salary|£45000 |
|Harry  |Bonus |£20000 |
+-------+------+-------+

Finishing things off with a slightly more complicated example, assume we have the following PySpark dataframe.

+----+-----------+-----------+
|col1|     col2  |     col3  |
+----+-----------+-----------+
| a  | [1, 2, 3] | [4, 5, 6] |
+----+-----------+-----------+

And we would like to obtain the following output.

+------+-----+-------+
|col1  |col2  |col3  |
+------+------+------+
|   a  |   1  |   4  |
|   a  |   2  |   5  |
|   a  |   3  |   6  |
+------+------+------+

This is trickier than it looks. First, let's create our input test data.

testData = [('a',[1,2,3],[4,5,6]),]

df = spark.createDataFrame(data=testData, schema = ['col1','col2','col3'])

At first glance, you might think that you could just explodecol2 and col3 but that won't work as you can only explode one column at a time. Let's try it and you'll see what I mean.

df.select ("col1",explode("col2").alias("col2"),"col3").select("col1","col2",explode("col3").alias("col3")).show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   a|   1|   4|
|   a|   1|   5|
|   a|   1|   6|
|   a|   2|   4|
|   a|   2|   5|
|   a|   2|   6|
|   a|   3|   4|
|   a|   3|   5|
|   a|   3|   6|
+----+----+----+

Hmmm, not quite what we're looking for. To do what we want we have to have an intermediate step that, for PySpark arrays, does the equivalent of a Python zip operation. Recall that the Python zip operator takes two iterables and splices them together. For example, if we had,

numbers = [1, 2, 3]

letters = ['a', 'b', 'c']

zipped=zip(numbers,letters)

print(list(zipped))

[(1, 'a'), (2, 'b'), (3, 'c')]

The equivalent command for arrays is handily called arrays_zip. So we have to use that first to "knit" together our arrays and then do our explode. It can be implemented using Pyspark SQL or dataframe operations. Here's the solution in SQL.

from pyspark.sql.functions import *

#create a database table of our input data

df.createOrReplaceTempView("test_table")

spark.sql("select col1,tmp.col2,tmp.col3 from (select col1,explode(tmp) as tmp from (select col1,arrays_zip(col2,col3) as tmp from test_table))").show()

+------+------+------+
|col1  |col2  |col3  |
+------+------+------+
|   a  |   1  |   4  |
|   a  |   2  |   5  |
|   a  |   3  |   6  |
+------+------+------+

The collect_list function

The collect_list function takes a PySpark dataframe data stored on a record-by-record basis and returns an individual dataframe column of that data as a collection. In that sense, it does the opposite of what the explode function does. A quick example will show what I mean. Suppose we have this input data set,

testData = (['a'],['b'],['c']) 

df = spark.createDataFrame(data=testData, schema = ['letter_column']) 

df.printSchema() 

df.show() 

+-------------+ 
|letter_column| 
+-------------+ 
|            a| 
|            b| 
|            c| 
+-------------+ 

Applying collect_list to the data we get,

from pyspark.sql.functions import collect_list

df.select(collect_list("letter_column").alias("letter_row")).show()

+----------+  
|letter_row|  
+----------+  
| [a, b, c]| 
+----------+ 

Normally we're not just dealing with one column of data so, for a more complicated problem consider that we have the following PySpark dataframe showing the wholesale price of Gas and Electricity over 3 days.

+-----------+----------+------+ 
|Fuel       |      Date| Price| 
+-----------+----------+------+ 
|Gas        |2019-10-11|121.56| 
|Gas        |2019-10-10|120.56| 
|Electricity|2019-10-11|100.00| 
|Gas        |2019-10-12|119.56| 
|Electricity|2019-10-10| 99.00| 
|Electricity|2019-10-12|101.00| 
+-----------+----------+------+

We want to return a new data set in the following format. The important point to note is that the prices of each fuel from left to right are to be in date order.


+-----------+------------------------+
|Fuel       |Price_hist              |
+-----------+------------------------+
|Electricity|[99.0, 100.0, 101.0]    |
|Gas        |[120.56, 121.56, 119.56]|
+-----------+------------------------+

We'll start by writing some code to create our input data set.

data = [
    ("Gas", "2019-10-11", 121.56),
    ("Gas", "2019-10-10", 120.56),
    ("Electricity", "2019-10-11", 100.00),
    ("Gas", "2019-10-12", 119.56),
    ("Electricity", "2019-10-10", 99.00),
    ("Electricity", "2019-10-12", 101.00)
]

# Create DataFrame
df = spark.createDataFrame(data, ["Fuel", "Date", "Price"])

# Show DataFrame
df.show()

Now, running our code,

from pyspark.sql.functions import collect_list

df.select("Fuel",collect_list("Price").alias("Price Hist")).show(truncate=False)

Returns the error,

...
...
AnalysisException: [MISSING_GROUP_BY] The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses.;
Aggregate [Fuel#2, collect_list(Price#4, 0, 0) AS Price Hist#22]
+- LogicalRDD [Fuel#2, Date#3, Price#4], false

That's not good and clearly, we are going to have to perform some kind of grouping on the Fuel Name. As luck would have it, the collect_list function is actually an aggregate function so we can use the agg and groupBy operations, along with a pre-sort operation on the dataframe to get our desired result. Running this,

from pyspark.sql.functions import collect_list

# Sort by Date to ensure prices are in date order
sorted_df = df.sort("Fuel", "Date")

# Group by Stock and collect prices into a list
result_df = sorted_df.groupBy("Fuel").agg(collect_list("Price").alias("Price_hist"))

# Show the result DataFrame
result_df.show(truncate=False)

Gives us our required output.

+-----------+------------------------+
|Fuel       |Price_hist              |
+-----------+------------------------+
|Electricity|[99.0, 100.0, 101.0]    |
|Gas        |[120.56, 121.56, 119.56]|
+-----------+------------------------+

Summary

In this article, I've introduced two of PySpark SQL's more unusual data manipulation functions and given you some use cases where they can be invaluable.

Use the explode function if you need to transform array or dictionary data fields in a dataframe into their constituent parts and put them in separate records in a dataframe.

The collect_list function can be thought of as the inverse of the explode function. Use this to aggregate items from individual dataframe records into collections.

_OK, that's all for me just now. I hope you found this article useful. If you did, please check out my profile page at this link. From there, you can see my other published stories and subscribe to get notified when I post new content._

I know times are tough and wallets constrained, but if you got real value from this article, please consider buying me a wee dram.

If you liked this content, I think you'll find these articles interesting too.

SQL Explained: Common Table Expressions

Python on Steroids: The Numba Boost

Tags: Data Engineering Data Science Pyspark Python Sql

Comment