system/block_wrapper.py

276 lines
10 KiB
Python
Raw Normal View History

2025-03-11 00:43:40 +00:00
import importlib
import json
import asyncio
import logging
import os
import re
import sys
import requests
from temporalio import activity
from temporalio.exceptions import ApplicationError
from jsonschema import validate, ValidationError
from temporalio.client import Client
from temporalio.worker import Worker
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Automatically determine if in a test environment
IS_TEST_ENVIRONMENT = "unittest" in sys.modules
# Environment variables
REPO_NAME = os.getenv('REPO_NAME')
BRANCH_NAME = os.getenv('BRANCH_NAME')
COMMIT_ID = os.getenv('VERSION')
NAMESPACE = os.getenv('NAMESPACE')
FLOWX_ENGINE_ADDRESS = os.getenv('FLOWX_ENGINE_ADDRESS')
SQLPAD_API_URL = os.getenv('SQLPAD_API_URL')
if not BRANCH_NAME or not COMMIT_ID or not NAMESPACE or not FLOWX_ENGINE_ADDRESS:
raise ValueError("Missing required environment variables.")
COMMIT_ID_SHORT = COMMIT_ID[:10]
# Sanitize name function
def sanitize_name(name):
sanitized = re.sub(r'\W|^(?=\d)', '_', name)
sanitized = re.sub(r'_+', '_', sanitized)
return sanitized.strip('_')
BLOCK_NAME = REPO_NAME + "_" + BRANCH_NAME
block_name_safe = sanitize_name(BLOCK_NAME)
commit_id_safe = sanitize_name(COMMIT_ID_SHORT)
# Construct the task queue name
TASK_QUEUE = f"{block_name_safe}_{commit_id_safe}"
# Load JSON schema
def load_schema(schema_path):
try:
with open(schema_path, 'r') as schema_file:
return json.load(schema_file)
except Exception as e:
logger.error("Failed to load schema from %s: %s", schema_path, e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Schema loading failed: {e}")
else:
raise ValueError(f"Schema loading failed: {e}")
# Validate input against request schema
def validate_input(input_data):
request_schema = load_schema("/app/request_schema.json")
try:
validate(instance=input_data, schema=request_schema)
logger.info("Input data validated successfully")
except ValidationError as e:
logger.error("Input validation failed: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Input validation error: {e}")
else:
raise ValueError(f"Input validation error: {e}")
# Validate output against response schema
def validate_output(output_data):
response_schema = load_schema("/app/response_schema.json")
try:
validate(instance=output_data, schema=response_schema)
logger.info("Output data validated successfully")
except ValidationError as e:
logger.error("Output validation failed: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError(f"Output validation error: {e}")
else:
raise ValueError(f"Output validation error: {e}")
# Get the connection ID from config.json
def get_connection_id(namespace):
response_schema = load_schema("/app/config.json")
for item in response_schema:
if item.get("namespace") == namespace:
logger.info("Got the connectionID")
return item.get("connectionId")
logger.error("Provided Namespace not found.")
raise ValueError(f"Namespace '{namespace}' not found")
# Read SQL file and replace placeholders
def construct_sql(input_data):
try:
with open("/app/main.sql", "r") as sql_file:
sql_template = sql_file.read()
2025-04-07 18:17:30 +00:00
2025-03-11 00:43:40 +00:00
for key, value in input_data.items():
placeholder = f"${key}"
2025-04-07 17:55:23 +00:00
2025-04-07 18:17:30 +00:00
# Decide the final text (replacement) for SQL:
2025-04-07 17:55:23 +00:00
if value is None:
replacement = "NULL"
2025-03-11 00:43:40 +00:00
elif isinstance(value, bool):
2025-04-07 18:17:30 +00:00
replacement = "TRUE" if value else "FALSE"
elif isinstance(value, str):
# Escape quotes if needed, or just do naive single quote
replacement = f"'{value}'"
else:
replacement = str(value)
# Replace the placeholder with our final replacement string
sql_template = sql_template.replace(placeholder, replacement)
logger.info("SQL query constructed.")
2025-03-11 00:43:40 +00:00
return sql_template.strip()
except Exception as e:
logger.error("Error processing SQL template: %s", e)
raise ApplicationError(f"SQL template error: {e}")
def get_batch_results(batch_id, retry_interval=0.05, max_retries=5):
retries = 0
while retries < max_retries:
try:
response = requests.get(f"{SQLPAD_API_URL}/api/batches/{batch_id}")
response.raise_for_status()
batch_status = response.json()
status = batch_status.get("status")
if status in ["finished", "error"]:
statements = batch_status.get("statements", [])
if not statements:
raise ApplicationError("No statements found in batch response.")
statement = statements[0]
statement_id = statement.get("id")
error = statement.get("error")
columns = statement.get("columns", None)
sql_text = batch_status.get("batchText", "").strip().lower()
logger.info(f"statements: {statements}")
logger.info(f"error from batches result {error}, statement: {statement_id}, columns: {columns}")
if error:
raise ApplicationError(f"SQL execution failed: {error}")
# Check if query is WITH + SELECT or plain SELECT
is_select_query = sql_text.startswith("select") or (
sql_text.startswith("with") and "select" in sql_text
)
# Ensure SELECT queries always return columns
if is_select_query and not columns:
raise ApplicationError("SELECT query did not return columns, cannot process data.")
return status, statement_id, error, columns, is_select_query
time.sleep(retry_interval)
retries += 1
except requests.RequestException as e:
logger.error("Failed to fetch batch results: %s", e)
raise ApplicationError(f"Failed to fetch batch results: {e}")
raise ApplicationError("SQLPad batch execution timed out.")
# Execute SQL via SQLPad APIs
def execute_sqlpad_query(connection_id, sql_query):
payload = {
"connectionId": connection_id,
"name": "",
"batchText": sql_query,
"selectedText": ""
}
try:
# Step 1: Create batch
response = requests.post(f"{SQLPAD_API_URL}/api/batches", json=payload)
response.raise_for_status()
batch_response = response.json()
# Extract batch ID from the response
batch_id = batch_response.get("statements", [{}])[0].get("batchId")
logger.info(f"Batch ID from the batches API response {batch_id}")
if not batch_id:
raise ApplicationError("Batch ID not found in SQLPad response.")
# Step 2: Retrieve batch statement ID and determine if it's a SELECT query
status, statement_id, error, columns, is_select_query = get_batch_results(batch_id)
# If it's a non-SELECT query
if not is_select_query:
return {"status": status, "error": error}
# Step 3: Fetch statement results only for SELECT/CTE queries
result_response = requests.get(f"{SQLPAD_API_URL}/api/statements/{statement_id}/results")
result_response.raise_for_status()
result_data = result_response.json()
type_mapping = {
"number": float,
"string": str,
"date": str,
"boolean": bool,
"timestamp": str,
}
column_names_list = [col["name"] for col in columns]
column_types_list = [col["datatype"] for col in columns]
converted_data = [
[
type_mapping.get(dtype, str)(value) if value is not None else None
for dtype, value in zip(column_types_list, row)
]
for row in result_data
]
results_dict_list = [dict(zip(column_names_list, row)) for row in converted_data]
logger.info(f"results_dict_list: {results_dict_list}")
return {"results": results_dict_list}
except requests.RequestException as e:
logger.error("SQLPad API request failed: %s", e)
raise ApplicationError(f"SQLPad API request failed: {e}")
# Registering activity
@activity.defn
async def block_main_activity(input_data):
# Validate the input
validate_input(input_data)
try:
sql_query = construct_sql(input_data)
logger.info(f"constructed sql query: {sql_query}")
connection_id = get_connection_id(NAMESPACE)
if connection_id:
logger.info(f"connection id exists {connection_id}")
result = execute_sqlpad_query(connection_id, sql_query)
validate_output(result)
logger.info(f"final result for the query: {result}")
return result
else:
logger.error(f"connection id not exists, please add the connection id according to the namespace.")
raise ApplicationError("connection id not exists, please add the connection id according to the namespace.")
except Exception as e:
logger.error("Error executing query execution: %s", e)
if not IS_TEST_ENVIRONMENT:
raise ApplicationError("Error during query execution") from e
else:
raise RuntimeError("Error during query execution") from e
# Worker function
async def main():
try:
client = await Client.connect(FLOWX_ENGINE_ADDRESS, namespace=NAMESPACE)
worker = Worker(
client,
task_queue=TASK_QUEUE,
activities=[block_main_activity],
)
logger.info("Worker starting, listening to task queue: %s", TASK_QUEUE)
await worker.run()
except Exception as e:
logger.critical("Worker failed to start: %s", e)
raise
if __name__ == "__main__":
asyncio.run(main())