From 5eb04354a13b8dd8f9685a910f30c22d44857b35 Mon Sep 17 00:00:00 2001 From: admin user Date: Tue, 11 Mar 2025 00:43:40 +0000 Subject: [PATCH] Upload files to "/" --- Dockerfile | 17 +++ README.md | 8 +- block_wrapper.py | 265 ++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 + test_block_wrapper.py | 196 +++++++++++++++++++++++++++++++ 5 files changed, 488 insertions(+), 1 deletion(-) create mode 100644 Dockerfile create mode 100644 block_wrapper.py create mode 100644 requirements.txt create mode 100644 test_block_wrapper.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ef40d55 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +# Use Python slim image as base +FROM python:3.10-slim AS base + +# Set up a directory for the application code +WORKDIR /app + +# Copy only the requirements file initially for better caching +COPY requirements.txt . + +# Install Workflow SDK and other dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the application code +COPY . . + +# Set entrypoint for the worker +ENTRYPOINT ["python", "/app/block_wrapper.py"] diff --git a/README.md b/README.md index 59a3efc..62166db 100644 --- a/README.md +++ b/README.md @@ -1 +1,7 @@ -**Hello world!!!** +# Activity Block Wrapper + +### Example Usage with Docker +1. **Build the Base Image**: + ```bash + docker build -f Dockerfile.base -t activity_block_wrapper:latest . + diff --git a/block_wrapper.py b/block_wrapper.py new file mode 100644 index 0000000..818a07d --- /dev/null +++ b/block_wrapper.py @@ -0,0 +1,265 @@ +import importlib +import json +import asyncio +import logging +import os +import re +import sys +import requests +from temporalio import activity +from temporalio.exceptions import ApplicationError +from jsonschema import validate, ValidationError +from temporalio.client import Client +from temporalio.worker import Worker +import time + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# Automatically determine if in a test environment +IS_TEST_ENVIRONMENT = "unittest" in sys.modules + +# Environment variables +REPO_NAME = os.getenv('REPO_NAME') +BRANCH_NAME = os.getenv('BRANCH_NAME') +COMMIT_ID = os.getenv('VERSION') +NAMESPACE = os.getenv('NAMESPACE') +FLOWX_ENGINE_ADDRESS = os.getenv('FLOWX_ENGINE_ADDRESS') +SQLPAD_API_URL = os.getenv('SQLPAD_API_URL') + + +if not BRANCH_NAME or not COMMIT_ID or not NAMESPACE or not FLOWX_ENGINE_ADDRESS: + raise ValueError("Missing required environment variables.") + +COMMIT_ID_SHORT = COMMIT_ID[:10] + +# Sanitize name function +def sanitize_name(name): + sanitized = re.sub(r'\W|^(?=\d)', '_', name) + sanitized = re.sub(r'_+', '_', sanitized) + return sanitized.strip('_') + +BLOCK_NAME = REPO_NAME + "_" + BRANCH_NAME +block_name_safe = sanitize_name(BLOCK_NAME) +commit_id_safe = sanitize_name(COMMIT_ID_SHORT) + +# Construct the task queue name +TASK_QUEUE = f"{block_name_safe}_{commit_id_safe}" + +# Load JSON schema +def load_schema(schema_path): + try: + with open(schema_path, 'r') as schema_file: + return json.load(schema_file) + except Exception as e: + logger.error("Failed to load schema from %s: %s", schema_path, e) + if not IS_TEST_ENVIRONMENT: + raise ApplicationError(f"Schema loading failed: {e}") + else: + raise ValueError(f"Schema loading failed: {e}") + +# Validate input against request schema +def validate_input(input_data): + request_schema = load_schema("/app/request_schema.json") + try: + validate(instance=input_data, schema=request_schema) + logger.info("Input data validated successfully") + except ValidationError as e: + logger.error("Input validation failed: %s", e) + if not IS_TEST_ENVIRONMENT: + raise ApplicationError(f"Input validation error: {e}") + else: + raise ValueError(f"Input validation error: {e}") + +# Validate output against response schema +def validate_output(output_data): + response_schema = load_schema("/app/response_schema.json") + try: + validate(instance=output_data, schema=response_schema) + logger.info("Output data validated successfully") + except ValidationError as e: + logger.error("Output validation failed: %s", e) + if not IS_TEST_ENVIRONMENT: + raise ApplicationError(f"Output validation error: {e}") + else: + raise ValueError(f"Output validation error: {e}") +# Get the connection ID from config.json +def get_connection_id(namespace): + response_schema = load_schema("/app/config.json") + for item in response_schema: + if item.get("namespace") == namespace: + logger.info("Got the connectionID") + return item.get("connectionId") + logger.error("Provided Namespace not found.") + raise ValueError(f"Namespace '{namespace}' not found") + +# Read SQL file and replace placeholders +def construct_sql(input_data): + try: + with open("/app/main.sql", "r") as sql_file: + sql_template = sql_file.read() + for key, value in input_data.items(): + placeholder = f"${key}" + if isinstance(value, str): + value = f"'{value}'" + elif isinstance(value, bool): + value = "TRUE" if value else "FALSE" + sql_template = sql_template.replace(placeholder, str(value)) + logger.info(f"SQL query constructed.") + return sql_template.strip() + except Exception as e: + logger.error("Error processing SQL template: %s", e) + raise ApplicationError(f"SQL template error: {e}") + +def get_batch_results(batch_id, retry_interval=0.05, max_retries=5): + retries = 0 + + while retries < max_retries: + try: + response = requests.get(f"{SQLPAD_API_URL}/api/batches/{batch_id}") + response.raise_for_status() + batch_status = response.json() + + status = batch_status.get("status") + if status in ["finished", "error"]: + statements = batch_status.get("statements", []) + + if not statements: + raise ApplicationError("No statements found in batch response.") + + statement = statements[0] + statement_id = statement.get("id") + error = statement.get("error") + columns = statement.get("columns", None) + sql_text = batch_status.get("batchText", "").strip().lower() + logger.info(f"statements: {statements}") + logger.info(f"error from batches result {error}, statement: {statement_id}, columns: {columns}") + + if error: + raise ApplicationError(f"SQL execution failed: {error}") + + # Check if query is WITH + SELECT or plain SELECT + is_select_query = sql_text.startswith("select") or ( + sql_text.startswith("with") and "select" in sql_text + ) + + # Ensure SELECT queries always return columns + if is_select_query and not columns: + raise ApplicationError("SELECT query did not return columns, cannot process data.") + + return status, statement_id, error, columns, is_select_query + + time.sleep(retry_interval) + retries += 1 + except requests.RequestException as e: + logger.error("Failed to fetch batch results: %s", e) + raise ApplicationError(f"Failed to fetch batch results: {e}") + + raise ApplicationError("SQLPad batch execution timed out.") + +# Execute SQL via SQLPad APIs +def execute_sqlpad_query(connection_id, sql_query): + payload = { + "connectionId": connection_id, + "name": "", + "batchText": sql_query, + "selectedText": "" + } + + try: + # Step 1: Create batch + response = requests.post(f"{SQLPAD_API_URL}/api/batches", json=payload) + response.raise_for_status() + batch_response = response.json() + + # Extract batch ID from the response + batch_id = batch_response.get("statements", [{}])[0].get("batchId") + logger.info(f"Batch ID from the batches API response {batch_id}") + if not batch_id: + raise ApplicationError("Batch ID not found in SQLPad response.") + + # Step 2: Retrieve batch statement ID and determine if it's a SELECT query + status, statement_id, error, columns, is_select_query = get_batch_results(batch_id) + + # If it's a non-SELECT query + if not is_select_query: + return {"status": status, "error": error} + + # Step 3: Fetch statement results only for SELECT/CTE queries + result_response = requests.get(f"{SQLPAD_API_URL}/api/statements/{statement_id}/results") + result_response.raise_for_status() + result_data = result_response.json() + + type_mapping = { + "number": float, + "string": str, + "date": str, + "boolean": bool, + "timestamp": str, + } + + column_names_list = [col["name"] for col in columns] + column_types_list = [col["datatype"] for col in columns] + converted_data = [ + [ + type_mapping.get(dtype, str)(value) if value is not None else None + for dtype, value in zip(column_types_list, row) + ] + for row in result_data + ] + results_dict_list = [dict(zip(column_names_list, row)) for row in converted_data] + logger.info(f"results_dict_list: {results_dict_list}") + + return {"results": results_dict_list} + + except requests.RequestException as e: + logger.error("SQLPad API request failed: %s", e) + raise ApplicationError(f"SQLPad API request failed: {e}") + +# Registering activity +@activity.defn +async def block_main_activity(input_data): + # Validate the input + validate_input(input_data) + try: + sql_query = construct_sql(input_data) + logger.info(f"constructed sql query: {sql_query}") + connection_id = get_connection_id(NAMESPACE) + if connection_id: + logger.info(f"connection id exists {connection_id}") + result = execute_sqlpad_query(connection_id, sql_query) + validate_output(result) + logger.info(f"final result for the query: {result}") + return result + else: + logger.error(f"connection id not exists, please add the connection id according to the namespace.") + raise ApplicationError("connection id not exists, please add the connection id according to the namespace.") + + except Exception as e: + logger.error("Error executing query execution: %s", e) + if not IS_TEST_ENVIRONMENT: + raise ApplicationError("Error during query execution") from e + else: + raise RuntimeError("Error during query execution") from e + +# Worker function +async def main(): + try: + client = await Client.connect(FLOWX_ENGINE_ADDRESS, namespace=NAMESPACE) + worker = Worker( + client, + task_queue=TASK_QUEUE, + activities=[block_main_activity], + ) + logger.info("Worker starting, listening to task queue: %s", TASK_QUEUE) + await worker.run() + except Exception as e: + logger.critical("Worker failed to start: %s", e) + raise + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b495ad6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +temporalio==1.6.0 +jsonschema==4.23.0 +requests==2.32.3 \ No newline at end of file diff --git a/test_block_wrapper.py b/test_block_wrapper.py new file mode 100644 index 0000000..aed8192 --- /dev/null +++ b/test_block_wrapper.py @@ -0,0 +1,196 @@ +import unittest +from unittest.mock import patch, MagicMock, mock_open +import json +import asyncio +from jsonschema import ValidationError +import os +with patch.dict('os.environ', { + "REPO_NAME": "test_repo", + "BRANCH_NAME": "test_branch", + "VERSION": "test_version", + "NAMESPACE": "test_namespace", + "FLOWX_ENGINE_ADDRESS": "test_address" +}): + from block_wrapper import block_main_activity, validate_input, validate_output , construct_sql, get_connection_id + + +class TestBlockWrapper(unittest.TestCase): + + def setUp(self): + # Mock schemas to use for testing + self.mock_request_schema = { + "type": "object", + "properties": { + "salary": {"type": "number"}, + "department": {"type": "string"} + }, + "required": ["salary", "department"] + } + + self.mock_response_schema = { + "type": "object", + "$schema": "http://json-schema.org/draft-07/schema", + "properties": { + "id": { + "type": "integer" + }, + "first_name": { + "type": "string" + }, + "last_name": { + "type": "string" + }, + "email": { + "type": "string", + "format": "email" + }, + "phone_number": { + "type": "string" + }, + "hire_date": { + "type": "string", + "format": "date-time" + }, + "job_title": { + "type": "string" + }, + "salary": { + "type": "number" + }, + "department": { + "type": "string" + } + }, + "required": ["id","first_name","last_name","email","phone_number","hire_date","job_title","salary","department"] + } + + self.mock_config_schema = [ + { + "namespace": "staging", + "connectionId": "8d7341b4-53a5-41b8-8c9d-5133fafb5d7b" + }, + { + "namespace": "production", + "connectionId": "4b1437d8-53a5-41b8-8c9d-5133fafbtyuu" + } + ] + + self.mock_main_sql = "SELECT * FROM public.employee WHERE salary=$salary and department=$department;" + + + # Mock the contents of request_schema.json and response_schema.json using different patchers + self.mock_open_request = mock_open(read_data=json.dumps(self.mock_request_schema)) + self.mock_open_response = mock_open(read_data=json.dumps(self.mock_response_schema)) + self.mock_open_config = mock_open(read_data=json.dumps(self.mock_config_schema)) + self.mock_open_main_sql = mock_open(read_data=self.mock_main_sql) + + self.open_main_sql_patcher = patch("builtins.open", self.mock_open_main_sql) + self.open_main_sql_patcher.start() + + # Mock load_block_main to return a mock main function + self.load_block_main_patcher = patch("block_wrapper.execute_sqlpad_query", return_value=MagicMock(return_value={"id": 4, "first_name": "Bob", "last_name": "Brown", "email": "bob.brown@example.com", "phone_number": "444-222-1111", "hire_date": "2020-07-25", "job_title": "Marketing Specialist", "salary": 60000.00, "department": "Marketing"})) + self.mock_load_block_main = self.load_block_main_patcher.start() + + def tearDown(self): + # Stop all patches + self.load_block_main_patcher.stop() + + @patch("block_wrapper.load_schema") + def test_validate_input_success(self, mock_load_schema): + # Set up load_schema to return request schema for validate_input + mock_load_schema.return_value = self.mock_request_schema + input_data = {"salary": 20000.0, "department": "Marketing"} + validate_input(input_data) # Should pass without errors + + @patch("block_wrapper.load_schema") + def test_validate_input_failure(self, mock_load_schema): + # Set up load_schema to return request schema for validate_input + mock_load_schema.return_value = self.mock_request_schema + input_data = {"salary": 20000.00} # Missing 'department' + with self.assertRaises(ValueError): + validate_input(input_data) + + @patch("block_wrapper.load_schema") + def test_validate_output_success(self, mock_load_schema): + # Set up load_schema to return response schema for validate_output + mock_load_schema.return_value = self.mock_response_schema + output_data = {"id": 4, "first_name": "Bob", "last_name": "Brown", "email": "bob.brown@example.com", "phone_number": "444-222-1111", "hire_date": "2020-07-25", "job_title": "Marketing Specialist", "salary": 60000.00, "department": "Marketing"} + validate_output(output_data) # Should pass without errors + + @patch("block_wrapper.load_schema") + def test_validate_output_failure(self, mock_load_schema): + # Set up load_schema to return response schema for validate_output + mock_load_schema.return_value = self.mock_response_schema + output_data = {"id": 4, "first_name": "Bob", "last_name": "Brown", "email": "bob.brown@example.com", "phone_number": "444-222-1111", "hire_date": "2020-07-25", "job_title": "Marketing Specialist", "salary": 60000.00} # Missing 'department' + with self.assertRaises(ValueError): + validate_output(output_data) + + @patch("block_wrapper.load_schema") + async def test_block_main_activity_success(self, mock_load_schema): + # Set up load_schema to return request and response schemas in order + mock_load_schema.side_effect = [self.mock_request_schema, self.mock_response_schema] + input_data = {"salary": 20000.0, "department": "Marketing"} + result = await block_main_activity(input_data) + self.assertEqual(result, {"id": 4, "first_name": "Bob", "last_name": "Brown", "email": "bob.brown@example.com", "phone_number": "444-222-1111", "hire_date": "2020-07-25", "job_title": "Marketing Specialist", "salary": 60000.00, "department": "Marketing"}) + + @patch("block_wrapper.load_schema") + async def test_block_main_activity_failure(self, mock_load_schema): + # Set up load_schema to return request and response schemas in order + mock_load_schema.side_effect = [self.mock_request_schema, self.mock_response_schema] + # Cause an exception in main function + self.mock_load_block_main.side_effect = Exception("Unexpected error") + input_data = {"salary": 20000.0, "department": "Marketing"} + with self.assertRaises(RuntimeError): + await block_main_activity(input_data) + + @patch("block_wrapper.load_schema") + async def test_block_main_activity_input_validation_failure(self, mock_load_schema): + # Mock validate_input to raise ValidationError + with patch("block_wrapper.validate_input", side_effect=ValidationError("Invalid input")): + input_data = {"salary": 20000.00} # Missing 'department' + with self.assertRaises(ValueError): + await block_main_activity(input_data) + + + @patch("block_wrapper.load_schema") + async def test_block_main_activity_output_validation_failure(self, mock_load_schema): + # Mock validate_output to raise ValidationError + with patch("block_wrapper.validate_output", side_effect=ValidationError("Invalid output")): + input_data = {"salary": 20000.0, "department": "Marketing"} + with self.assertRaises(ValueError): + await block_main_activity(input_data) + + @patch.dict(os.environ, {"NAMESPACE": "staging"}) + @patch("block_wrapper.load_schema") + def test_get_connection_id_staging(self, mock_load_schema): + """Test fetching connectionId for 'staging' namespace""" + mock_load_schema.return_value = self.mock_config_schema + connection_id = get_connection_id(os.environ["NAMESPACE"]) + self.assertEqual(connection_id, "8d7341b4-53a5-41b8-8c9d-5133fafb5d7b") + + @patch.dict(os.environ, {"NAMESPACE": "production"}) + @patch("block_wrapper.load_schema") + def test_get_connection_id_production(self, mock_load_schema): + """Test fetching connectionId for 'production' namespace""" + mock_load_schema.return_value = self.mock_config_schema + connection_id = get_connection_id(os.environ["NAMESPACE"]) + self.assertEqual(connection_id, "4b1437d8-53a5-41b8-8c9d-5133fafbtyuu") + + @patch("block_wrapper.load_schema") + def test_get_connection_id_invalid_namespace(self, mock_load_schema): + """Test handling of invalid namespace""" + mock_load_schema.return_value = self.mock_config_schema + with self.assertRaises(ValueError) as context: + get_connection_id("development") + self.assertIn("Namespace 'development' not found", str(context.exception)) + + @patch("block_wrapper.load_schema") + def test_valid_sql_replacement(self, mock_load_schema): + mock_load_schema.return_value = self.mock_main_sql + input_data = {"salary": 20000.0, "department": "Marketing"} + expected_sql = "SELECT * FROM public.employee WHERE salary=20000.0 and department='Marketing';" + result = construct_sql(input_data) + self.assertEqual(result, expected_sql) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file