How would you find customers who have bought a product at least 3 times consecutively? – Table name: Orders – Columns: CustomerID, ProductID, OrderDate? Please provide solution in spark sql and in pyspark?
Question is Bit vague as most can read it as 3 times on consecutive Days,but if not mentioned First i will consider it as 3 consecutive orders even if it happened on same days.
We will use window functions but which one:-
Here Answer demands consecutive order but no mention of any constraint on Date. So Considering Orders might have happened on same day we should use Dense_Rank.
Here are the solutions in Spark SQL and PySpark:
Sample Data
Python
data = [
(1, 101, "2022-01-01"),
(1, 101, "2022-01-15"),
(1, 101, "2022-02-01"),
(1, 102, "2022-03-01"),
(2, 101, "2022-01-01"),
(2, 101, "2022-01-20"),
(3, 101, "2022-01-01"),
(3, 101, "2022-01-10"),
(3, 101, "2022-01-20"),
(3, 101, "2022-02-01")
]
columns = ["CustomerID", "ProductID", "OrderDate"]
df = spark.createDataFrame(data).toDF(*columns)
Spark SQL Solution
WITH RankedOrders AS (
SELECT CustomerID, ProductID, OrderDate,
DENSE_RANK() OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS OrderRank
FROM Orders
)
SELECT DISTINCT CustomerID
FROM RankedOrders
WHERE OrderRank >= 3
PySpark Solution
from pyspark.sql import Window
from pyspark.sql.functions import dense_rank, col
window = Window.partitionBy("CustomerID", "ProductID").orderBy("OrderDate")
ranked_orders = df.withColumn("OrderRank", dense_rank().over(window))
consecutive_customers = ranked_orders.filter(col("OrderRank") >= 3).select("CustomerID").distinct()
In both solutions, we use a window function to rank the orders for each customer and product. We then filter the results to find customers who have placed at least 3 consecutive orders.
Note that the Spark SQL solution uses the DENSE_RANK
function, while the PySpark solution uses the dense_rank
function from the pyspark.sql.functions
module.
Also, in the PySpark solution, we use the withColumn
method to add the OrderRank
column to the DataFrame.
But there can be a very interesting version of Code, if we have to consider consecutive days Orders.
CREATE TABLE Orders (
CustomerID INT,
ProductID INT,
OrderDate DATE
);
INSERT INTO Orders (CustomerID, ProductID, OrderDate)
VALUES
(1, 101, '2022-01-01'),
(1, 101, '2022-01-02'),
(1, 101, '2022-01-03'),
(1, 102, '2022-01-10'),
(2, 101, '2022-01-01'),
(2, 101, '2022-01-05'),
(3, 101, '2022-01-01'),
(3, 101, '2022-01-02'),
(3, 101, '2022-01-03'),
(3, 101, '2022-01-04');
WITH OrderedOrders AS (
SELECT CustomerID, ProductID, OrderDate,
DENSE_RANK() OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS DenseRank,
LAG(OrderDate, 1) OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS PrevOrderDate,
LAG(OrderDate, 2) OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS PrevPrevOrderDate
FROM Orders
)
SELECT *
FROM OrderedOrders
ORDER BY CustomerID, ProductID, OrderDate;
Result
CustomerID | ProductID | OrderDate | DenseRank | PrevOrderDate | PrevPrevOrderDate |
---|---|---|---|---|---|
1 | 101 | 2022-01-01 | 1 | NULL | NULL |
1 | 101 | 2022-01-02 | 2 | 2022-01-01 | NULL |
1 | 101 | 2022-01-03 | 3 | 2022-01-02 | 2022-01-01 |
1 | 102 | 2022-01-10 | 1 | NULL | NULL |
2 | 101 | 2022-01-01 | 1 | NULL | NULL |
2 | 101 | 2022-01-05 | 2 | 2022-01-01 | NULL |
3 | 101 | 2022-01-01 | 1 | NULL | NULL |
3 | 101 | 2022-01-02 | 2 | 2022-01-01 | NULL |
3 | 101 | 2022-01-03 | 3 | 2022-01-02 | 2022-01-01 |
3 | 101 | 2022-01-04 | 4 | 2022-01-03 | 2022-01-02 |
Note that:
- The
DENSE_RANK
function assigns a unique rank to each row within a partition of a result set. - The
LAG
function returns the value of a previous row within a partition of a result set. - The
PARTITION BY
clause divides the result set into partitions to which the function is applied. - The
ORDER BY
clause specifies the order of the rows within each partition.
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window
# Create a SparkSession
spark = SparkSession.builder.appName("Example").getOrCreate()
# Create a DataFrame
data = [
(1, 101, "2022-01-01"),
(1, 101, "2022-01-02"),
(1, 101, "2022-01-03"),
(1, 102, "2022-01-10"),
(2, 101, "2022-01-01"),
(2, 101, "2022-01-05"),
(3, 101, "2022-01-01"),
(3, 101, "2022-01-02"),
(3, 101, "2022-01-03"),
(3, 101, "2022-01-04")
]
columns = ["CustomerID", "ProductID", "OrderDate"]
df = spark.createDataFrame(data).toDF(*columns)
# Define the window
window = Window.partitionBy("CustomerID", "ProductID").orderBy("OrderDate")
# Calculate the dense rank, lag, and date differences
df = df.withColumn("DenseRank", F.dense_rank().over(window))
df = df.withColumn("PrevOrderDate", F.lag("OrderDate", 1).over(window))
df = df.withColumn("PrevPrevOrderDate", F.lag("OrderDate", 2).over(window))
df = df.withColumn("OrderDate_PrevOrderDate", F.datediff("OrderDate", F.col("PrevOrderDate")))
df = df.withColumn("OrderDate_PrevPrevOrderDate", F.datediff("OrderDate", F.col("PrevPrevOrderDate")))
# Show the results
df.show()
Here are the solutions in Spark SQL and PySpark:
Spark SQL Solution
SQL
WITH OrderedOrders AS (
SELECT CustomerID, ProductID, OrderDate,
LAG(OrderDate, 1) OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS PrevOrderDate,
LAG(OrderDate, 2) OVER (PARTITION BY CustomerID, ProductID ORDER BY OrderDate) AS PrevPrevOrderDate
FROM Orders
)
SELECT DISTINCT CustomerID
FROM OrderedOrders
WHERE OrderDate - PrevOrderDate <= INTERVAL 1 DAY
AND OrderDate - PrevPrevOrderDate <= INTERVAL 2 DAY
PySpark Solution
from pyspark.sql import Window
from pyspark.sql.functions import lag, col
window = Window.partitionBy("CustomerID", "ProductID").orderBy("OrderDate")
ordered_orders = df.withColumn("PrevOrderDate", lag("OrderDate", 1).over(window)) \
.withColumn("PrevPrevOrderDate", lag("OrderDate", 2).over(window))
consecutive_orders = ordered_orders.filter((col("OrderDate") - col("PrevOrderDate")).cast("long") <= 86400) \
.filter((col("OrderDate") - col("PrevPrevOrderDate")).cast("long") <= 172800)
consecutive_customers = consecutive_orders.select("CustomerID").distinct()
In both solutions, we use a window function to lag the OrderDate column by 1 and 2 rows, partitioned by CustomerID and ProductID. We then filter the results to find customers who have placed orders at least 3 times consecutively.
Note that the Spark SQL solution uses the LAG
function, while the PySpark solution uses the lag
function from the pyspark.sql.functions
module.
Also, in the PySpark solution, we cast the result of the subtraction to a long value to compare it with the number of seconds in a day (86400) or two days (172800).
How to find employees whose salary is greater than the average salary of employees in their respective location?
Table Name: Employee
Column Names: EmpID (Employee ID), Emp_name (Employee Name), Manager_id (Manager ID), Salary (Employee Salary), Location (Employee Location)
using spark sql or pyspark
Complete PySpark Solution with Sample Data
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import col, avg
# Initialize Spark session
spark = SparkSession.builder.master("local").appName("EmployeeSalaryExample").getOrCreate()
# Sample data
data = [
(1, "Alice", 10, 8000, "New York"),
(2, "Bob", 11, 12000, "New York"),
(3, "Charlie", 10, 7000, "Chicago"),
(4, "David", 12, 11000, "New York"),
(5, "Eve", 13, 6000, "Chicago")
]
# Define schema
columns = ["EmpID", "Emp_name", "Manager_id", "Salary", "Location"]
# Create DataFrame
employee_df = spark.createDataFrame(data, columns)
# Define a window partitioned by location
location_window = Window.partitionBy("Location")
# Step 1: Calculate the average salary for each location
employee_with_avg_salary = employee_df.withColumn(
"avg_salary", avg("Salary").over(location_window)
)
# Step 2: Filter employees whose salary is greater than the average salary
result = employee_with_avg_salary.filter(col("Salary") > col("avg_salary"))
# Show the result
result.select("EmpID", "Emp_name", "Manager_id", "Salary", "Location", "avg_salary").show()
Explanation of the Process:
- Spark Session: The Spark session is initialized with
spark = SparkSession.builder.master("local").appName("EmployeeSalaryExample").getOrCreate()
. - Sample Data: A list of tuples is created where each tuple represents an employee. Each tuple contains:
EmpID
: Employee IDEmp_name
: Employee NameManager_id
: Manager IDSalary
: Employee SalaryLocation
: Employee Location
- Schema: The column names are defined, and a DataFrame
employee_df
is created using the sample data and schema. - Window Function: The window is defined by partitioning the data by
Location
so that the average salary can be computed for each location. - Filter: We filter out employees whose salary is greater than the calculated average salary for their respective location.
- Show Result: Finally, the result is displayed showing the
EmpID
,Emp_name
,Manager_id
,Salary
,Location
, andavg_salary
.
Example Output:
+-----+--------+-----------+------+--------+----------+
|EmpID|Emp_name|Manager_id |Salary|Location|avg_salary|
+-----+--------+-----------+------+--------+----------+
| 2| Bob| 11| 12000| New York| 10000|
| 4| David| 12| 11000| New York| 10000|
+-----+--------+-----------+------+--------+----------+
Spark SQL Version
If you’d prefer to use Spark SQL, the following steps will also work. You can first register the DataFrame as a temporary table and then run the SQL query.
# Register the DataFrame as a temporary SQL table
employee_df.createOrReplaceTempView("Employee")
# Spark SQL to find employees with salary greater than average salary in their location
query = """
WITH LocationAvgSalary AS (
SELECT
EmpID,
Emp_name,
Manager_id,
Salary,
Location,
AVG(Salary) OVER (PARTITION BY Location) AS avg_salary
FROM Employee
)
SELECT
EmpID,
Emp_name,
Manager_id,
Salary,
Location,
avg_salary
FROM LocationAvgSalary
WHERE Salary > avg_salary
"""
# Run the SQL query and show the results
result_sql = spark.sql(query)
result_sql.show()
Both methods (PySpark with Window
functions and Spark SQL) will provide the same result.
aggregate functions can be applied with window functions in PySpark. These functions are used to compute aggregate values over a defined window or partition, and they do not require explicit grouping.
Supported Aggregate Functions for Window Operations
Some commonly used aggregate functions include:
avg
: Computes the average value.sum
: Computes the sum of values.count
: Counts the number of values.min
: Finds the minimum value.max
: Finds the maximum value.stddev
: Computes the standard deviation.first
: Gets the first value in the window.last
: Gets the last value in the window.collect_list
: Collects all values in the window as a list.collect_set
: Collects all distinct values in the window as a set.
General Syntax for Window Aggregate Functions
from pyspark.sql.window import Window
from pyspark.sql.functions import avg, sum, count, min, max
# Define a window
window_spec = Window.partitionBy("partition_column").orderBy("order_column")
# Apply aggregate function over the window
df.withColumn("new_column", aggregate_function("column").over(window_spec))
Examples
1. Average Salary by Department
Compute the average salary for each department and append it as a new column.
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import avg
# Initialize Spark session
spark = SparkSession.builder.master("local").appName("WindowFunctionsExample").getOrCreate()
# Sample data
data = [
(1, "Alice", "HR", 6000),
(2, "Bob", "HR", 7000),
(3, "Charlie", "IT", 8000),
(4, "David", "IT", 12000),
(5, "Eve", "Finance", 9000)
]
columns = ["EmpID", "Emp_name", "Department", "Salary"]
# Create DataFrame
df = spark.createDataFrame(data, columns)
# Define the window specification
window_spec = Window.partitionBy("Department")
# Add a column with the average salary per department
df = df.withColumn("avg_salary", avg("Salary").over(window_spec))
df.show()
Output:
+-----+--------+----------+------+----------+
|EmpID|Emp_name|Department|Salary|avg_salary|
+-----+--------+----------+------+----------+
| 1| Alice| HR| 6000| 6500.0|
| 2| Bob| HR| 7000| 6500.0|
| 3| Charlie| IT| 8000| 10000.0|
| 4| David| IT| 12000| 10000.0|
| 5| Eve| Finance| 9000| 9000.0|
+-----+--------+----------+------+----------+
2. Rank Employees by Salary Within Departments
Calculate rank based on salary within each department.
from pyspark.sql.functions import rank
# Define a window specification
window_spec = Window.partitionBy("Department").orderBy(col("Salary").desc())
# Add a rank column
df = df.withColumn("rank", rank().over(window_spec))
df.show()
Output:
+-----+--------+----------+------+----------+----+
|EmpID|Emp_name|Department|Salary|avg_salary|rank|
+-----+--------+----------+------+----------+----+
| 2| Bob| HR| 7000| 6500.0| 1|
| 1| Alice| HR| 6000| 6500.0| 2|
| 4| David| IT| 12000| 10000.0| 1|
| 3| Charlie| IT| 8000| 10000.0| 2|
| 5| Eve| Finance| 9000| 9000.0| 1|
+-----+--------+----------+------+----------+----+
3. Total Salary by Department
Calculate the total salary for each department.
from pyspark.sql.functions import sum
# Add a column with the total salary per department
df = df.withColumn("total_salary", sum("Salary").over(window_spec))
df.show()
Output:
+-----+--------+----------+------+----------+----+------------+
|EmpID|Emp_name|Department|Salary|avg_salary|rank|total_salary|
+-----+--------+----------+------+----------+----+------------+
| 2| Bob| HR| 7000| 6500.0| 1| 13000|
| 1| Alice| HR| 6000| 6500.0| 2| 13000|
| 4| David| IT| 12000| 10000.0| 1| 20000|
| 3| Charlie| IT| 8000| 10000.0| 2| 20000|
| 5| Eve| Finance| 9000| 9000.0| 1| 9000|
+-----+--------+----------+------+----------+----+------------+
Points to Remember
- Window Specification: You can define partitioning (
partitionBy
) and ordering (orderBy
) for your window. - Partition vs Aggregate:
- Without a window, aggregates work on the entire DataFrame.
- With a window, aggregates are calculated for each partition.
- Aggregate Functions: These work seamlessly with window specifications and are highly optimized in PySpark.