🔍 Understanding cache() in PySpark: Functionality, Optimization & Best Use Cases


🔹 What is cache() in PySpark?

  • cache() stores the DataFrame in memory across worker nodes.
  • It avoids recomputation of previous transformations when the DataFrame is used multiple times.

🔧 How Does cache() Work Internally?

  1. When you call df.cache(), it does NOT store data immediately. It only marks the DataFrame as cacheable.
  2. The first action on the DataFrame (df.count(), df.show(), df.collect(), etc.) triggers the computation and stores the data in memory.
  3. Subsequent actions on the cached DataFrame will use the in-memory data instead of recomputing from the source.

🔹 How Does cache() Optimize Performance?

Avoids Recomputations:

  • PySpark transformations are lazy. Without caching, each action triggers full recomputation from the source.
  • cache() saves the DataFrame after transformations, so repeated queries run much faster.

Reduces IO Load & Network Latency:

  • If data comes from external sources (S3, HDFS, DBs), cache() reduces repeated reads.
  • Cached data is distributed across the cluster.

Speeds Up Iterative Jobs (ML, Graph Processing, Multiple Queries on Same Data):

  • If a DataFrame is used multiple times in a script, cache() significantly reduces processing time.

Optimizes Joins & Aggregations:

  • When joining large datasets, caching intermediate DataFrames helps reduce expensive recomputations.

🔹 Does cache() Help in Repartitioning & Reshuffling?

🔸 No, cache() does not change partitions.

  • If you cache a DataFrame, it retains its existing partition structure.
  • If you need both caching & repartitioning, always repartition first and then cache: df = df.repartition(10).cache() df.count() # Triggers cache and avoids recomputation
  • If you cache before repartitioning, Spark will still shuffle the data when repartitioning.

🔸 cache() prevents unnecessary shuffling

  • If a DataFrame is repeatedly shuffled in joins or aggregations, caching helps reuse pre-shuffled data instead of recalculating.

🔹 When to Use cache()?

💡 Best Scenarios: 1️⃣ Repeated Access to Same Data

  • If a DataFrame is queried multiple times, caching prevents recomputation.
df.cache()
df.count()  # Triggers cache
df.select("columnA").show()  # Uses cached data
df.groupBy("columnB").count().show()  # Uses cached data

2️⃣ Machine Learning Pipelines (Iterative Computations)

  • ML algorithms iterate over the same dataset multiple times.
df_train = df.cache()
model = train_model(df_train)  # Uses cached data for training

3️⃣ Optimizing Joins & Aggregations

  • If a DataFrame is used in multiple joins, cache it before the joins.
df_large.cache().count()  # Cache large table
df_joined = df_large.join(df_small, "id")

4️⃣ Reducing Recomputations in Large Queries

df_filtered = df.filter("amount > 100").cache()
df_filtered.count()  # Triggers cache
df_filtered.groupBy("category").sum("amount").show()  # Uses cached data

5️⃣ Avoiding Expensive File Reads (Parquet, CSV, Database Queries)

df = spark.read.parquet("s3://large-dataset.parquet").cache()
df.count()  # Triggers cache

🔹 When NOT to Use cache()?

🚫 Large DataFrames that Exceed Available Memory

  • If the dataset is too large for RAM, cache() may cause out-of-memory errors.

🚫 Data Used Only Once

  • If a DataFrame is used only in one query, caching is unnecessary.

🚫 Changing Data in Each Step

  • If transformations alter the DataFrame in each step, caching does not help.

🚫 When Using coalesce()

  • coalesce() reduces partitions and is often used before writing data to storage.
  • Instead of caching, just write to disk and read it back.

🔹 Difference Between cache() and persist()

Featurecache()persist(StorageLevel.MEMORY_AND_DISK)
Storage TypeOnly stores in memoryCan store in memory & disk
EvictionRemoved when memory is neededFalls back to disk if memory is full
Best forSmall to medium-sized DataFramesLarge datasets that may not fit in memory
Usagedf.cache()df.persist(StorageLevel.MEMORY_AND_DISK)

🚀 Example: Using cache() vs persist()

from pyspark.sql import SparkSession
from pyspark import StorageLevel

spark = SparkSession.builder.appName("CachingExample").getOrCreate()

df = spark.read.parquet("large_dataset.parquet")

# Caching DataFrame
df.cache()
df.count()  # Triggers cache

# Persisting with Memory + Disk storage
df.persist(StorageLevel.MEMORY_AND_DISK)
df.count()  # Triggers persistence

Use cache() when data fits in memory
Use persist() when data is large and may spill to disk


🔹 Summary: Best Practices for cache()

✅ Use when data is used multiple times to prevent recomputation
Repartition first, then cache if needed
Trigger execution (count/show) after caching
✅ Use persist() instead of cache for large datasets
Avoid caching temporary transformations that are only used once

# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations

from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
from pyspark import StorageLevel

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("PySparkSQL_ETL") \
    .config("spark.sql.autoBroadcastJoinThreshold", "524288000")  # Set auto-broadcast threshold to 500MB
    .config("spark.executor.memory", "4g")  # Increase executor memory
    .config("spark.driver.memory", "2g")  # Increase driver memory
    .config("spark.memory.fraction", "0.8")  # Allocate more memory for computation
    .getOrCreate()

# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")

# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
        (2, "B", "inactive", 200),
        (3, "A", "active", 150),
        (4, "C", "active", 120),
        (5, "B", "inactive", 300)]

columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)

# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist(StorageLevel.MEMORY_AND_DISK)
df.count()  # Forces execution & caching

# Approach 1: Using Temp Views for Step-by-Step ETL

df.createOrReplaceTempView("source_data")

# Step 1: Filter Active Records
filtered_query = """
    SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).persist(StorageLevel.MEMORY_AND_DISK)
filtered_df.count()  # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")

# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")

# Step 2: Aggregation
aggregated_query = """
    SELECT category, SUM(amount) AS total_amount
    FROM filtered_data
    GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist(StorageLevel.MEMORY_AND_DISK)
aggregated_df.count()  # Forces execution
aggregated_df.show()

# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
    SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
    SELECT category, SUM(amount) AS total_amount
    FROM filtered_data
    GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).persist(StorageLevel.MEMORY_AND_DISK)
cte_df.count()  # Ensures execution
cte_df.show()

# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
    SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
    SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
    FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).persist(StorageLevel.MEMORY_AND_DISK)
ranked_df.count()  # Ensures execution
ranked_df.show()

# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)

# Decide whether to broadcast based on size
if df_small.count() < 1000000:  # Example: Broadcast if less than 1 million rows
    df_final = df.join(broadcast(df_small), "id", "inner").persist(StorageLevel.MEMORY_AND_DISK)
else:
    df_final = df.join(df_small, "id", "inner").repartition("category").persist(StorageLevel.MEMORY_AND_DISK)

df_final.count()  # Forces execution before writing

df_final.write.mode("overwrite").parquet("optimized_output.parquet")

# Closing Spark Session
spark.stop()


Discover more from HintsToday

Subscribe to get the latest posts sent to your email.

Pages ( 4 of 4 ): « Previous123 4

Discover more from HintsToday

Subscribe now to keep reading and get access to the full archive.

Continue reading