🔧 Optimizing Repartitioning & Minimizing Shuffling in PySpark

Repartitioning is essential in distributed computing to optimize parallel execution, but excessive shuffling can degrade performance. Here’s how to handle it efficiently:


🔹 1️⃣ Understanding Repartitioning Methods

1. repartition(n) – Increases parallelism but causes full shuffle

df = df.repartition(10)  # Redistributes into 10 partitions

Use Case: When load balancing is needed (e.g., skewed data).
Downside: Full shuffle across worker nodes.

2. repartition(col) – Redistributes based on column values

df = df.repartition("category")  # Partition based on 'category'

Use Case: Optimizes joins and aggregations when filtering by category.
Downside: Shuffles data across the cluster.

3. coalesce(n) – Reduces partitions without full shuffle

df = df.coalesce(4)  # Reduce to 4 partitions

Use Case: Used after filtering to reduce shuffle & optimize performance.
Downside: Cannot increase partitions, only reduces.


🔹 2️⃣ How to Minimize Shuffling?

1️⃣ Repartition Early in ETL

If you know that your data will be processed by a specific column, repartition before transformations:

df = df.repartition("category")  # Avoid unnecessary shuffle in later joins/aggregations

✔ Prevents multiple shuffle operations later.

2️⃣ Use broadcast() Instead of Repartitioning for Small Tables

For small lookup tables, use broadcast join instead of repartitioning:

from pyspark.sql.functions import broadcast
df_final = df_large.join(broadcast(df_small), "id", "inner")

✔ Reduces shuffle when joining a large table with a small table.

3️⃣ Optimize Joins Using DISTRIBUTE BY Instead of repartition()

If joining two large tables, use DISTRIBUTE BY instead of repartition:

df.createOrReplaceTempView("big_table")
df_small.createOrReplaceTempView("small_table")

query = """
SELECT * FROM big_table
DISTRIBUTE BY category
JOIN small_table
ON big_table.category = small_table.category
"""
df_final = spark.sql(query)

Distributes data efficiently before joining, reducing shuffle.

4️⃣ Use coalesce() Instead of repartition() for Output

If reducing partitions before saving to disk, use coalesce() to avoid full shuffle:

df_final.coalesce(1).write.csv("output.csv", header=True)

✔ Avoids full shuffle while reducing partitions.


🔹 3️⃣ Example: Combining All Optimizations

from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

# Initialize Spark
spark = SparkSession.builder \
    .appName("Optimized_ETL") \
    .config("spark.sql.autoBroadcastJoinThreshold", "524288000") \
    .config("spark.executor.memory", "4g") \
    .config("spark.memory.fraction", "0.8") \
    .getOrCreate()

# Load Large Data
df_large = spark.read.parquet("large_data.parquet").repartition("category")  # Repartition early
df_small = spark.read.parquet("small_data.parquet")

# Optimize Join
if df_small.count() < 1000000:
    df_final = df_large.join(broadcast(df_small), "id", "inner")
else:
    df_final = df_large.join(df_small, "id", "inner").repartition("category")

# Reduce Partitions Before Writing
df_final = df_final.coalesce(4)
df_final.write.parquet("optimized_output.parquet")

spark.stop()

Repartitioned Early
Used Broadcast Join for small tables
Distributed Large Table Before Join
Reduced Partitions Before Writing


🚀 Summary: Best Practices for Repartitioning

Use repartition(col) early to reduce shuffling later
Use broadcast() for small tables instead of repartitioning
Use DISTRIBUTE BY instead of repartition() for large joins
Use coalesce() to reduce partitions before writing

🔧 Ensuring Early Repartition & Sequential Execution in PySpark

By default, PySpark follows lazy evaluation, meaning transformations (like repartition()) are not executed immediately. They are only triggered when an action (e.g., .count(), .show(), .write()) is called.

🔹 1️⃣ Ensuring Repartition Happens Early

Since repartition is a transformation, Spark does not execute it immediately. To force execution, we can use an action immediately after repartitioning:

df = df.repartition("category")  # Repartition before heavy transformations
df.count()  # Triggers the repartition immediately

✔ Ensures data is shuffled before moving to the next steps.


🔹 2️⃣ Forcing Sequential Execution

To make sure operations happen in order, use actions after key transformations:

df = df.repartition("category")  # Step 1: Repartition
df.cache().count()  # Step 2: Trigger execution & cache result

df_filtered = df.filter("amount > 100")  # Step 3: Apply filter
df_filtered.count()  # Step 4: Force execution

df_aggregated = df_filtered.groupBy("category").sum("amount")  # Step 5: Aggregate
df_aggregated.show()  # Step 6: Trigger execution

Each step is executed in sequence
Avoids unnecessary recomputation (because of cache())


🔹 3️⃣ Using persist() or cache() for Sequential Execution

Instead of triggering execution with .count(), use caching:

df = df.repartition("category").persist()  # Persist after repartition
df.count()  # Triggers execution & caches partitioned data

df_filtered = df.filter("amount > 100")
df_filtered.persist()
df_filtered.count()  # Executes filter before moving ahead

✔ Ensures data is partitioned before filtering
Reduces recomputation in later stages


🔹 4️⃣ Using checkpoint() for Strict Sequential Execution

If data is very large, use checkpointing instead of caching:

spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")

df = df.repartition("category").checkpoint()  # Save intermediate state
df.count()  # Executes the repartition

df_filtered = df.filter("amount > 100").checkpoint()  # Checkpoint filtered data
df_filtered.count()  # Ensures execution before proceeding

Forces Spark to save intermediate results to disk
Prevents re-execution of previous steps


🔹 5️⃣ Using foreachPartition() to Trigger Execution

Another way to ensure sequential execution is using foreachPartition(), which triggers an action for each partition:

df.repartition("category").foreachPartition(lambda x: list(x))  # Forces execution

✔ Ensures repartitioning is completed before moving ahead.


🚀 Best Practices for Sequential Execution in PySpark

Use .count() after repartitioning to trigger execution
Use .persist() or .cache() to avoid recomputation
Use .checkpoint() for large datasets
Use .foreachPartition() to force execution per partition

# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations

from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("PySparkSQL_ETL") \
    .config("spark.sql.autoBroadcastJoinThreshold", "524288000")  # Set auto-broadcast threshold to 500MB
    .config("spark.executor.memory", "4g")  # Increase executor memory
    .config("spark.driver.memory", "2g")  # Increase driver memory
    .config("spark.memory.fraction", "0.8")  # Allocate more memory for computation
    .getOrCreate()

# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")

# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
        (2, "B", "inactive", 200),
        (3, "A", "active", 150),
        (4, "C", "active", 120),
        (5, "B", "inactive", 300)]

columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)

# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist()
df.count()  # Forces execution & caching

# Approach 1: Using Temp Views for Step-by-Step ETL

df.createOrReplaceTempView("source_data")

# Step 1: Filter Active Records
filtered_query = """
    SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).checkpoint()
filtered_df.count()  # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")

# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")

# Step 2: Aggregation
aggregated_query = """
    SELECT category, SUM(amount) AS total_amount
    FROM filtered_data
    GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist()
aggregated_df.count()  # Forces execution
aggregated_df.show()

# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
    SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
    SELECT category, SUM(amount) AS total_amount
    FROM filtered_data
    GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).checkpoint()
cte_df.count()  # Ensures execution
cte_df.show()

# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
    SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
    SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
    FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).checkpoint()
ranked_df.count()  # Ensures execution
ranked_df.show()

# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)

# Decide whether to broadcast based on size
if df_small.count() < 1000000:  # Example: Broadcast if less than 1 million rows
    df_final = df.join(broadcast(df_small), "id", "inner")
else:
    df_final = df.join(df_small, "id", "inner").repartition("category")

df_final.persist()
df_final.count()  # Forces execution before writing

df_final.write.mode("overwrite").parquet("optimized_output.parquet")

# Closing Spark Session
spark.stop()

Discover more from HintsToday

Subscribe to get the latest posts sent to your email.

Pages ( 3 of 4 ): « Previous12 3 4Next »

Discover more from HintsToday

Subscribe now to keep reading and get access to the full archive.

Continue reading