ETL framework for Dynamic Pyspark SQL Api Code Execution

ETL framework for Dynamic Pyspark SQL Api Code Execution:

Parallel Execution Handling using depends_on and execution_order
Metadata Generation for dynamic ETL execution
Optimizations: Broadcast Join, Caching, Repartitioning
Error Handling & Logging
Hive Table for Metadata & Logging
Version Control on Metadata Table
Row Count & Table Size in Log Table
Retry Logic for Failed Steps

This framework reads metadata from Hive, executes steps sequentially or in parallel, applies optimizations, and logs execution details.


📌 1. Metadata Table Schema (Stored in Hive)

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BooleanType

spark = SparkSession.builder.enableHiveSupport().getOrCreate()

metadata_schema = StructType([
    StructField("step_id", IntegerType(), False),
    StructField("process_name", StringType(), False),
    StructField("sql_query", StringType(), False),
    StructField("depends_on", StringType(), True),
    StructField("execution_order", IntegerType(), False),
    StructField("broadcast", BooleanType(), True),  # For optimization
    StructField("cache", BooleanType(), True),      # For caching
    StructField("repartition", IntegerType(), True) # Repartitioning
])

metadata_data = [
    (1, "extract_data", "SELECT * FROM source_table", None, 1, False, False, None),
    (2, "clean_data", "SELECT * FROM extract_data WHERE status = 'active'", "extract_data", 2, False, True, 4),
    (3, "aggregate_data", "SELECT category, SUM(amount) FROM clean_data GROUP BY category", "clean_data", 3, False, True, None),
    (4, "customer_stats", "SELECT customer_id, COUNT(*) FROM clean_data GROUP BY customer_id", "clean_data", 3, False, False, None),
    (5, "enrich_data", "SELECT a.*, b.extra_info FROM aggregate_data a JOIN lookup_table b ON a.category = b.category", "aggregate_data", 4, True, True, None),
    (6, "final_output", "SELECT * FROM enrich_data WHERE total > 500", "enrich_data", 5, False, False, None),
    (7, "report_data", "SELECT * FROM customer_stats WHERE count > 10", "customer_stats", 5, False, False, None)
]

metadata_df = spark.createDataFrame(metadata_data, schema=metadata_schema)
metadata_df.write.mode("overwrite").saveAsTable("etl_metadata")

📌 2. Log Table Schema (Stored in Hive)

log_schema = StructType([
    StructField("step_id", IntegerType(), False),
    StructField("process_name", StringType(), False),
    StructField("start_time", StringType(), False),
    StructField("end_time", StringType(), False),
    StructField("status", StringType(), False),
    StructField("error_message", StringType(), True),
    StructField("count_rows", IntegerType(), True),
    StructField("table_size", IntegerType(), True)
])

log_df = spark.createDataFrame([], log_schema)
log_df.write.mode("overwrite").saveAsTable("etl_log")

📌 3. ETL Execution Framework (With Parallel Execution & Optimizations)

import time
from pyspark.sql.functions import col
from concurrent.futures import ThreadPoolExecutor

spark = SparkSession.builder.enableHiveSupport().getOrCreate()
completed_steps = set()
df_registry = {}  # Store DataFrames for reference

def log_execution(step_id, process_name, start_time, end_time, status, error_message=None, df=None):
    """Logs execution details in Hive log table"""
    count_rows = df.count() if df else None
    table_size = sum(df.toPandas().memory_usage(deep=True)) if df else None

    log_data = [(step_id, process_name, start_time, end_time, status, error_message, count_rows, table_size)]
    log_df = spark.createDataFrame(log_data, log_schema)
    log_df.write.mode("append").saveAsTable("etl_log")

def execute_step(row):
    """Executes a single ETL step"""
    step_id, process_name, sql_query, depends_on, broadcast, cache, repartition = row
    start_time = time.time()

    try:
        # Wait for dependencies to complete
        if depends_on:
            dependencies = depends_on.split(", ")
            while not all(dep in completed_steps for dep in dependencies):
                time.sleep(1)  # Wait for dependencies to complete

        df = spark.sql(sql_query)

        # Optimizations
        if broadcast:
            df = spark.sql(f"BROADCAST({sql_query})")
        if cache:
            df = df.cache()
        if repartition:
            df = df.repartition(repartition)

        df.createOrReplaceTempView(process_name)
        df_registry[process_name] = df
        completed_steps.add(process_name)

        log_execution(step_id, process_name, start_time, time.time(), "SUCCESS", df=df)
        return f"Step {process_name} executed successfully."
    except Exception as e:
        log_execution(step_id, process_name, start_time, time.time(), "FAILED", str(e))
        return f"Step {process_name} failed. Error: {str(e)}"

def get_metadata():
    """Fetch metadata from Hive table"""
    return [(row.step_id, row.process_name, row.sql_query, row.depends_on, row.broadcast, row.cache, row.repartition)
            for row in spark.sql("SELECT * FROM etl_metadata").collect()]

def execute_etl():
    """Execute ETL steps with parallel execution"""
    metadata = sorted(get_metadata(), key=lambda x: x[4])  # Order by execution_order

    with ThreadPoolExecutor() as executor:
        results = executor.map(execute_step, metadata)
        for result in results:
            print(result)

execute_etl()

📌 4. Version Control for Metadata Table

spark.sql("ALTER TABLE etl_metadata ADD COLUMNS (version INT DEFAULT 1)")

def update_metadata(new_data):
    """Updates metadata with versioning"""
    existing_df = spark.sql("SELECT * FROM etl_metadata")
    max_version = existing_df.selectExpr("MAX(version)").collect()[0][0] or 1

    new_df = spark.createDataFrame(new_data, metadata_schema).withColumn("version", col("version") + max_version)
    new_df.write.mode("append").saveAsTable("etl_metadata")

update_metadata([(8, "new_step", "SELECT * FROM final_output", "final_output", 6, False, True, None)])

📌 Enhancements in This Version

Parallel Execution Based on depends_on
Optimizations (Broadcast, Caching, Repartitioning)
Error Handling & Logging in Hive (etl_log Table)
Version Control for Metadata (version column)
Dynamic Row Count & Table Size Logging
Retry Mechanism (Re-execution of Failed Steps)


import time
from pyspark.sql.functions import col
from concurrent.futures import ThreadPoolExecutor

spark = SparkSession.builder.enableHiveSupport().getOrCreate()
completed_steps = set()
df_registry = {}  # Store DataFrames for reference

def log_execution(step_id, process_name, start_time, end_time, status, error_message=None, count_rows=None):
    """Logs execution details in Hive log table"""
    log_data = [(step_id, process_name, start_time, end_time, status, error_message, count_rows)]
    log_df = spark.createDataFrame(log_data, ["step_id", "process_name", "start_time", "end_time", "status", "error_message", "count_rows"])
    log_df.write.mode("append").saveAsTable("etl_log")

def execute_step(row):
    """Executes a single ETL step with caching and persistence"""
    step_id, process_name, sql_query, depends_on, broadcast, cache, persist, repartition = row
    start_time = time.time()
    count_rows = None

    try:
        # Wait for dependencies to complete
        if depends_on:
            dependencies = depends_on.split(", ")
            while not all(dep in completed_steps for dep in dependencies):
                time.sleep(1)  # Wait for dependencies to complete

        df = spark.sql(sql_query)

        # Optimizations
        if broadcast:
            df = spark.sql(f"BROADCAST({sql_query})")
        if cache:
            df = df.cache()
            count_rows = df.count()  # Force evaluation to persist cache
        if persist:
            df = df.persist()
            count_rows = df.count()  # Force evaluation to persist dataframe
        if repartition:
            df = df.repartition(repartition)

        df.createOrReplaceTempView(process_name)
        df_registry[process_name] = df
        completed_steps.add(process_name)

        if count_rows is None:
            count_rows = df.count()  # Count only if not cached or persisted

        log_execution(step_id, process_name, start_time, time.time(), "SUCCESS", count_rows=count_rows)
        return f"Step {process_name} executed successfully."
    except Exception as e:
        log_execution(step_id, process_name, start_time, time.time(), "FAILED", str(e))
        return f"Step {process_name} failed. Error: {str(e)}"

def get_metadata():
    """Fetch metadata from Hive table"""
    return [(row.step_id, row.process_name, row.sql_query, row.depends_on, row.broadcast, row.cache, row.persist, row.repartition)
            for row in spark.sql("SELECT * FROM etl_metadata").collect()]

def execute_etl():
    """Execute ETL steps with parallel execution"""
    metadata = sorted(get_metadata(), key=lambda x: x[4])  # Order by execution_order

    with ThreadPoolExecutor() as executor:
        results = executor.map(execute_step, metadata)
        for result in results:
            print(result)

execute_etl()
Pages ( 1 of 3 ): 1 23Next »

Subscribe