ETL framework for Dynamic Pyspark SQL Api Code Execution

import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime

spark = SparkSession.builder.enableHiveSupport().getOrCreate()
completed_steps = set()
df_registry = {}  # Store DataFrames for reference

def log_execution(step_id, process_name, start_time, end_time, status, error_message=None, count_rows=None):
    """Logs execution details in Hive log table"""
    log_data = [(step_id, process_name, start_time, end_time, status, error_message, count_rows)]
    log_df = spark.createDataFrame(log_data, ["step_id", "process_name", "start_time", "end_time", "status", "error_message", "count_rows"])
    log_df.write.mode("append").saveAsTable("etl_log")

def write_final_table(df, table_name, mode):
    """Writes the final table in overwrite, snapshot, or append mode"""
    if mode == "overwrite":
        df.write.mode("overwrite").saveAsTable(table_name)
    elif mode == "snapshot":
        timestamp = datetime.now().strftime("%Y%m%d%H%M")
        snapshot_table = f"{table_name}_{timestamp}"
        df.write.mode("overwrite").saveAsTable(snapshot_table)
        spark.sql(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM {snapshot_table}")
    elif mode == "append":
        df.write.mode("append").saveAsTable(table_name)

def execute_step(row):
    """Executes a single ETL step with caching and persistence"""
    step_id, process_name, sql_query, depends_on, broadcast, cache, persist, repartition, final_table, write_mode = row
    start_time = time.time()
    count_rows = None

    try:
        # Wait for dependencies to complete
        if depends_on:
            dependencies = depends_on.split(", ")
            while not all(dep in completed_steps for dep in dependencies):
                time.sleep(1)  # Wait for dependencies to complete

        df = spark.sql(sql_query)

        # Optimizations
        if broadcast:
            df = spark.sql(f"BROADCAST({sql_query})")
        if cache:
            df = df.cache()
            count_rows = df.count()  # Force evaluation to persist cache
        if persist:
            df = df.persist()
            count_rows = df.count()  # Force evaluation to persist dataframe
        if repartition:
            df = df.repartition(repartition)

        df.createOrReplaceTempView(process_name)
        df_registry[process_name] = df
        completed_steps.add(process_name)

        if count_rows is None:
            count_rows = df.count()  # Count only if not cached or persisted

        # Write to final table if applicable
        if final_table:
            write_final_table(df, final_table, write_mode)

        log_execution(step_id, process_name, start_time, time.time(), "SUCCESS", count_rows=count_rows)
        return f"Step {process_name} executed successfully."
    except Exception as e:
        log_execution(step_id, process_name, start_time, time.time(), "FAILED", str(e))
        return f"Step {process_name} failed. Error: {str(e)}"

def get_metadata():
    """Fetch metadata from Hive table"""
    return [(row.step_id, row.process_name, row.sql_query, row.depends_on, row.broadcast, row.cache, row.persist, row.repartition, row.final_table, row.write_mode)
            for row in spark.table("etl_metadata").collect()]

def execute_etl():
    """Execute ETL steps with parallel execution"""
    metadata = sorted(get_metadata(), key=lambda x: x[4])  # Order by execution_order

    with ThreadPoolExecutor() as executor:
        results = executor.map(execute_step, metadata)
        for result in results:
            print(result)

# Create Metadata Table Using DataFrame with Sample Data
def create_metadata_table():
    metadata_schema = ["step_id", "process_name", "sql_query", "depends_on", "broadcast", "cache", "persist", "repartition", "final_table", "write_mode"]
    metadata_data = [
        (1, "load_data", "SELECT * FROM source_table", None, False, False, False, None, "processed_data", "overwrite"),
        (2, "aggregate_data", "SELECT col1, COUNT(*) FROM processed_data GROUP BY col1", "load_data", True, True, False, None, "aggregated_data", "snapshot"),
        (3, "filter_data", "SELECT * FROM aggregated_data WHERE col1 > 10", "aggregate_data", False, True, False, None, "final_output", "append")
    ]
    metadata_df = spark.createDataFrame(metadata_data, metadata_schema)
    metadata_df.write.mode("overwrite").saveAsTable("etl_metadata")

# Create Log Table Using DataFrame
def create_log_table():
    log_schema = ["step_id", "process_name", "start_time", "end_time", "status", "error_message", "count_rows"]
    log_data = []  # Empty initial structure
    log_df = spark.createDataFrame(log_data, log_schema)
    log_df.write.mode("overwrite").saveAsTable("etl_log")

# Version Control Metadata Table
def version_metadata_table():
    timestamp = datetime.now().strftime("%Y%m%d%H%M")
    snapshot_table = f"etl_metadata_{timestamp}"
    spark.sql(f"CREATE TABLE {snapshot_table} AS SELECT * FROM etl_metadata")
    print(f"Metadata snapshot created: {snapshot_table}")

create_metadata_table()
create_log_table()
version_metadata_table()
execute_etl()

Pages ( 3 of 3 ): « Previous12 3