🔧 Optimizing Repartitioning & Minimizing Shuffling in PySpark
Repartitioning is essential in distributed computing to optimize parallel execution, but excessive shuffling can degrade performance. Here’s how to handle it efficiently:
🔹 1️⃣ Understanding Repartitioning Methods
1. repartition(n)
– Increases parallelism but causes full shuffle
df = df.repartition(10) # Redistributes into 10 partitions
✔ Use Case: When load balancing is needed (e.g., skewed data).
❌ Downside: Full shuffle across worker nodes.
2. repartition(col)
– Redistributes based on column values
df = df.repartition("category") # Partition based on 'category'
✔ Use Case: Optimizes joins and aggregations when filtering by category
.
❌ Downside: Shuffles data across the cluster.
3. coalesce(n)
– Reduces partitions without full shuffle
df = df.coalesce(4) # Reduce to 4 partitions
✔ Use Case: Used after filtering to reduce shuffle & optimize performance.
❌ Downside: Cannot increase partitions, only reduces.
🔹 2️⃣ How to Minimize Shuffling?
1️⃣ Repartition Early in ETL
If you know that your data will be processed by a specific column, repartition before transformations:
df = df.repartition("category") # Avoid unnecessary shuffle in later joins/aggregations
✔ Prevents multiple shuffle operations later.
2️⃣ Use broadcast()
Instead of Repartitioning for Small Tables
For small lookup tables, use broadcast join instead of repartitioning:
from pyspark.sql.functions import broadcast
df_final = df_large.join(broadcast(df_small), "id", "inner")
✔ Reduces shuffle when joining a large table with a small table.
3️⃣ Optimize Joins Using DISTRIBUTE BY
Instead of repartition()
If joining two large tables, use DISTRIBUTE BY
instead of repartition:
df.createOrReplaceTempView("big_table")
df_small.createOrReplaceTempView("small_table")
query = """
SELECT * FROM big_table
DISTRIBUTE BY category
JOIN small_table
ON big_table.category = small_table.category
"""
df_final = spark.sql(query)
✔ Distributes data efficiently before joining, reducing shuffle.
4️⃣ Use coalesce()
Instead of repartition()
for Output
If reducing partitions before saving to disk, use coalesce() to avoid full shuffle:
df_final.coalesce(1).write.csv("output.csv", header=True)
✔ Avoids full shuffle while reducing partitions.
🔹 3️⃣ Example: Combining All Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark
spark = SparkSession.builder \
.appName("Optimized_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") \
.config("spark.executor.memory", "4g") \
.config("spark.memory.fraction", "0.8") \
.getOrCreate()
# Load Large Data
df_large = spark.read.parquet("large_data.parquet").repartition("category") # Repartition early
df_small = spark.read.parquet("small_data.parquet")
# Optimize Join
if df_small.count() < 1000000:
df_final = df_large.join(broadcast(df_small), "id", "inner")
else:
df_final = df_large.join(df_small, "id", "inner").repartition("category")
# Reduce Partitions Before Writing
df_final = df_final.coalesce(4)
df_final.write.parquet("optimized_output.parquet")
spark.stop()
✔ Repartitioned Early
✔ Used Broadcast Join for small tables
✔ Distributed Large Table Before Join
✔ Reduced Partitions Before Writing
🚀 Summary: Best Practices for Repartitioning
✅ Use repartition(col)
early to reduce shuffling later
✅ Use broadcast()
for small tables instead of repartitioning
✅ Use DISTRIBUTE BY
instead of repartition()
for large joins
✅ Use coalesce()
to reduce partitions before writing
🔧 Ensuring Early Repartition & Sequential Execution in PySpark
By default, PySpark follows lazy evaluation, meaning transformations (like repartition()
) are not executed immediately. They are only triggered when an action (e.g., .count()
, .show()
, .write()
) is called.
🔹 1️⃣ Ensuring Repartition Happens Early
Since repartition is a transformation, Spark does not execute it immediately. To force execution, we can use an action immediately after repartitioning:
df = df.repartition("category") # Repartition before heavy transformations
df.count() # Triggers the repartition immediately
✔ Ensures data is shuffled before moving to the next steps.
🔹 2️⃣ Forcing Sequential Execution
To make sure operations happen in order, use actions after key transformations:
df = df.repartition("category") # Step 1: Repartition
df.cache().count() # Step 2: Trigger execution & cache result
df_filtered = df.filter("amount > 100") # Step 3: Apply filter
df_filtered.count() # Step 4: Force execution
df_aggregated = df_filtered.groupBy("category").sum("amount") # Step 5: Aggregate
df_aggregated.show() # Step 6: Trigger execution
✔ Each step is executed in sequence
✔ Avoids unnecessary recomputation (because of cache()
)
🔹 3️⃣ Using persist()
or cache()
for Sequential Execution
Instead of triggering execution with .count()
, use caching:
df = df.repartition("category").persist() # Persist after repartition
df.count() # Triggers execution & caches partitioned data
df_filtered = df.filter("amount > 100")
df_filtered.persist()
df_filtered.count() # Executes filter before moving ahead
✔ Ensures data is partitioned before filtering
✔ Reduces recomputation in later stages
🔹 4️⃣ Using checkpoint()
for Strict Sequential Execution
If data is very large, use checkpointing instead of caching:
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
df = df.repartition("category").checkpoint() # Save intermediate state
df.count() # Executes the repartition
df_filtered = df.filter("amount > 100").checkpoint() # Checkpoint filtered data
df_filtered.count() # Ensures execution before proceeding
✔ Forces Spark to save intermediate results to disk
✔ Prevents re-execution of previous steps
🔹 5️⃣ Using foreachPartition()
to Trigger Execution
Another way to ensure sequential execution is using foreachPartition(), which triggers an action for each partition:
df.repartition("category").foreachPartition(lambda x: list(x)) # Forces execution
✔ Ensures repartitioning is completed before moving ahead.
🚀 Best Practices for Sequential Execution in PySpark
✅ Use .count()
after repartitioning to trigger execution
✅ Use .persist()
or .cache()
to avoid recomputation
✅ Use .checkpoint()
for large datasets
✅ Use .foreachPartition()
to force execution per partition
# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark Session
spark = SparkSession.builder \
.appName("PySparkSQL_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") # Set auto-broadcast threshold to 500MB
.config("spark.executor.memory", "4g") # Increase executor memory
.config("spark.driver.memory", "2g") # Increase driver memory
.config("spark.memory.fraction", "0.8") # Allocate more memory for computation
.getOrCreate()
# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
(2, "B", "inactive", 200),
(3, "A", "active", 150),
(4, "C", "active", 120),
(5, "B", "inactive", 300)]
columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)
# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist()
df.count() # Forces execution & caching
# Approach 1: Using Temp Views for Step-by-Step ETL
df.createOrReplaceTempView("source_data")
# Step 1: Filter Active Records
filtered_query = """
SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).checkpoint()
filtered_df.count() # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")
# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")
# Step 2: Aggregation
aggregated_query = """
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist()
aggregated_df.count() # Forces execution
aggregated_df.show()
# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).checkpoint()
cte_df.count() # Ensures execution
cte_df.show()
# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).checkpoint()
ranked_df.count() # Ensures execution
ranked_df.show()
# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)
# Decide whether to broadcast based on size
if df_small.count() < 1000000: # Example: Broadcast if less than 1 million rows
df_final = df.join(broadcast(df_small), "id", "inner")
else:
df_final = df.join(df_small, "id", "inner").repartition("category")
df_final.persist()
df_final.count() # Forces execution before writing
df_final.write.mode("overwrite").parquet("optimized_output.parquet")
# Closing Spark Session
spark.stop()
Discover more from HintsToday
Subscribe to get the latest posts sent to your email.
Home › Forums › PySpark SQL API Programming- How To, Approaches, Optimization
Share this- Make us Famous:-
Share this- Make us Famous:-