import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
completed_steps = set()
df_registry = {} # Store DataFrames for reference
def log_execution(step_id, process_name, start_time, end_time, status, error_message=None, count_rows=None):
"""Logs execution details in Hive log table"""
log_data = [(step_id, process_name, start_time, end_time, status, error_message, count_rows)]
log_df = spark.createDataFrame(log_data, ["step_id", "process_name", "start_time", "end_time", "status", "error_message", "count_rows"])
log_df.write.mode("append").saveAsTable("etl_log")
def write_final_table(df, table_name, mode):
"""Writes the final table in overwrite, snapshot, or append mode"""
if mode == "overwrite":
df.write.mode("overwrite").saveAsTable(table_name)
elif mode == "snapshot":
timestamp = datetime.now().strftime("%Y%m%d%H%M")
snapshot_table = f"{table_name}_{timestamp}"
df.write.mode("overwrite").saveAsTable(snapshot_table)
spark.sql(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM {snapshot_table}")
elif mode == "append":
df.write.mode("append").saveAsTable(table_name)
def execute_step(row):
"""Executes a single ETL step with caching and persistence"""
step_id, process_name, sql_query, depends_on, broadcast, cache, persist, repartition, final_table, write_mode = row
start_time = time.time()
count_rows = None
try:
# Wait for dependencies to complete
if depends_on:
dependencies = depends_on.split(", ")
while not all(dep in completed_steps for dep in dependencies):
time.sleep(1) # Wait for dependencies to complete
df = spark.sql(sql_query)
# Optimizations
if broadcast:
df = spark.sql(f"BROADCAST({sql_query})")
if cache:
df = df.cache()
count_rows = df.count() # Force evaluation to persist cache
if persist:
df = df.persist()
count_rows = df.count() # Force evaluation to persist dataframe
if repartition:
df = df.repartition(repartition)
df.createOrReplaceTempView(process_name)
df_registry[process_name] = df
completed_steps.add(process_name)
if count_rows is None:
count_rows = df.count() # Count only if not cached or persisted
# Write to final table if applicable
if final_table:
write_final_table(df, final_table, write_mode)
log_execution(step_id, process_name, start_time, time.time(), "SUCCESS", count_rows=count_rows)
return f"Step {process_name} executed successfully."
except Exception as e:
log_execution(step_id, process_name, start_time, time.time(), "FAILED", str(e))
return f"Step {process_name} failed. Error: {str(e)}"
def get_metadata():
"""Fetch metadata from Hive table"""
return [(row.step_id, row.process_name, row.sql_query, row.depends_on, row.broadcast, row.cache, row.persist, row.repartition, row.final_table, row.write_mode)
for row in spark.table("etl_metadata").collect()]
def execute_etl():
"""Execute ETL steps with parallel execution"""
metadata = sorted(get_metadata(), key=lambda x: x[4]) # Order by execution_order
with ThreadPoolExecutor() as executor:
results = executor.map(execute_step, metadata)
for result in results:
print(result)
# Create Metadata Table Using DataFrame with Sample Data
def create_metadata_table():
metadata_schema = ["step_id", "process_name", "sql_query", "depends_on", "broadcast", "cache", "persist", "repartition", "final_table", "write_mode"]
metadata_data = [
(1, "load_data", "SELECT * FROM source_table", None, False, False, False, None, "processed_data", "overwrite"),
(2, "aggregate_data", "SELECT col1, COUNT(*) FROM processed_data GROUP BY col1", "load_data", True, True, False, None, "aggregated_data", "snapshot"),
(3, "filter_data", "SELECT * FROM aggregated_data WHERE col1 > 10", "aggregate_data", False, True, False, None, "final_output", "append")
]
metadata_df = spark.createDataFrame(metadata_data, metadata_schema)
metadata_df.write.mode("overwrite").saveAsTable("etl_metadata")
# Create Log Table Using DataFrame
def create_log_table():
log_schema = ["step_id", "process_name", "start_time", "end_time", "status", "error_message", "count_rows"]
log_data = [] # Empty initial structure
log_df = spark.createDataFrame(log_data, log_schema)
log_df.write.mode("overwrite").saveAsTable("etl_log")
# Version Control Metadata Table
def version_metadata_table():
timestamp = datetime.now().strftime("%Y%m%d%H%M")
snapshot_table = f"etl_metadata_{timestamp}"
spark.sql(f"CREATE TABLE {snapshot_table} AS SELECT * FROM etl_metadata")
print(f"Metadata snapshot created: {snapshot_table}")
create_metadata_table()
create_log_table()
version_metadata_table()
execute_etl()
Share this- Make us Famous:-