🔍 Understanding cache()
in PySpark: Functionality, Optimization & Best Use Cases
🔹 What is cache()
in PySpark?
cache()
stores the DataFrame in memory across worker nodes.- It avoids recomputation of previous transformations when the DataFrame is used multiple times.
🔧 How Does cache()
Work Internally?
- When you call
df.cache()
, it does NOT store data immediately. It only marks the DataFrame as cacheable. - The first action on the DataFrame (
df.count()
,df.show()
,df.collect()
, etc.) triggers the computation and stores the data in memory. - Subsequent actions on the cached DataFrame will use the in-memory data instead of recomputing from the source.
🔹 How Does cache()
Optimize Performance?
✅ Avoids Recomputations:
- PySpark transformations are lazy. Without caching, each action triggers full recomputation from the source.
cache()
saves the DataFrame after transformations, so repeated queries run much faster.
✅ Reduces IO Load & Network Latency:
- If data comes from external sources (S3, HDFS, DBs),
cache()
reduces repeated reads. - Cached data is distributed across the cluster.
✅ Speeds Up Iterative Jobs (ML, Graph Processing, Multiple Queries on Same Data):
- If a DataFrame is used multiple times in a script,
cache()
significantly reduces processing time.
✅ Optimizes Joins & Aggregations:
- When joining large datasets, caching intermediate DataFrames helps reduce expensive recomputations.
🔹 Does cache()
Help in Repartitioning & Reshuffling?
🔸 No, cache()
does not change partitions.
- If you cache a DataFrame, it retains its existing partition structure.
- If you need both caching & repartitioning, always repartition first and then cache:
df = df.repartition(10).cache() df.count() # Triggers cache and avoids recomputation
- If you cache before repartitioning, Spark will still shuffle the data when repartitioning.
🔸 cache()
prevents unnecessary shuffling
- If a DataFrame is repeatedly shuffled in joins or aggregations, caching helps reuse pre-shuffled data instead of recalculating.
🔹 When to Use cache()
?
💡 Best Scenarios: 1️⃣ Repeated Access to Same Data
- If a DataFrame is queried multiple times, caching prevents recomputation.
df.cache()
df.count() # Triggers cache
df.select("columnA").show() # Uses cached data
df.groupBy("columnB").count().show() # Uses cached data
2️⃣ Machine Learning Pipelines (Iterative Computations)
- ML algorithms iterate over the same dataset multiple times.
df_train = df.cache()
model = train_model(df_train) # Uses cached data for training
3️⃣ Optimizing Joins & Aggregations
- If a DataFrame is used in multiple joins, cache it before the joins.
df_large.cache().count() # Cache large table
df_joined = df_large.join(df_small, "id")
4️⃣ Reducing Recomputations in Large Queries
df_filtered = df.filter("amount > 100").cache()
df_filtered.count() # Triggers cache
df_filtered.groupBy("category").sum("amount").show() # Uses cached data
5️⃣ Avoiding Expensive File Reads (Parquet, CSV, Database Queries)
df = spark.read.parquet("s3://large-dataset.parquet").cache()
df.count() # Triggers cache
🔹 When NOT to Use cache()
?
🚫 Large DataFrames that Exceed Available Memory
- If the dataset is too large for RAM,
cache()
may cause out-of-memory errors.
🚫 Data Used Only Once
- If a DataFrame is used only in one query, caching is unnecessary.
🚫 Changing Data in Each Step
- If transformations alter the DataFrame in each step, caching does not help.
🚫 When Using coalesce()
coalesce()
reduces partitions and is often used before writing data to storage.- Instead of caching, just write to disk and read it back.
🔹 Difference Between cache()
and persist()
Feature | cache() | persist(StorageLevel.MEMORY_AND_DISK) |
---|---|---|
Storage Type | Only stores in memory | Can store in memory & disk |
Eviction | Removed when memory is needed | Falls back to disk if memory is full |
Best for | Small to medium-sized DataFrames | Large datasets that may not fit in memory |
Usage | df.cache() | df.persist(StorageLevel.MEMORY_AND_DISK) |
🚀 Example: Using cache()
vs persist()
from pyspark.sql import SparkSession
from pyspark import StorageLevel
spark = SparkSession.builder.appName("CachingExample").getOrCreate()
df = spark.read.parquet("large_dataset.parquet")
# Caching DataFrame
df.cache()
df.count() # Triggers cache
# Persisting with Memory + Disk storage
df.persist(StorageLevel.MEMORY_AND_DISK)
df.count() # Triggers persistence
✔ Use cache()
when data fits in memory
✔ Use persist()
when data is large and may spill to disk
🔹 Summary: Best Practices for cache()
✅ Use when data is used multiple times to prevent recomputation
✅ Repartition first, then cache if needed
✅ Trigger execution (count/show) after caching
✅ Use persist() instead of cache for large datasets
✅ Avoid caching temporary transformations that are only used once
# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
from pyspark import StorageLevel
# Initialize Spark Session
spark = SparkSession.builder \
.appName("PySparkSQL_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") # Set auto-broadcast threshold to 500MB
.config("spark.executor.memory", "4g") # Increase executor memory
.config("spark.driver.memory", "2g") # Increase driver memory
.config("spark.memory.fraction", "0.8") # Allocate more memory for computation
.getOrCreate()
# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
(2, "B", "inactive", 200),
(3, "A", "active", 150),
(4, "C", "active", 120),
(5, "B", "inactive", 300)]
columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)
# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist(StorageLevel.MEMORY_AND_DISK)
df.count() # Forces execution & caching
# Approach 1: Using Temp Views for Step-by-Step ETL
df.createOrReplaceTempView("source_data")
# Step 1: Filter Active Records
filtered_query = """
SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).persist(StorageLevel.MEMORY_AND_DISK)
filtered_df.count() # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")
# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")
# Step 2: Aggregation
aggregated_query = """
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist(StorageLevel.MEMORY_AND_DISK)
aggregated_df.count() # Forces execution
aggregated_df.show()
# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).persist(StorageLevel.MEMORY_AND_DISK)
cte_df.count() # Ensures execution
cte_df.show()
# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).persist(StorageLevel.MEMORY_AND_DISK)
ranked_df.count() # Ensures execution
ranked_df.show()
# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)
# Decide whether to broadcast based on size
if df_small.count() < 1000000: # Example: Broadcast if less than 1 million rows
df_final = df.join(broadcast(df_small), "id", "inner").persist(StorageLevel.MEMORY_AND_DISK)
else:
df_final = df.join(df_small, "id", "inner").repartition("category").persist(StorageLevel.MEMORY_AND_DISK)
df_final.count() # Forces execution before writing
df_final.write.mode("overwrite").parquet("optimized_output.parquet")
# Closing Spark Session
spark.stop()
Discover more from HintsToday
Subscribe to get the latest posts sent to your email.
Home › Forums › PySpark SQL API Programming- How To, Approaches, Optimization
Share this- Make us Famous:-
Share this- Make us Famous:-