๐ง Optimizing Repartitioning & Minimizing Shuffling in PySpark
Repartitioning is essential in distributed computing to optimize parallel execution, but excessive shuffling can degrade performance. Here’s how to handle it efficiently:
๐น 1๏ธโฃ Understanding Repartitioning Methods
1. repartition(n) โ Increases parallelism but causes full shuffle
df = df.repartition(10) # Redistributes into 10 partitions
โ Use Case: When load balancing is needed (e.g., skewed data).
โ Downside: Full shuffle across worker nodes.
2. repartition(col) โ Redistributes based on column values
df = df.repartition("category") # Partition based on 'category'
โ Use Case: Optimizes joins and aggregations when filtering by category.
โ Downside: Shuffles data across the cluster.
3. coalesce(n) โ Reduces partitions without full shuffle
df = df.coalesce(4) # Reduce to 4 partitions
โ Use Case: Used after filtering to reduce shuffle & optimize performance.
โ Downside: Cannot increase partitions, only reduces.
๐น 2๏ธโฃ How to Minimize Shuffling?
1๏ธโฃ Repartition Early in ETL
If you know that your data will be processed by a specific column, repartition before transformations:
df = df.repartition("category") # Avoid unnecessary shuffle in later joins/aggregations
โ Prevents multiple shuffle operations later.
2๏ธโฃ Use broadcast() Instead of Repartitioning for Small Tables
For small lookup tables, use broadcast join instead of repartitioning:
from pyspark.sql.functions import broadcast
df_final = df_large.join(broadcast(df_small), "id", "inner")
โ Reduces shuffle when joining a large table with a small table.
3๏ธโฃ Optimize Joins Using DISTRIBUTE BY Instead of repartition()
If joining two large tables, use DISTRIBUTE BY instead of repartition:
df.createOrReplaceTempView("big_table")
df_small.createOrReplaceTempView("small_table")
query = """
SELECT * FROM big_table
DISTRIBUTE BY category
JOIN small_table
ON big_table.category = small_table.category
"""
df_final = spark.sql(query)
โ Distributes data efficiently before joining, reducing shuffle.
4๏ธโฃ Use coalesce() Instead of repartition() for Output
If reducing partitions before saving to disk, use coalesce() to avoid full shuffle:
df_final.coalesce(1).write.csv("output.csv", header=True)
โ Avoids full shuffle while reducing partitions.
๐น 3๏ธโฃ Example: Combining All Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark
spark = SparkSession.builder \
.appName("Optimized_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") \
.config("spark.executor.memory", "4g") \
.config("spark.memory.fraction", "0.8") \
.getOrCreate()
# Load Large Data
df_large = spark.read.parquet("large_data.parquet").repartition("category") # Repartition early
df_small = spark.read.parquet("small_data.parquet")
# Optimize Join
if df_small.count() < 1000000:
df_final = df_large.join(broadcast(df_small), "id", "inner")
else:
df_final = df_large.join(df_small, "id", "inner").repartition("category")
# Reduce Partitions Before Writing
df_final = df_final.coalesce(4)
df_final.write.parquet("optimized_output.parquet")
spark.stop()
โ Repartitioned Early
โ Used Broadcast Join for small tables
โ Distributed Large Table Before Join
โ Reduced Partitions Before Writing
๐ Summary: Best Practices for Repartitioning
โ
Use repartition(col) early to reduce shuffling later
โ
Use broadcast() for small tables instead of repartitioning
โ
Use DISTRIBUTE BY instead of repartition() for large joins
โ
Use coalesce() to reduce partitions before writing
๐ง Ensuring Early Repartition & Sequential Execution in PySpark
By default, PySpark follows lazy evaluation, meaning transformations (like repartition()) are not executed immediately. They are only triggered when an action (e.g., .count(), .show(), .write()) is called.
๐น 1๏ธโฃ Ensuring Repartition Happens Early
Since repartition is a transformation, Spark does not execute it immediately. To force execution, we can use an action immediately after repartitioning:
df = df.repartition("category") # Repartition before heavy transformations
df.count() # Triggers the repartition immediately
โ Ensures data is shuffled before moving to the next steps.
๐น 2๏ธโฃ Forcing Sequential Execution
To make sure operations happen in order, use actions after key transformations:
df = df.repartition("category") # Step 1: Repartition
df.cache().count() # Step 2: Trigger execution & cache result
df_filtered = df.filter("amount > 100") # Step 3: Apply filter
df_filtered.count() # Step 4: Force execution
df_aggregated = df_filtered.groupBy("category").sum("amount") # Step 5: Aggregate
df_aggregated.show() # Step 6: Trigger execution
โ Each step is executed in sequence
โ Avoids unnecessary recomputation (because of cache())
๐น 3๏ธโฃ Using persist() or cache() for Sequential Execution
Instead of triggering execution with .count(), use caching:
df = df.repartition("category").persist() # Persist after repartition
df.count() # Triggers execution & caches partitioned data
df_filtered = df.filter("amount > 100")
df_filtered.persist()
df_filtered.count() # Executes filter before moving ahead
โ Ensures data is partitioned before filtering
โ Reduces recomputation in later stages
๐น 4๏ธโฃ Using checkpoint() for Strict Sequential Execution
If data is very large, use checkpointing instead of caching:
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
df = df.repartition("category").checkpoint() # Save intermediate state
df.count() # Executes the repartition
df_filtered = df.filter("amount > 100").checkpoint() # Checkpoint filtered data
df_filtered.count() # Ensures execution before proceeding
โ Forces Spark to save intermediate results to disk
โ Prevents re-execution of previous steps
๐น 5๏ธโฃ Using foreachPartition() to Trigger Execution
Another way to ensure sequential execution is using foreachPartition(), which triggers an action for each partition:
df.repartition("category").foreachPartition(lambda x: list(x)) # Forces execution
โ Ensures repartitioning is completed before moving ahead.
๐ Best Practices for Sequential Execution in PySpark
โ
Use .count() after repartitioning to trigger execution
โ
Use .persist() or .cache() to avoid recomputation
โ
Use .checkpoint() for large datasets
โ
Use .foreachPartition() to force execution per partition
# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark Session
spark = SparkSession.builder \
.appName("PySparkSQL_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") # Set auto-broadcast threshold to 500MB
.config("spark.executor.memory", "4g") # Increase executor memory
.config("spark.driver.memory", "2g") # Increase driver memory
.config("spark.memory.fraction", "0.8") # Allocate more memory for computation
.getOrCreate()
# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
(2, "B", "inactive", 200),
(3, "A", "active", 150),
(4, "C", "active", 120),
(5, "B", "inactive", 300)]
columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)
# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist()
df.count() # Forces execution & caching
# Approach 1: Using Temp Views for Step-by-Step ETL
df.createOrReplaceTempView("source_data")
# Step 1: Filter Active Records
filtered_query = """
SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).checkpoint()
filtered_df.count() # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")
# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")
# Step 2: Aggregation
aggregated_query = """
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist()
aggregated_df.count() # Forces execution
aggregated_df.show()
# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).checkpoint()
cte_df.count() # Ensures execution
cte_df.show()
# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).checkpoint()
ranked_df.count() # Ensures execution
ranked_df.show()
# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)
# Decide whether to broadcast based on size
if df_small.count() < 1000000: # Example: Broadcast if less than 1 million rows
df_final = df.join(broadcast(df_small), "id", "inner")
else:
df_final = df.join(df_small, "id", "inner").repartition("category")
df_final.persist()
df_final.count() # Forces execution before writing
df_final.write.mode("overwrite").parquet("optimized_output.parquet")
# Closing Spark Session
spark.stop()
Leave a Reply