PySpark Design Patterns for Data Pipelines
The five most useful design patterns for PySpark data pipelines are Factory (swap data sources without changing pipeline …
By Prabeesh Keezhathra
- 8 minutes read - 1554 wordsPerformance tuning decides whether a Spark job runs in 10 minutes or 10 hours. Most slowdowns you’ll hit in production come from the same five areas: spill, skew, shuffle, storage, and serialization. This guide walks through each one with the cause, how to spot it in the Spark UI, and the PySpark code to fix it.
The examples use PySpark, but the concepts apply to Scala and Java Spark equally well.
You’ll get the most out of this guide if you already know:
If you’re not set up yet, the Apache Spark installation guide for Ubuntu and macOS covers the setup.
| Problem | Symptom in Spark UI | First thing to try |
|---|---|---|
| Spill | “Spill (memory)” / “Spill (disk)” columns in stage metrics | Raise executor memory, use salted joins, enable AQE |
| Skew | One task takes 10× longer than the median | Salt the join key, enable AQE skew join, broadcast the small side |
| Shuffle | Wide transformations dominate the stage timeline | Repartition on join keys, broadcast small tables, prefer narrow ops |
| Storage | Lots of tiny files, slow reads | Compact with coalesce, specify schema, pick Parquet over CSV |
| Serialization | Python UDFs are the slowest stage | Replace UDFs with SQL functions or Pandas UDFs |
Spill happens when Spark can’t fit an operation’s working set in memory and starts writing temp files to disk. Disk I/O is orders of magnitude slower than memory, so every GB spilled is a proportional hit to job runtime. You’ll see it in the Spark UI under the Spill (memory) and Spill (disk) columns of a stage’s task list.
1spark.conf.set("spark.executor.memory", "16g")
2spark.conf.set("spark.driver.memory", "8g")
3# Reserve more of the executor heap for execution + storage
4spark.conf.set("spark.memory.fraction", "0.8")
A single hot key forces millions of rows through one task. Adding a random salt spreads the work across many tasks, each small enough to fit in memory.
1from pyspark.sql import functions as F
2
3# Add a salt column to both sides of the join
4df1 = df1.withColumn("salt", (F.rand() * 10).cast("int"))
5df2 = df2.withColumn("salt", (F.rand() * 10).cast("int"))
6
7result = df1.join(df2, on=["key", "salt"], how="inner").drop("salt")
AQE is the simplest win for spill. It re-plans stages at runtime using real partition statistics.
1spark.conf.set("spark.sql.adaptive.enabled", "true")
2spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
Skew is an uneven distribution of data across partitions. The job is only as fast as its slowest task, so a single skewed partition can dominate total runtime. Spot it by sorting the task list in the Spark UI by duration, a healthy stage has a narrow spread; a skewed one has a long tail.
A small amount of skew (under ~20%) is normal and not worth chasing. Beyond that, fix it.
1# Rebalance on a key column before joining
2df_rebalanced = df.repartition(10, "key")
3
4# Or persist with bucketing for repeated reads
5df.write.bucketBy(10, "key").sortBy("key").saveAsTable("my_bucketed_table")
With spark.sql.adaptive.skewJoin.enabled, Spark detects skewed partitions at runtime and splits them into smaller tasks automatically. This is the default in Spark 3.2+.
Shuffles move data across the network between executors. They’re the most expensive operations in Spark, so reducing shuffle is usually the biggest performance lever.
| Type | Examples | Shuffle? | Partition scope |
|---|---|---|---|
| Narrow | map, filter, select, withColumn, union | No | Within a single partition |
| Wide | join, groupBy, orderBy, distinct, repartition | Yes | Across partitions |
Rewriting a wide operation into narrow operations, or pushing filters above joins, directly cuts shuffle volume.
1# Repartitioning both sides on the join key means data is already collocated
2df1 = df1.repartition(10, "key")
3df2 = df2.repartition(10, "key")
4joined = df1.join(df2, "key")
If one side of a join fits in memory on every executor (roughly < 10 MB by default, tunable via spark.sql.autoBroadcastJoinThreshold), broadcast it and skip the shuffle entirely.
1from pyspark.sql.functions import broadcast
2
3joined = large_df.join(broadcast(small_df), "key")
How data is laid out on disk affects both read performance and shuffle behavior. The most common problems are tiny files and inferred schemas.
When a Spark job writes one file per partition, you can end up with thousands of small files. Each file has per-open overhead, so reads get slow. Aim for part-files between 128 MB and 1 GB.
1# Read many small files, write back as fewer larger ones
2df.coalesce(1).write.mode("overwrite").parquet("output_path")
Use coalesce when reducing partition count (it avoids a shuffle); use repartition when you need to balance size evenly and are OK with a shuffle.
coalesce(n) | repartition(n) | |
|---|---|---|
| Shuffles data? | No | Yes |
| Can increase partitions? | No (only decrease) | Yes |
| Partition size balance | Uneven | Even |
| Typical cost | Cheap | Expensive |
| Use when | You want fewer, larger output files | You need evenly-sized partitions for a downstream join/group |
Schema inference scans the data, which is slow and flaky for large inputs. Declare the schema up front:
1from pyspark.sql.types import StructType, StructField, StringType, IntegerType
2
3schema = StructType([
4 StructField("name", StringType(), True),
5 StructField("age", IntegerType(), True),
6])
7
8df = (
9 spark.read
10 .format("csv")
11 .option("header", "true")
12 .schema(schema)
13 .load("data.csv")
14)
Parquet is columnar, compressed, and stores the schema in the file. For anything you’ll read more than once, the conversion pays for itself quickly.
Serialization is how Spark moves data and code across the cluster. The biggest lever here is avoiding Python UDFs.
| Option | Serialization cost | Relative speed | When to use |
|---|---|---|---|
| SQL / DataFrame functions | None | Fastest | First choice whenever the logic is expressible in Spark functions |
SQL higher-order functions (transform, filter, aggregate) | None | Fast | Array / map column transformations |
| Pandas UDF (vectorized) | Batched via Arrow | Fast | Custom logic that must run in Python, on large batches |
| Python UDF (row-at-a-time) | Per row, JVM ↔ Python | Slow | Avoid, last resort for custom Python logic |
A Python UDF forces Spark to serialize each row out of the JVM, run it through Python, then serialize the result back. That round trip is orders of magnitude slower than a native function.
1# Fast: native Python map (no Spark involvement)
2numbers = [1, 2, 3, 4, 5]
3doubled = list(map(lambda x: x * 2, numbers))
4
5# Faster than a UDF inside Spark: use SQL functions
6df.select((df["col"] * 2).alias("doubled_col"))
Pandas UDFs send a batch of rows via Arrow, so the serialization cost is amortized across the batch instead of paid per row.
1import pandas as pd
2from pyspark.sql.functions import pandas_udf
3from pyspark.sql.types import DoubleType
4
5@pandas_udf(DoubleType())
6def double(x: pd.Series) -> pd.Series:
7 return x * 2
8
9df = df.withColumn("doubled_col", double(df["col"]))
For operations on array or map columns, Spark’s built-in higher-order functions (transform, filter, aggregate, etc.) run entirely in the JVM and skip the UDF round trip:
1from pyspark.sql.functions import expr
2
3# Double every element of an array column without a UDF
4df.select(expr("transform(my_array, x -> x * 2) as doubled_array"))
Not a performance win on its own, but it makes the UI dramatically easier to debug when you’re tuning:
1sc.setJobDescription("Processing data for analysis")
2df = df.filter(df.age > 30).collect()
The five problems compound. A job that suffers from skew is also spilling; a shuffle-heavy job usually has a storage problem feeding it. In practice, tune in this order:
Look at the task table for a stage in the Spark UI. If the Spill (memory) or Spill (disk) columns are non-zero, you’re spilling. Any non-trivial spill is worth chasing.
coalesce and repartition?coalesce(n) reduces the number of partitions without a shuffle, fast but can leave you with uneven partitions. repartition(n) does a full shuffle to evenly rebalance. Use coalesce for shrinking, repartition when you need even partition sizes.
broadcast()?When one side of a join is small enough to fit in memory on every executor, roughly under 10 MB by default, controlled by spark.sql.autoBroadcastJoinThreshold. Broadcasting skips the shuffle on the large side entirely.
Yes, since Spark 3.2. On older versions you need spark.sql.adaptive.enabled=true and, for skew handling, spark.sql.adaptive.skewJoin.enabled=true.
Each row is serialized out of the JVM, passed to a Python process, executed, then serialized back. That round trip dominates runtime. Pandas UDFs batch rows through Arrow, which amortizes the cost; SQL functions avoid Python entirely.