Advanced PySpark Design Patterns: Real-World Implementation Examples
Building upon our previous discussion of basic design patterns in PySpark data pipelines,Improve PySpark Data Pipelines …
Building upon the fundamental performance tuning concepts covered in our previous blog post on Performance Tuning on Apache Spark, this bonus article explores advanced optimization techniques that can dramatically improve PySpark pipeline performance in production environments. While the previous post focused on essential concepts like spill prevention, skew handling, shuffle optimization, storage management, and serialization, this article delves into modern PySpark features, sophisticated optimization strategies, and production-ready implementations that go beyond basic tuning.
If you haven’t read our foundational performance tuning guide yet, we recommend starting there to understand the basics of Apache Spark optimization, including techniques for preventing spills, reducing data skew, minimizing shuffle operations, optimizing storage, and improving serialization efficiency.
Modern PySpark versions include powerful adaptive query execution capabilities that automatically optimize query plans based on runtime statistics. Understanding and leveraging these features is crucial for production performance and builds upon the manual optimization techniques discussed in our previous performance tuning guide.
1from pyspark.sql import SparkSession
2from pyspark.sql.functions import col, broadcast
3
4class AdaptiveQueryOptimizer:
5 """Advanced optimizer leveraging PySpark's Adaptive Query Execution."""
6
7 def __init__(self, spark: SparkSession):
8 self.spark = spark
9 self._configure_aqe()
10
11 def _configure_aqe(self):
12 """Configure Adaptive Query Execution for optimal performance."""
13 # Enable AQE and related optimizations
14 self.spark.conf.set("spark.sql.adaptive.enabled", "true")
15 self.spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
16 self.spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
17 self.spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
18
19 # Advanced AQE configurations
20 self.spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128m")
21 self.spark.conf.set("spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold", "0")
22 self.spark.conf.set("spark.sql.adaptive.forceOptimizeSkewedJoin", "true")
23 self.spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")
24 self.spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
25
26 # Enable cost-based optimization
27 self.spark.conf.set("spark.sql.cbo.enabled", "true")
28 self.spark.conf.set("spark.sql.cbo.joinReorder.enabled", "true")
29 self.spark.conf.set("spark.sql.cbo.joinReorder.dp.threshold", "12")
30
31 def optimize_join_strategy(self, left_df, right_df, join_cols, join_type="inner"):
32 """Intelligently choose the best join strategy based on data characteristics."""
33
34 # Get table statistics for cost-based optimization
35 left_stats = self._get_table_statistics(left_df)
36 right_stats = self._get_table_statistics(right_df)
37
38 # Determine optimal join strategy
39 if self._should_broadcast(left_stats, right_stats):
40 return left_df.join(broadcast(right_df), join_cols, join_type)
41 elif self._should_sort_merge_join(left_stats, right_stats):
42 return self._optimize_sort_merge_join(left_df, right_df, join_cols, join_type)
43 else:
44 return left_df.join(right_df, join_cols, join_type)
45
46 def _get_table_statistics(self, df):
47 """Get table statistics for optimization decisions."""
48 # This is a simplified implementation
49 # In production, you'd use Spark's statistics API
50 return {
51 "size_bytes": df.count() * 100, # Rough estimate
52 "partition_count": df.rdd.getNumPartitions(),
53 "skew_factor": self._calculate_skew_factor(df)
54 }
55
56 def _should_broadcast(self, left_stats, right_stats):
57 """Determine if broadcast join is optimal."""
58 return right_stats["size_bytes"] < 10 * 1024 * 1024 # 10MB threshold
59
60 def _should_sort_merge_join(self, left_stats, right_stats):
61 """Determine if sort-merge join is optimal."""
62 return (left_stats["size_bytes"] > 100 * 1024 * 1024 and
63 right_stats["size_bytes"] > 100 * 1024 * 1024)
64
65 def _optimize_sort_merge_join(self, left_df, right_df, join_cols, join_type):
66 """Optimize sort-merge join with proper partitioning."""
67 # Repartition both DataFrames for optimal sort-merge join
68 repartitioned_left = left_df.repartitionByRange(len(join_cols), *join_cols)
69 repartitioned_right = right_df.repartitionByRange(len(join_cols), *join_cols)
70
71 return repartitioned_left.join(repartitioned_right, join_cols, join_type)
72
73 def _calculate_skew_factor(self, df):
74 """Calculate data skew factor."""
75 # Simplified skew calculation
76 return 1.0 # Placeholder
77
78# Usage example
79spark = SparkSession.builder.appName("AdvancedOptimization").getOrCreate()
80optimizer = AdaptiveQueryOptimizer(spark)
81
82# Optimize complex join operations
83large_table = spark.read.parquet("/path/to/large/table")
84medium_table = spark.read.parquet("/path/to/medium/table")
85
86optimized_join = optimizer.optimize_join_strategy(
87 large_table,
88 medium_table,
89 ["user_id", "date"]
90)
Modern PySpark can automatically optimize queries by pushing down predicates and pruning partitions at runtime. This advanced technique builds upon the storage optimization concepts we discussed in our performance tuning guide, taking them to the next level with automatic optimization.
1from pyspark.sql.functions import year, month, dayofmonth, to_date
2from datetime import datetime, timedelta
3
4class PartitionOptimizer:
5 """Optimizer for partition-based performance improvements."""
6
7 def __init__(self, spark: SparkSession):
8 self.spark = spark
9 self._configure_partition_optimizations()
10
11 def _configure_partition_optimizations(self):
12 """Configure partition-related optimizations."""
13 # Enable dynamic partition pruning
14 self.spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
15 self.spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.useStats", "true")
16 self.spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.fallbackFilterRatio", "0.5")
17
18 # Enable predicate pushdown
19 self.spark.conf.set("spark.sql.parquet.filterPushdown", "true")
20 self.spark.conf.set("spark.sql.parquet.mergeSchema", "false")
21 self.spark.conf.set("spark.sql.parquet.enableVectorizedReader", "true")
22
23 def create_time_partitioned_table(self, df, timestamp_col, base_path):
24 """Create a time-partitioned table optimized for queries."""
25
26 # Add time-based partition columns
27 partitioned_df = df.withColumn("year", year(col(timestamp_col))) \
28 .withColumn("month", month(col(timestamp_col))) \
29 .withColumn("day", dayofmonth(col(timestamp_col)))
30
31 # Write with optimal partitioning
32 partitioned_df.write \
33 .partitionBy("year", "month", "day") \
34 .mode("overwrite") \
35 .parquet(base_path)
36
37 return partitioned_df
38
39 def optimize_time_range_query(self, table_path, start_date, end_date,
40 additional_filters=None):
41 """Optimize queries with time range filters."""
42
43 # Create date range for partition pruning
44 start_dt = datetime.strptime(start_date, "%Y-%m-%d")
45 end_dt = datetime.strptime(end_date, "%Y-%m-%d")
46
47 # Build partition filters for dynamic pruning
48 partition_filters = []
49 current_dt = start_dt
50 while current_dt <= end_dt:
51 partition_filters.append(
52 f"(year = {current_dt.year} AND month = {current_dt.month} AND day = {current_dt.day})"
53 )
54 current_dt += timedelta(days=1)
55
56 # Combine partition filters
57 partition_condition = " OR ".join(partition_filters)
58
59 # Read with partition pruning
60 df = self.spark.read.parquet(table_path)
61
62 # Apply partition filters
63 if partition_filters:
64 df = df.filter(partition_condition)
65
66 # Apply additional filters
67 if additional_filters:
68 for filter_condition in additional_filters:
69 df = df.filter(filter_condition)
70
71 return df
72
73 def optimize_multi_table_join_with_pruning(self, fact_table_path,
74 dimension_table_paths,
75 join_conditions,
76 date_filter):
77 """Optimize multi-table joins with partition pruning."""
78
79 # Read fact table with partition pruning
80 fact_df = self.optimize_time_range_query(
81 fact_table_path,
82 date_filter["start_date"],
83 date_filter["end_date"]
84 )
85
86 # Read dimension tables
87 dimension_dfs = {}
88 for table_name, table_path in dimension_table_paths.items():
89 dimension_dfs[table_name] = self.spark.read.parquet(table_path)
90
91 # Perform optimized joins
92 result_df = fact_df
93 for table_name, join_condition in join_conditions.items():
94 result_df = result_df.join(
95 dimension_dfs[table_name],
96 join_condition,
97 "left"
98 )
99
100 return result_df
101
102# Usage example
103partition_optimizer = PartitionOptimizer(spark)
104
105# Create time-partitioned table
106sales_data = spark.read.csv("/path/to/sales/data")
107partitioned_sales = partition_optimizer.create_time_partitioned_table(
108 sales_data,
109 "transaction_date",
110 "/path/to/partitioned/sales"
111)
112
113# Optimize time range query
114filtered_sales = partition_optimizer.optimize_time_range_query(
115 "/path/to/partitioned/sales",
116 "2024-01-01",
117 "2024-01-31",
118 additional_filters=["amount > 1000", "region = 'North'"]
119)
While our previous performance tuning guide covered basic memory management concepts like preventing spills and using appropriate storage levels, this section explores sophisticated caching frameworks and intelligent memory management strategies for production environments.
1from pyspark.storagelevel import StorageLevel
2from typing import Dict, Any, Optional
3import time
4import psutil
5
6class IntelligentCacheManager:
7 """Advanced cache management with memory monitoring and eviction strategies."""
8
9 def __init__(self, spark: SparkSession, max_cache_size_gb: float = 10.0):
10 self.spark = spark
11 self.max_cache_size_bytes = max_cache_size_gb * 1024 * 1024 * 1024
12 self.cache_registry: Dict[str, Dict[str, Any]] = {}
13 self.access_patterns: Dict[str, list] = {}
14
15 def cache_with_strategy(self, df, name: str,
16 access_frequency: str = "medium",
17 data_volatility: str = "low") -> Any:
18 """Cache DataFrame with intelligent strategy selection."""
19
20 # Determine optimal storage level based on characteristics
21 storage_level = self._select_storage_level(access_frequency, data_volatility)
22
23 # Check memory availability
24 if not self._check_memory_availability(df, storage_level):
25 self._evict_least_valuable_cache()
26
27 # Cache the DataFrame
28 cached_df = df.persist(storage_level)
29
30 # Register cache entry
31 self.cache_registry[name] = {
32 "dataframe": cached_df,
33 "storage_level": storage_level,
34 "access_frequency": access_frequency,
35 "data_volatility": data_volatility,
36 "cache_time": time.time(),
37 "access_count": 0,
38 "last_access": time.time(),
39 "estimated_size": self._estimate_dataframe_size(df)
40 }
41
42 self.access_patterns[name] = []
43
44 return cached_df
45
46 def _select_storage_level(self, access_frequency: str, data_volatility: str) -> StorageLevel:
47 """Select optimal storage level based on access patterns."""
48
49 if access_frequency == "high" and data_volatility == "low":
50 return StorageLevel.MEMORY_ONLY
51 elif access_frequency == "high" and data_volatility == "high":
52 return StorageLevel.MEMORY_AND_DISK
53 elif access_frequency == "medium":
54 return StorageLevel.MEMORY_AND_DISK_SER
55 else:
56 return StorageLevel.DISK_ONLY
57
58 def _check_memory_availability(self, df, storage_level: StorageLevel) -> bool:
59 """Check if sufficient memory is available for caching."""
60
61 estimated_size = self._estimate_dataframe_size(df)
62 current_cache_size = sum(
63 entry["estimated_size"] for entry in self.cache_registry.values()
64 )
65
66 return (current_cache_size + estimated_size) <= self.max_cache_size_bytes
67
68 def _estimate_dataframe_size(self, df) -> int:
69 """Estimate DataFrame size in bytes."""
70 # Simplified estimation - in production, use more sophisticated methods
71 sample_size = min(1000, df.count())
72 if sample_size == 0:
73 return 0
74
75 sample_df = df.limit(sample_size)
76 sample_bytes = sample_df.rdd.map(lambda row: len(str(row))).sum()
77 total_rows = df.count()
78
79 return int((sample_bytes / sample_size) * total_rows)
80
81 def _evict_least_valuable_cache(self):
82 """Evict least valuable cache entries based on access patterns."""
83
84 if not self.cache_registry:
85 return
86
87 # Calculate cache value scores
88 cache_scores = {}
89 current_time = time.time()
90
91 for name, entry in self.cache_registry.items():
92 # Score based on access frequency, recency, and size
93 time_factor = 1.0 / (current_time - entry["last_access"] + 1)
94 access_factor = entry["access_count"]
95 size_factor = 1.0 / (entry["estimated_size"] + 1)
96
97 cache_scores[name] = (access_factor * time_factor * size_factor)
98
99 # Find least valuable cache
100 least_valuable = min(cache_scores.items(), key=lambda x: x[1])[0]
101
102 # Evict least valuable cache
103 self.uncache_dataframe(least_valuable)
104
105 def get_cached_dataframe(self, name: str):
106 """Get cached DataFrame and update access patterns."""
107
108 if name in self.cache_registry:
109 entry = self.cache_registry[name]
110 entry["access_count"] += 1
111 entry["last_access"] = time.time()
112
113 self.access_patterns[name].append(time.time())
114
115 return entry["dataframe"]
116 else:
117 raise KeyError(f"DataFrame '{name}' not found in cache")
118
119 def uncache_dataframe(self, name: str):
120 """Uncache DataFrame and remove from registry."""
121
122 if name in self.cache_registry:
123 entry = self.cache_registry[name]
124 entry["dataframe"].unpersist()
125 del self.cache_registry[name]
126
127 if name in self.access_patterns:
128 del self.access_patterns[name]
129
130 def get_cache_statistics(self) -> Dict[str, Any]:
131 """Get comprehensive cache statistics."""
132
133 total_size = sum(entry["estimated_size"] for entry in self.cache_registry.values())
134 total_accesses = sum(entry["access_count"] for entry in self.cache_registry.values())
135
136 return {
137 "total_cached_entries": len(self.cache_registry),
138 "total_cache_size_bytes": total_size,
139 "total_cache_size_gb": total_size / (1024**3),
140 "total_accesses": total_accesses,
141 "cache_hit_rate": self._calculate_hit_rate(),
142 "memory_utilization": total_size / self.max_cache_size_bytes,
143 "entries": {
144 name: {
145 "size_gb": entry["estimated_size"] / (1024**3),
146 "access_count": entry["access_count"],
147 "storage_level": str(entry["storage_level"]),
148 "cache_age_hours": (time.time() - entry["cache_time"]) / 3600
149 }
150 for name, entry in self.cache_registry.items()
151 }
152 }
153
154 def _calculate_hit_rate(self) -> float:
155 """Calculate cache hit rate."""
156 total_accesses = sum(entry["access_count"] for entry in self.cache_registry.values())
157 return total_accesses / max(total_accesses, 1)
158
159# Usage example
160cache_manager = IntelligentCacheManager(spark, max_cache_size_gb=20.0)
161
162# Cache frequently accessed data with intelligent strategy
163user_profile_data = spark.read.parquet("/path/to/user/profiles")
164cached_profiles = cache_manager.cache_with_strategy(
165 user_profile_data,
166 "user_profiles",
167 access_frequency="high",
168 data_volatility="low"
169)
170
171# Cache medium-frequency data
172transaction_data = spark.read.parquet("/path/to/transactions")
173cached_transactions = cache_manager.cache_with_strategy(
174 transaction_data,
175 "transactions",
176 access_frequency="medium",
177 data_volatility="high"
178)
179
180# Get cache statistics
181stats = cache_manager.get_cache_statistics()
182print(f"Cache utilization: {stats['memory_utilization']:.2%}")
183print(f"Hit rate: {stats['cache_hit_rate']:.2%}")
Understanding and influencing query planning can significantly improve performance. This section builds upon the shuffle and join optimization concepts from our performance tuning guide by introducing sophisticated query plan analysis and optimization techniques.
1from pyspark.sql import DataFrame
2from typing import List, Dict, Any
3import re
4
5class QueryPlanOptimizer:
6 """Analyze and optimize Spark query plans for better performance."""
7
8 def __init__(self, spark: SparkSession):
9 self.spark = spark
10
11 def analyze_query_plan(self, df: DataFrame) -> Dict[str, Any]:
12 """Analyze query plan and identify optimization opportunities."""
13
14 # Get the logical plan
15 logical_plan = df._jdf.queryExecution().analyzed()
16
17 # Get the optimized plan
18 optimized_plan = df._jdf.queryExecution().optimizedPlan()
19
20 # Get the physical plan
21 physical_plan = df._jdf.queryExecution().executedPlan()
22
23 # Analyze plan characteristics
24 analysis = {
25 "logical_plan": str(logical_plan),
26 "optimized_plan": str(optimized_plan),
27 "physical_plan": str(physical_plan),
28 "optimization_opportunities": self._identify_optimizations(physical_plan),
29 "estimated_cost": self._estimate_query_cost(physical_plan),
30 "shuffle_operations": self._count_shuffle_operations(physical_plan),
31 "broadcast_joins": self._count_broadcast_joins(physical_plan),
32 "sort_merge_joins": self._count_sort_merge_joins(physical_plan)
33 }
34
35 return analysis
36
37 def _identify_optimizations(self, physical_plan) -> List[str]:
38 """Identify potential optimization opportunities."""
39
40 plan_str = str(physical_plan)
41 opportunities = []
42
43 # Check for expensive operations
44 if "SortMergeJoin" in plan_str:
45 opportunities.append("Consider broadcast join for smaller tables")
46
47 if "ShuffleExchange" in plan_str:
48 opportunities.append("Consider repartitioning to reduce shuffle")
49
50 if "BroadcastHashJoin" in plan_str and "BroadcastExchange" in plan_str:
51 opportunities.append("Broadcast join detected - ensure table size is appropriate")
52
53 if "FileScan" in plan_str and "PartitionFilters" not in plan_str:
54 opportunities.append("Consider adding partition filters for better pruning")
55
56 return opportunities
57
58 def _estimate_query_cost(self, physical_plan) -> Dict[str, Any]:
59 """Estimate query execution cost."""
60
61 plan_str = str(physical_plan)
62
63 # Count expensive operations
64 shuffle_count = plan_str.count("ShuffleExchange")
65 sort_count = plan_str.count("Sort")
66 join_count = plan_str.count("Join")
67
68 # Calculate cost score
69 cost_score = shuffle_count * 10 + sort_count * 5 + join_count * 3
70
71 return {
72 "cost_score": cost_score,
73 "shuffle_operations": shuffle_count,
74 "sort_operations": sort_count,
75 "join_operations": join_count,
76 "complexity_level": self._get_complexity_level(cost_score)
77 }
78
79 def _get_complexity_level(self, cost_score: int) -> str:
80 """Determine query complexity level."""
81 if cost_score < 10:
82 return "Low"
83 elif cost_score < 30:
84 return "Medium"
85 else:
86 return "High"
87
88 def _count_shuffle_operations(self, physical_plan) -> int:
89 """Count shuffle operations in the plan."""
90 return str(physical_plan).count("ShuffleExchange")
91
92 def _count_broadcast_joins(self, physical_plan) -> int:
93 """Count broadcast joins in the plan."""
94 return str(physical_plan).count("BroadcastHashJoin")
95
96 def _count_sort_merge_joins(self, physical_plan) -> int:
97 """Count sort-merge joins in the plan."""
98 return str(physical_plan).count("SortMergeJoin")
99
100 def optimize_query_with_hints(self, df: DataFrame,
101 join_strategy_hints: Dict[str, str] = None,
102 repartition_hints: Dict[str, int] = None) -> DataFrame:
103 """Apply optimization hints to improve query performance."""
104
105 # Apply join strategy hints
106 if join_strategy_hints:
107 for table_name, strategy in join_strategy_hints.items():
108 if strategy == "broadcast":
109 df = df.hint("broadcast", table_name)
110 elif strategy == "shuffle_hash":
111 df = df.hint("shuffle_hash", table_name)
112 elif strategy == "shuffle_replicate_nl":
113 df = df.hint("shuffle_replicate_nl", table_name)
114
115 # Apply repartition hints
116 if repartition_hints:
117 for table_name, num_partitions in repartition_hints.items():
118 df = df.hint("repartition", num_partitions)
119
120 return df
121
122 def compare_query_plans(self, original_df: DataFrame,
123 optimized_df: DataFrame) -> Dict[str, Any]:
124 """Compare two query plans to measure optimization effectiveness."""
125
126 original_analysis = self.analyze_query_plan(original_df)
127 optimized_analysis = self.analyze_query_plan(optimized_df)
128
129 comparison = {
130 "original_cost_score": original_analysis["estimated_cost"]["cost_score"],
131 "optimized_cost_score": optimized_analysis["estimated_cost"]["cost_score"],
132 "cost_improvement": (
133 original_analysis["estimated_cost"]["cost_score"] -
134 optimized_analysis["estimated_cost"]["cost_score"]
135 ),
136 "shuffle_reduction": (
137 original_analysis["shuffle_operations"] -
138 optimized_analysis["shuffle_operations"]
139 ),
140 "join_strategy_changes": {
141 "original_broadcast_joins": original_analysis["broadcast_joins"],
142 "optimized_broadcast_joins": optimized_analysis["broadcast_joins"],
143 "original_sort_merge_joins": original_analysis["sort_merge_joins"],
144 "optimized_sort_merge_joins": optimized_analysis["sort_merge_joins"]
145 }
146 }
147
148 return comparison
149
150# Usage example
151query_optimizer = QueryPlanOptimizer(spark)
152
153# Analyze complex query
154complex_query = large_table.join(medium_table, "user_id") \
155 .groupBy("category") \
156 .agg({"amount": "sum"})
157
158analysis = query_optimizer.analyze_query_plan(complex_query)
159print(f"Query complexity: {analysis['estimated_cost']['complexity_level']}")
160print(f"Optimization opportunities: {analysis['optimization_opportunities']}")
161
162# Apply optimization hints
163optimized_query = query_optimizer.optimize_query_with_hints(
164 complex_query,
165 join_strategy_hints={"medium_table": "broadcast"},
166 repartition_hints={"large_table": 200}
167)
168
169# Compare plans
170comparison = query_optimizer.compare_query_plans(complex_query, optimized_query)
171print(f"Cost improvement: {comparison['cost_improvement']}")
Comprehensive monitoring is essential for maintaining optimal performance in production. This section extends the monitoring concepts mentioned in our performance tuning guide with advanced analytics and alerting capabilities.
1import time
2import psutil
3from datetime import datetime, timedelta
4from typing import Dict, List, Any
5import json
6
7class ProductionPerformanceMonitor:
8 """Comprehensive performance monitoring for production PySpark applications."""
9
10 def __init__(self, spark: SparkSession):
11 self.spark = spark
12 self.operation_metrics: Dict[str, List[Dict[str, Any]]] = {}
13 self.system_metrics: List[Dict[str, Any]] = []
14 self.alert_thresholds = {
15 "execution_time_seconds": 300, # 5 minutes
16 "memory_usage_percent": 80,
17 "cpu_usage_percent": 90,
18 "disk_io_percent": 85
19 }
20
21 def monitor_operation(self, operation_name: str, operation_func,
22 *args, **kwargs) -> Any:
23 """Monitor operation performance with comprehensive metrics."""
24
25 start_time = time.time()
26 start_system_metrics = self._get_system_metrics()
27
28 # Execute operation
29 result = operation_func(*args, **kwargs)
30
31 end_time = time.time()
32 end_system_metrics = self._get_system_metrics()
33
34 # Calculate metrics
35 execution_time = end_time - start_time
36 memory_delta = end_system_metrics["memory_usage"] - start_system_metrics["memory_usage"]
37 cpu_delta = end_system_metrics["cpu_usage"] - start_system_metrics["cpu_usage"]
38
39 # Store metrics
40 operation_metric = {
41 "operation_name": operation_name,
42 "execution_time_seconds": execution_time,
43 "memory_usage_mb": memory_delta,
44 "cpu_usage_percent": cpu_delta,
45 "timestamp": datetime.now().isoformat(),
46 "system_metrics": {
47 "start": start_system_metrics,
48 "end": end_system_metrics
49 }
50 }
51
52 if operation_name not in self.operation_metrics:
53 self.operation_metrics[operation_name] = []
54
55 self.operation_metrics[operation_name].append(operation_metric)
56
57 # Check for alerts
58 self._check_alerts(operation_metric)
59
60 return result
61
62 def _get_system_metrics(self) -> Dict[str, Any]:
63 """Get current system metrics."""
64 return {
65 "memory_usage": psutil.virtual_memory().percent,
66 "cpu_usage": psutil.cpu_percent(interval=1),
67 "disk_usage": psutil.disk_usage('/').percent,
68 "timestamp": datetime.now().isoformat()
69 }
70
71 def _check_alerts(self, operation_metric: Dict[str, Any]):
72 """Check if operation metrics exceed alert thresholds."""
73
74 alerts = []
75
76 if operation_metric["execution_time_seconds"] > self.alert_thresholds["execution_time_seconds"]:
77 alerts.append(f"Slow operation: {operation_metric['operation_name']} took {operation_metric['execution_time_seconds']:.2f} seconds")
78
79 if operation_metric["system_metrics"]["end"]["memory_usage"] > self.alert_thresholds["memory_usage_percent"]:
80 alerts.append(f"High memory usage: {operation_metric['system_metrics']['end']['memory_usage']:.1f}%")
81
82 if operation_metric["system_metrics"]["end"]["cpu_usage"] > self.alert_thresholds["cpu_usage_percent"]:
83 alerts.append(f"High CPU usage: {operation_metric['system_metrics']['end']['cpu_usage']:.1f}%")
84
85 if alerts:
86 self._send_alerts(alerts)
87
88 def _send_alerts(self, alerts: List[str]):
89 """Send performance alerts."""
90 # In production, this would send to monitoring system
91 print(f"PERFORMANCE ALERTS: {alerts}")
92
93 def get_performance_report(self, time_window_hours: int = 24) -> Dict[str, Any]:
94 """Generate comprehensive performance report."""
95
96 cutoff_time = datetime.now() - timedelta(hours=time_window_hours)
97
98 # Filter recent metrics
99 recent_metrics = {}
100 for operation_name, metrics in self.operation_metrics.items():
101 recent_metrics[operation_name] = [
102 m for m in metrics
103 if datetime.fromisoformat(m["timestamp"]) > cutoff_time
104 ]
105
106 # Calculate statistics
107 report = {
108 "time_window_hours": time_window_hours,
109 "total_operations": sum(len(metrics) for metrics in recent_metrics.values()),
110 "operations": {}
111 }
112
113 for operation_name, metrics in recent_metrics.items():
114 if metrics:
115 execution_times = [m["execution_time_seconds"] for m in metrics]
116 memory_usage = [m["memory_usage_mb"] for m in metrics]
117
118 report["operations"][operation_name] = {
119 "count": len(metrics),
120 "avg_execution_time": sum(execution_times) / len(execution_times),
121 "max_execution_time": max(execution_times),
122 "min_execution_time": min(execution_times),
123 "avg_memory_usage": sum(memory_usage) / len(memory_usage),
124 "max_memory_usage": max(memory_usage),
125 "performance_trend": self._calculate_trend(execution_times)
126 }
127
128 return report
129
130 def _calculate_trend(self, values: List[float]) -> str:
131 """Calculate performance trend."""
132 if len(values) < 2:
133 return "insufficient_data"
134
135 # Simple trend calculation
136 first_half = values[:len(values)//2]
137 second_half = values[len(values)//2:]
138
139 first_avg = sum(first_half) / len(first_half)
140 second_avg = sum(second_half) / len(second_half)
141
142 if second_avg < first_avg * 0.9:
143 return "improving"
144 elif second_avg > first_avg * 1.1:
145 return "degrading"
146 else:
147 return "stable"
148
149 def export_metrics(self, file_path: str):
150 """Export metrics to JSON file."""
151 with open(file_path, 'w') as f:
152 json.dump({
153 "operation_metrics": self.operation_metrics,
154 "system_metrics": self.system_metrics,
155 "export_timestamp": datetime.now().isoformat()
156 }, f, indent=2)
157
158# Usage example
159monitor = ProductionPerformanceMonitor(spark)
160
161# Monitor expensive operations
162def expensive_data_processing(df):
163 return df.groupBy("category").agg({"amount": "sum", "count": "count"})
164
165result = monitor.monitor_operation(
166 "category_aggregation",
167 expensive_data_processing,
168 large_dataset
169)
170
171# Generate performance report
172report = monitor.get_performance_report(time_window_hours=6)
173print(f"Total operations in last 6 hours: {report['total_operations']}")
174
175# Export metrics for analysis
176monitor.export_metrics("/path/to/performance_metrics.json")
Building upon the foundational concepts from our Performance Tuning on Apache Spark guide, here are advanced best practices for production environments:
Adaptive Query Execution:
Partition Optimization:
Intelligent Caching:
Query Plan Analysis:
Production Monitoring:
Memory Management:
By implementing these advanced optimization techniques, you can achieve significant performance improvements in production PySpark environments. Remember to profile your specific use cases and continuously monitor performance to ensure optimal results. For more foundational performance tuning concepts, refer to our comprehensive guide on Performance Tuning on Apache Spark.