📌 How to Decide Which Table to Broadcast in PySpark?
Broadcasting a table eliminates shuffle operations, making joins faster by distributing a small table to all worker nodes.
💡 General Rule:
📌 Broadcast the smaller table if:
✔ Size < 10GB (safe limit)
✔ Used in a JOIN operation
✔ Frequently accessed in multiple queries
🔍 1️⃣ Automatically Decide Which Table to Broadcast
Instead of manually checking table size, PySpark can auto-broadcast using:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "524288000") # 500MB
- If a table is smaller than this limit, Spark automatically broadcasts it.
- Default value = 10MB, so increase it for large datasets.
⚙ 2️⃣ Manually Broadcast a Table in PySpark SQL API
If Spark does not auto-broadcast, use an explicit hint:
query = """
SELECT /*+ BROADCAST(small_table) */ a.*, b.extra_info
FROM large_table a
JOIN small_table b
ON a.id = b.id
"""
df = spark.sql(query)
✔ Forces PySpark to broadcast small_table
.
⚡ 3️⃣ Using DataFrame API: broadcast()
If using DataFrame API, manually broadcast:
from pyspark.sql.functions import broadcast
df_large = spark.read.parquet("large_data.parquet")
df_small = spark.read.parquet("small_data.parquet")
df_final = df_large.join(broadcast(df_small), "id", "inner")
✔ Reduces shuffle
✔ Faster execution
🛠 4️⃣ How to Automatically Detect & Broadcast the Right Table?
Dynamically check size and decide:
from pyspark.sql.utils import AnalysisException
def should_broadcast(df, threshold=500 * 1024 * 1024): # 500MB
try:
size_in_bytes = df.inputFiles().__sizeof__() # Rough size estimate
return size_in_bytes < threshold
except AnalysisException:
return False
# Load Data
df_large = spark.read.parquet("large_data.parquet")
df_small = spark.read.parquet("small_data.parquet")
# Decide to Broadcast
if should_broadcast(df_small):
df_final = df_large.join(broadcast(df_small), "id", "inner")
else:
df_final = df_large.join(df_small, "id", "inner")
✔ Automates broadcasting decision
✔ Prevents unnecessary memory use
🔧 5️⃣ Memory Adjustment for Broadcast Joins
Since broadcasting loads data into executor memory, adjust these settings in Spark configuration:
🔹 Increase memory for executors:
spark.conf.set("spark.executor.memory", "4g") # Default: 1g
spark.conf.set("spark.driver.memory", "2g") # Default: 1g
🔹 Increase shuffle memory fraction:
spark.conf.set("spark.memory.fraction", "0.8") # Default: 0.6
- Allocates 80% of total executor memory to computations.
🔹 Increase broadcast threshold (if needed):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # Disable auto-broadcasting
- Use
-1
if you don’t want Spark to automatically broadcast and prefer manual control.
🚀 Summary: Best Practices for Broadcast Joins
✅ Use broadcast(df)
for DataFrame API
✅ Use /*+ BROADCAST(table) */
for SQL API
✅ Adjust spark.sql.autoBroadcastJoinThreshold
for automatic control
✅ Tune spark.executor.memory
for large tables
✅ Automate broadcast decision using Python logic
# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark Session
spark = SparkSession.builder \
.appName("PySparkSQL_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") # Set auto-broadcast threshold to 500MB
.config("spark.executor.memory", "4g") # Increase executor memory
.config("spark.driver.memory", "2g") # Increase driver memory
.config("spark.memory.fraction", "0.8") # Allocate more memory for computation
.getOrCreate()
# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
(2, "B", "inactive", 200),
(3, "A", "active", 150),
(4, "C", "active", 120),
(5, "B", "inactive", 300)]
columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)
# Approach 1: Using Temp Views for Step-by-Step ETL
df.createOrReplaceTempView("source_data")
# Step 1: Filter Active Records
filtered_query = """
SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query)
filtered_df.createOrReplaceTempView("filtered_data")
# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")
# Step 2: Aggregation
aggregated_query = """
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query)
aggregated_df.show()
# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query)
cte_df.show()
# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query)
ranked_df.show()
# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)
# Decide whether to broadcast based on size
if df_small.count() < 1000000: # Example: Broadcast if less than 1 million rows
df_final = df.join(broadcast(df_small), "id", "inner")
else:
df_final = df.join(df_small, "id", "inner")
df_final.show()
# Closing Spark Session
spark.stop()
Discover more from HintsToday
Subscribe to get the latest posts sent to your email.
Home › Forums › PySpark SQL API Programming- How To, Approaches, Optimization
Share this- Make us Famous:-
Share this- Make us Famous:-