PySpark Design Patterns for Data Pipelines
The five most useful design patterns for PySpark data pipelines are Factory (swap data sources without changing pipeline …
This builds on basic design patterns in PySpark pipelines (factory, singleton, builder, observer, pipeline). Once those are familiar, three more patterns cover more complex cases that come up in production: switching algorithms at runtime, adding cross-cutting concerns, and sharing skeleton logic across pipeline variants.
The Strategy pattern defines a family of algorithms, encapsulates each one, and makes them interchangeable. In data pipelines this is useful when the processing step varies by input characteristics or business requirements but the surrounding code should stay the same.
1from abc import ABC, abstractmethod
2from typing import Dict, Any
3from pyspark.sql import DataFrame, SparkSession
4
5class DataProcessingStrategy(ABC):
6 """Abstract strategy for data processing."""
7
8 @abstractmethod
9 def process(self, data: DataFrame) -> DataFrame:
10 pass
11
12 @abstractmethod
13 def get_processing_info(self) -> Dict[str, Any]:
14 pass
15
16class AggregationStrategy(DataProcessingStrategy):
17 """Strategy for aggregation-based processing."""
18
19 def __init__(self, group_by_cols: list, agg_cols: Dict[str, str]):
20 self.group_by_cols = group_by_cols
21 self.agg_cols = agg_cols
22
23 def process(self, data: DataFrame) -> DataFrame:
24 return data.groupBy(self.group_by_cols).agg(self.agg_cols)
25
26 def get_processing_info(self) -> Dict[str, Any]:
27 return {
28 "strategy_type": "aggregation",
29 "group_by_columns": self.group_by_cols,
30 "aggregation_columns": self.agg_cols
31 }
32
33class FilteringStrategy(DataProcessingStrategy):
34 """Strategy for filtering-based processing."""
35
36 def __init__(self, filter_condition: str):
37 self.filter_condition = filter_condition
38
39 def process(self, data: DataFrame) -> DataFrame:
40 return data.filter(self.filter_condition)
41
42 def get_processing_info(self) -> Dict[str, Any]:
43 return {
44 "strategy_type": "filtering",
45 "filter_condition": self.filter_condition
46 }
47
48class DataProcessor:
49 """Context class that uses different processing strategies."""
50
51 def __init__(self, strategy: DataProcessingStrategy):
52 self.strategy = strategy
53
54 def set_strategy(self, strategy: DataProcessingStrategy):
55 """Change the processing strategy at runtime."""
56 self.strategy = strategy
57
58 def process_data(self, data: DataFrame) -> DataFrame:
59 return self.strategy.process(data)
60
61 def get_strategy_info(self) -> Dict[str, Any]:
62 return self.strategy.get_processing_info()
63
64# Usage example
65spark = SparkSession.builder.appName("StrategyPattern").getOrCreate()
66
67data = spark.createDataFrame(
68 [(1, "A", 100), (1, "B", 200), (2, "A", 150), (2, "B", 250)],
69 ["id", "category", "value"],
70)
71
72agg_strategy = AggregationStrategy(
73 group_by_cols=["id"],
74 agg_cols={"value": "sum", "category": "count"},
75)
76processor = DataProcessor(agg_strategy)
77processor.process_data(data).show()
78
79# Switch to filtering at runtime
80processor.set_strategy(FilteringStrategy("value > 150"))
81processor.process_data(data).show()
When to use it: when the same context needs to pick between several distinct algorithms based on runtime state or config. The caller holds the DataProcessor; the concrete strategy is a swappable detail.
The Decorator pattern wraps an object to add behaviour without changing its class. In a data pipeline this is the cleanest way to layer on cross-cutting concerns like logging, validation, or timing without polluting each transformation class.
1from abc import ABC, abstractmethod
2from typing import Callable, Dict
3import time
4import logging
5
6class DataTransformation(ABC):
7 """Abstract base class for data transformations."""
8
9 @abstractmethod
10 def transform(self, data: DataFrame) -> DataFrame:
11 pass
12
13class BaseTransformation(DataTransformation):
14 """Trivial base transformation that just returns the data unchanged."""
15
16 def __init__(self, name: str):
17 self.name = name
18
19 def transform(self, data: DataFrame) -> DataFrame:
20 return data
21
22class TransformationDecorator(DataTransformation):
23 """Base decorator that delegates to the wrapped transformation."""
24
25 def __init__(self, transformation: DataTransformation):
26 self._transformation = transformation
27
28 @property
29 def name(self) -> str:
30 return getattr(self._transformation, "name", "unknown")
31
32 def transform(self, data: DataFrame) -> DataFrame:
33 return self._transformation.transform(data)
34
35class LoggingDecorator(TransformationDecorator):
36 """Decorator that adds logging around the wrapped transformation."""
37
38 def transform(self, data: DataFrame) -> DataFrame:
39 start = time.time()
40 logging.info(f"Starting transformation: {self.name}")
41 result = self._transformation.transform(data)
42 logging.info(f"Finished transformation: {self.name} ({time.time() - start:.2f}s)")
43 return result
44
45class ValidationDecorator(TransformationDecorator):
46 """Decorator that runs column-level validation rules before transforming."""
47
48 def __init__(self, transformation: DataTransformation, validation_rules: Dict[str, Callable]):
49 super().__init__(transformation)
50 self.validation_rules = validation_rules
51
52 def transform(self, data: DataFrame) -> DataFrame:
53 for column, validation_func in self.validation_rules.items():
54 if column in data.columns:
55 invalid_count = data.filter(~validation_func(data[column])).count()
56 if invalid_count > 0:
57 logging.warning(f"Found {invalid_count} invalid values in column {column}")
58 return self._transformation.transform(data)
59
60# Usage
61def is_positive(col):
62 return col > 0
63
64def is_not_null(col):
65 return col.isNotNull()
66
67base = BaseTransformation("data_cleaning")
68validated = ValidationDecorator(base, {"value": is_positive, "id": is_not_null})
69logged = LoggingDecorator(validated)
70
71result = logged.transform(data)
Each decorator wraps one concern (logging, validation). Stack them in whatever order you need. Adding a new concern (e.g. timing, retry, caching) is a new class plus one more wrapper call at the construction site; the existing transformations don’t change.
When to use it: cross-cutting concerns that apply to many transformations but don’t belong inside any of them.
The Template Method pattern defines the skeleton of an algorithm in a base class and lets subclasses fill in specific steps. In a data pipeline this is the natural fit for “every job follows the same shape (validate → preprocess → apply logic → postprocess → validate output), but each job’s steps differ.”
1from abc import ABC, abstractmethod
2from pyspark.sql import DataFrame
3from pyspark.sql import functions as F
4import logging
5
6class DataPipelineTemplate(ABC):
7 """Template class for data pipeline workflows."""
8
9 def run_pipeline(self, input_data: DataFrame) -> DataFrame:
10 """Template method that fixes the order of pipeline stages."""
11 try:
12 validated_data = self.validate_input(input_data)
13 preprocessed_data = self.preprocess_data(validated_data)
14 processed_data = self.apply_business_logic(preprocessed_data)
15 postprocessed_data = self.postprocess_data(processed_data)
16 final_data = self.validate_output(postprocessed_data)
17 self.log_results(final_data)
18 return final_data
19 except Exception as e:
20 self.handle_error(e)
21 raise
22
23 @abstractmethod
24 def validate_input(self, data: DataFrame) -> DataFrame: ...
25
26 @abstractmethod
27 def preprocess_data(self, data: DataFrame) -> DataFrame: ...
28
29 @abstractmethod
30 def apply_business_logic(self, data: DataFrame) -> DataFrame: ...
31
32 @abstractmethod
33 def postprocess_data(self, data: DataFrame) -> DataFrame: ...
34
35 @abstractmethod
36 def validate_output(self, data: DataFrame) -> DataFrame: ...
37
38 def log_results(self, data: DataFrame):
39 logging.info(f"Pipeline completed. Output rows: {data.count()}")
40
41 def handle_error(self, error: Exception):
42 logging.error(f"Pipeline failed: {error}")
43
44class SalesDataPipeline(DataPipelineTemplate):
45 """Concrete implementation for sales data."""
46
47 REQUIRED_COLUMNS = ["sale_id", "product_id", "amount", "date"]
48
49 def validate_input(self, data: DataFrame) -> DataFrame:
50 missing = [c for c in self.REQUIRED_COLUMNS if c not in data.columns]
51 if missing:
52 raise ValueError(f"Missing required columns: {missing}")
53 return data
54
55 def preprocess_data(self, data: DataFrame) -> DataFrame:
56 return (
57 data.dropDuplicates(["sale_id"])
58 .withColumn("date", F.to_date("date"))
59 .withColumn("year", F.year("date"))
60 .withColumn("month", F.month("date"))
61 )
62
63 def apply_business_logic(self, data: DataFrame) -> DataFrame:
64 daily = data.groupBy("date").agg(
65 F.sum("amount").alias("daily_total"),
66 F.count("*").alias("daily_transactions"),
67 )
68 return data.join(daily, "date", "left")
69
70 def postprocess_data(self, data: DataFrame) -> DataFrame:
71 return data.withColumn("amount", F.round("amount", 2))
72
73 def validate_output(self, data: DataFrame) -> DataFrame:
74 negative = data.filter(F.col("amount") < 0).count()
75 if negative > 0:
76 logging.warning(f"Found {negative} rows with negative amounts")
77 return data
78
79# Usage
80sales_pipeline = SalesDataPipeline()
81result = sales_pipeline.run_pipeline(sales_df)
The base class fixes the stage order and the error-handling shape; every subclass only has to fill in what each stage means for its domain. Adding a second subclass (say LogDataPipeline) reuses the whole scaffold for free.
When to use it: any time you have several pipeline variants that share a common shape and differ only in the content of each stage.
Strategy, decorator, and template method each solve a different problem: runtime algorithm swap, orthogonal concerns, and shared skeletons. Combined with the five basic patterns, they cover most of the structural shapes you’ll reach for in a production pipeline. Start simple; only reach for an advanced pattern when the basic one stops fitting.