šŸ” Code Extractor

class TwoPassSqlWorkflow

Maturity: 26

Two-pass SQL generation workflow with iteration and error correction

File:
/tf/active/vicechatdev/full_smartstat/two_pass_sql_workflow.py
Lines:
47 - 856
Complexity:
moderate

Purpose

Two-pass SQL generation workflow with iteration and error correction

Source Code

class TwoPassSqlWorkflow:
    """Two-pass SQL generation workflow with iteration and error correction"""
    
    def __init__(self, schema_discovery, data_processor, statistical_agent=None):
        self.schema_discovery = schema_discovery
        self.data_processor = data_processor
        self.statistical_agent = statistical_agent
        self.max_iterations = 3
        
    def generate_sql_with_iterations(self, user_request: str, max_rows: int = 50000) -> Dict[str, Any]:
        """
        Main workflow: Generate SQL through two-pass approach with iterations
        """
        workflow_results = {
            'user_request': user_request,
            'iterations': [],
            'final_success': False,
            'final_sql': None,
            'final_data': None,
            'workflow_summary': None
        }
        
        logger.info(f"Starting two-pass SQL workflow for request: {user_request}")
        
        # Initialize user preferences (not supported in this method signature yet)
        preferred_tables = None
        preferred_columns = None
        
        # Get current schema
        discovered_schema = self.schema_discovery.discover_schema()
        
        # Extract table names from schema - handle both formats
        available_tables = []
        if hasattr(discovered_schema, 'tables') and discovered_schema.tables:
            if isinstance(discovered_schema.tables, list):
                # Tables is a list of TableInfo objects
                available_tables = [table.name for table in discovered_schema.tables]
            elif isinstance(discovered_schema.tables, dict):
                # Tables is a dictionary
                available_tables = list(discovered_schema.tables.keys())
        
        if not available_tables:
            # Fallback to columns_by_table format
            schema_dict = discovered_schema.to_dict()
            available_tables = list(schema_dict.get('columns_by_table', {}).keys())
        
        logger.info(f"Available tables for selection: {len(available_tables)}")
        
        previous_errors = []
        
        for iteration in range(1, self.max_iterations + 1):
            logger.info(f"Starting iteration {iteration}")
            
            try:
                # Pass 1: Table Selection (with user preferences)
                table_selection = self._pass1_select_tables(
                    user_request, available_tables, previous_errors, preferred_tables
                )
                
                # Pass 2: Generate SQL query (with user preferences)
                sql_generation = self._pass2_generate_sql(
                    user_request, table_selection, discovered_schema, max_rows, previous_errors, preferred_columns
                )
                
                # Test SQL execution
                execution_result = self._test_sql_execution(sql_generation.sql_query)
                
                iteration_result = IterationResult(
                    iteration_number=iteration,
                    table_selection=table_selection,
                    sql_generation=sql_generation,
                    execution_success=execution_result['success'],
                    error_message=execution_result.get('error'),
                    execution_time=execution_result.get('execution_time'),
                    row_count=execution_result.get('row_count')
                )
                
                workflow_results['iterations'].append(iteration_result)
                
                if execution_result['success']:
                    # Success! Finalize results
                    workflow_results['final_success'] = True
                    workflow_results['final_sql'] = sql_generation.sql_query
                    workflow_results['final_data'] = execution_result.get('data')
                    workflow_results['workflow_summary'] = self._create_workflow_summary(workflow_results)
                    
                    logger.info(f"Workflow completed successfully in iteration {iteration}")
                    return workflow_results
                else:
                    # Add error to previous errors for next iteration
                    error_info = {
                        'iteration': iteration,
                        'error': execution_result['error'],
                        'sql_query': sql_generation.sql_query,
                        'tables_used': sql_generation.tables_used,
                        'columns_used': sql_generation.columns_used
                    }
                    previous_errors.append(error_info)
                    logger.warning(f"Iteration {iteration} failed: {execution_result['error']}")
                    
            except Exception as e:
                logger.error(f"Iteration {iteration} failed with exception: {str(e)}")
                previous_errors.append({
                    'iteration': iteration,
                    'error': str(e),
                    'sql_query': None,
                    'tables_used': [],
                    'columns_used': {}
                })
        
        # All iterations failed
        workflow_results['workflow_summary'] = self._create_workflow_summary(workflow_results)
        logger.error(f"Workflow failed after {self.max_iterations} iterations")
        return workflow_results
    
    def _pass1_select_tables(self, user_request: str, available_tables: List[str], previous_errors: List[Dict], 
                           preferred_tables: List[str] = None) -> TableSelectionResult:
        """
        Pass 1: Use LLM to select relevant tables based on user request
        """
        logger.info("Pass 1: Selecting relevant tables")
        
        # Create context for table selection with user preferences
        table_context = self._create_table_selection_context(available_tables, previous_errors, preferred_tables)
        
        prompt = f"""You are a database expert analyzing a laboratory information management system (LIMS).

TASK: Select the most relevant tables needed to answer this user request.

USER REQUEST: {user_request}

AVAILABLE TABLES:
{table_context}

PREVIOUS ITERATION ERRORS (if any):
{self._format_previous_errors(previous_errors)}

INSTRUCTIONS:
1. Analyze the user request to understand what data is needed
2. Select 3-8 most relevant tables that contain the required data
3. **PRIORITIZE HIGH-VOLUME TABLES** - Tables with millions of rows contain the main data
4. For laboratory results, prefer tables with >1M rows over small lookup tables
5. Avoid selecting too many tables (causes complex joins)
6. Consider typical laboratory workflow relationships between tables
7. Read the row counts carefully - select tables with substantial data

RESPOND WITH JSON:
{{
    "selected_tables": ["Table1", "Table2", "Table3"],
    "reasoning": "Explanation of why these tables were selected based on the user request",
    "confidence": 0.85,
    "suggested_joins": [
        "Description of likely join relationships between selected tables"
    ]
}}"""

        if self.statistical_agent:
            try:
                response = self.statistical_agent.query_llm(prompt, model='claude-sonnet-4-5-20250929')
                logger.debug(f"Raw LLM response for table selection: {response[:200]}...")
                
                # Try to extract JSON from response using robust parsing
                result_data = self._extract_json_from_response(response)
                
                if result_data and 'selected_tables' in result_data:
                    return TableSelectionResult(
                        selected_tables=result_data.get('selected_tables', []),
                        reasoning=result_data.get('reasoning', ''),
                        confidence=result_data.get('confidence', 0.5),
                        suggested_joins=result_data.get('suggested_joins', [])
                    )
                else:
                    logger.warning("No valid JSON found in LLM response, using fallback")
                    return self._fallback_table_selection(user_request, available_tables)
            except Exception as e:
                logger.error(f"LLM table selection failed: {e}")
                return self._fallback_table_selection(user_request, available_tables)
        else:
            return self._fallback_table_selection(user_request, available_tables)
    
    def _pass2_generate_sql(self, user_request: str, table_selection: TableSelectionResult, 
                           discovered_schema, max_rows: int, previous_errors: List[Dict], 
                           preferred_columns: List[str] = None) -> SqlGenerationResult:
        """
        Pass 2: Generate SQL with detailed schema information for selected tables
        """
        logger.info(f"Pass 2: Generating SQL for selected tables: {table_selection.selected_tables}")
        
        # Get detailed schema information for selected tables only
        detailed_schema = self._get_detailed_schema_for_tables(table_selection.selected_tables, discovered_schema)
        
        # Get discovered relationships for selected tables
        discovered_relationships = self._get_relationships_for_tables(table_selection.selected_tables, discovered_schema)
        
        prompt = f"""You are a SQL expert generating queries for a laboratory information management system (LIMS).

TASK: Generate a precise SQL query to answer the user request using ONLY the provided tables and columns.

USER REQUEST: {user_request}

SELECTED TABLES FROM PASS 1:
{table_selection.reasoning}

COMPLETE SCHEMA FOR SELECTED TABLES (ALL AVAILABLE COLUMNS):
{detailed_schema}

ļæ½ CRITICAL CONSTRAINT: You MUST use ONLY the column names listed above. Every column name in your query MUST appear exactly as shown in the schema above. Do NOT assume any column exists that is not explicitly listed.

šŸ”— DISCOVERED RELATIONSHIPS (Use these for JOINs):
{discovered_relationships}

šŸ”— PROVEN WORKING RELATIONSHIPS (Fallback if discovered relationships insufficient):
- Requests → Samples: Samples.Sample_Request = Requests.Id
- Samples → Results: Results.Result_Sample = Samples.Id  
- Results → Analyses: Results.Result_Analysis = Analyses.Id

PREVIOUS ITERATION ERRORS (if any):
{self._format_previous_errors(previous_errors)}

SUGGESTED JOINS FROM PASS 1:
{chr(10).join(table_selection.suggested_joins)}

{self._format_preferred_columns_context(preferred_columns, detailed_schema)}

CRITICAL REQUIREMENTS:
1. šŸ” COLUMN VERIFICATION: Before writing any column name, verify it exists in the schema above
2. 🚫 NO ASSUMPTIONS: Do not use any column name not explicitly shown in the schema
3. šŸ·ļø USE ALIASES: Use table aliases (r for Requests, s for Samples, res for Results, a for Analyses)
4. šŸ“Š LIMIT RESULTS: Include TOP {max_rows} to limit results
5. šŸ”— PROPER JOINS: Use the proven relationships listed above
6. āœ… VALIDATED DATA: Include WHERE conditions for TechValidated = 1 AND BioValidated = 1 (if Results table is used)
7. šŸ“… DATE FILTERING: Use DATEADD(MONTH, -N, GETDATE()) for date filtering
8. šŸ“‹ MEANINGFUL ORDER: Include ORDER BY for logical result ordering

SQL SERVER SYNTAX:
- Use single quotes for string literals
- Use proper CASE statements
- Use COALESCE for null handling
- Use DATEADD for date arithmetic

RESPOND WITH JSON (verify each column name against the schema before including it):
{{
    "sql_query": "SELECT TOP {max_rows} ... FROM ... WHERE ...",
    "explanation": "Detailed explanation of the query logic",
    "confidence": 0.90,
    "tables_used": ["Table1", "Table2"],
    "columns_used": {{"Table1": ["col1", "col2"], "Table2": ["col3", "col4"]}}
}}"""

        if self.statistical_agent:
            try:
                response = self.statistical_agent.query_llm(prompt, model='claude-sonnet-4-5-20250929')
                logger.debug(f"Raw LLM response for SQL generation: {response[:200]}...")
                
                # Try to extract JSON from response using robust parsing
                result_data = self._extract_json_from_response(response)
                
                if result_data and 'sql_query' in result_data:
                    return SqlGenerationResult(
                        sql_query=result_data.get('sql_query', ''),
                        explanation=result_data.get('explanation', ''),
                        confidence=result_data.get('confidence', 0.5),
                        tables_used=result_data.get('tables_used', []),
                        columns_used=result_data.get('columns_used', {})
                    )
                else:
                    logger.warning("No valid JSON found in SQL generation response, attempting SQL extraction")
                    # Try to extract SQL directly from response even if JSON failed
                    extracted_sql = self._extract_sql_from_response(response, table_selection.selected_tables)
                    if extracted_sql:
                        logger.info("Successfully extracted SQL from non-JSON response")
                        return SqlGenerationResult(
                            sql_query=extracted_sql,
                            explanation="SQL extracted from response despite JSON parsing failure",
                            confidence=0.7,  # Good confidence since we extracted actual SQL
                            tables_used=table_selection.selected_tables,
                            columns_used={}
                        )
                    else:
                        logger.warning("SQL extraction also failed, using fallback")
                        return self._fallback_sql_generation(user_request, table_selection, max_rows)
            except Exception as e:
                logger.error(f"LLM SQL generation failed: {e}")
                return self._fallback_sql_generation(user_request, table_selection, max_rows)
        else:
            return self._fallback_sql_generation(user_request, table_selection, max_rows)
    
    def _create_table_selection_context(self, available_tables: List[str], previous_errors: List[Dict], 
                                       preferred_tables: List[str] = None) -> str:
        """Create context for table selection with table descriptions"""
        context_lines = []
        
        # Get table descriptions from schema discovery
        table_descriptions = {}
        try:
            discovered_schema = self.schema_discovery.discover_schema()
            for table in discovered_schema.tables:
                if table.description:
                    table_descriptions[table.name] = table.description
        except Exception as e:
            logger.warning(f"Could not get table descriptions: {e}")
        
        # Group tables by category with descriptions
        categories = {
            'Core Laboratory Operations': [],
            'Customer & Partner Management': [],
            'Laboratory Configuration': [],
            'Specialized Testing': [],
            'System & Reference': []
        }
        
        for table in available_tables:
            description = table_descriptions.get(table, f"Table: {table}")
            table_lower = table.lower()
            
            # Categorize based on table description keywords
            if any(keyword in description.lower() for keyword in ['request', 'sample', 'result', 'analys']):
                categories['Core Laboratory Operations'].append((table, description))
            elif any(keyword in description.lower() for keyword in ['customer', 'company', 'veterinarian', 'client']):
                categories['Customer & Partner Management'].append((table, description))
            elif any(keyword in description.lower() for keyword in ['type', 'category', 'configuration', 'template']):
                categories['Laboratory Configuration'].append((table, description))
            elif any(keyword in description.lower() for keyword in ['bacteriology', 'pcr', 'serology', 'antibiogram', 'parasitology']):
                categories['Specialized Testing'].append((table, description))
            else:
                categories['System & Reference'].append((table, description))
        
        # Create formatted output
        for category, tables in categories.items():
            if tables:
                context_lines.append(f"\nļæ½ {category.upper()}:")
                # Add row counts and sort by volume (highest first)
                tables_with_counts = []
                for table_name, description in tables:
                    row_count = 0
                    try:
                        for schema_table in discovered_schema.tables:
                            if schema_table.name == table_name:
                                row_count = schema_table.row_count
                                break
                    except:
                        pass
                    tables_with_counts.append((table_name, description, row_count))
                
                # Sort by row count (descending) to prioritize high-volume tables
                tables_with_counts.sort(key=lambda x: x[2], reverse=True)
                
                for table_name, description, row_count in tables_with_counts[:8]:  # Show up to 8 tables per category
                    # Truncate long descriptions
                    short_desc = description[:60] + "..." if len(description) > 60 else description
                    
                    # Format row count for readability
                    if row_count > 1000000:
                        count_str = f" ({row_count/1000000:.1f}M rows)"
                    elif row_count > 1000:
                        count_str = f" ({row_count/1000:.0f}K rows)"
                    elif row_count > 0:
                        count_str = f" ({row_count} rows)"
                    else:
                        count_str = " (empty)"
                    
                    context_lines.append(f"  • {table_name}{count_str}: {short_desc}")
                
                if len(tables_with_counts) > 8:
                    context_lines.append(f"  ... and {len(tables_with_counts) - 8} more tables in this category")
        
        # Add user preferred tables section if specified
        if preferred_tables:
            context_lines.append(f"\nšŸŽÆ USER PREFERRED TABLES:")
            context_lines.append("The user has specifically selected these tables for focus:")
            for table in preferred_tables:
                if table in available_tables:
                    # Get table info
                    description = table_descriptions.get(table, f"User-selected table: {table}")
                    row_count = 0
                    try:
                        for schema_table in discovered_schema.tables:
                            if schema_table.name == table:
                                row_count = schema_table.row_count
                                break
                    except:
                        pass
                    
                    if row_count > 1000000:
                        count_str = f" ({row_count/1000000:.1f}M rows)"
                    elif row_count > 1000:
                        count_str = f" ({row_count/1000:.0f}K rows)"
                    elif row_count > 0:
                        count_str = f" ({row_count} rows)"
                    else:
                        count_str = " (empty)"
                    
                    context_lines.append(f"  ⭐ {table}{count_str}: {description[:80]}")
                else:
                    context_lines.append(f"  āŒ {table}: NOT FOUND in available tables")
            
            context_lines.append("")
            context_lines.append("PRIORITIZE these user-selected tables when making your selection!")
        
        # Add discovered relationships for better table selection
        try:
            relationships = discovered_schema.to_dict().get('relationships', [])
            if relationships:
                context_lines.append("\nšŸ”— DISCOVERED TABLE RELATIONSHIPS:")
                # Show top relationships for context
                for rel in relationships[:10]:  # Limit to top 10 relationships
                    from_table = rel.get('from_table')
                    to_table = rel.get('to_table')
                    from_column = rel.get('from_column')
                    to_column = rel.get('to_column')
                    confidence = rel.get('confidence', 0)
                    
                    context_lines.append(f"  • {from_table}.{from_column} → {to_table}.{to_column} (conf: {confidence:.2f})")
                
                if len(relationships) > 10:
                    context_lines.append(f"  ... and {len(relationships) - 10} more relationships")
        except Exception as e:
            logger.warning(f"Could not add relationship context: {e}")
        
        return "\n".join(context_lines)
    
    def _format_preferred_columns_context(self, preferred_columns: List[str], detailed_schema: str) -> str:
        """Format preferred columns context for SQL generation"""
        if not preferred_columns:
            return ""
        
        context_lines = [
            "",
            "šŸŽÆ USER PREFERRED COLUMNS:",
            "The user has specifically selected these columns to include:"
        ]
        
        # Verify which preferred columns exist in the schema
        available_columns = []
        unavailable_columns = []
        
        for col in preferred_columns:
            if col in detailed_schema:
                available_columns.append(col)
            else:
                unavailable_columns.append(col)
        
        if available_columns:
            context_lines.append("āœ… AVAILABLE PREFERRED COLUMNS:")
            for col in available_columns:
                context_lines.append(f"  ⭐ {col}")
            context_lines.append("")
            context_lines.append("PRIORITIZE including these columns in your SELECT statement!")
        
        if unavailable_columns:
            context_lines.append("āŒ UNAVAILABLE PREFERRED COLUMNS (not in selected tables):")
            for col in unavailable_columns:
                context_lines.append(f"  āŒ {col}")
            context_lines.append("These columns cannot be used - they don't exist in the selected tables.")
        
        context_lines.append("")
        return "\n".join(context_lines)
    
    def _get_detailed_schema_for_tables(self, selected_tables: List[str], discovered_schema) -> str:
        """Get detailed column information for selected tables only"""
        schema_lines = []
        
        schema_dict = discovered_schema.to_dict()
        columns_by_table = schema_dict.get('columns_by_table', {})
        
        for table_name in selected_tables:
            if table_name in columns_by_table:
                table_columns = columns_by_table[table_name]
                schema_lines.append(f"\nšŸ“‹ {table_name} (Total: {len(table_columns)} columns):")
                
                if isinstance(table_columns, list):
                    # For the LLM to have exact column names, we need to show them all
                    # But format them in a readable way
                    if len(table_columns) <= 50:
                        # Show all columns for smaller tables
                        for i, col in enumerate(table_columns):
                            if isinstance(col, dict):
                                col_name = col.get('name', col.get('COLUMN_NAME', str(col)))
                                col_type = col.get('data_type', col.get('DATA_TYPE', 'unknown'))
                                is_fk = col.get('is_foreign_key', False)
                                ref_table = col.get('referenced_table', '')
                                
                                fk_info = f" -> {ref_table}" if is_fk and ref_table else ""
                                fk_marker = " (FK)" if is_fk else ""
                                
                                schema_lines.append(f"    {col_name} ({col_type}){fk_marker}{fk_info}")
                            else:
                                # Handle case where columns are just strings
                                schema_lines.append(f"    {col} (unknown type)")
                    else:
                        # For larger tables, show columns in groups but still show ALL column names
                        column_names = []
                        for col in table_columns:
                            if isinstance(col, dict):
                                col_name = col.get('name', col.get('COLUMN_NAME', str(col)))
                            else:
                                col_name = str(col)
                            column_names.append(col_name)
                        
                        # Show all column names in groups of 10 for readability
                        for i in range(0, len(column_names), 10):
                            group = column_names[i:i+10]
                            schema_lines.append(f"    {', '.join(group)}")
                        
                        # Add special note for foreign key relationships
                        fk_info = []
                        for col in table_columns[:50]:  # Check first 50 for FK info
                            if isinstance(col, dict) and col.get('is_foreign_key'):
                                col_name = col.get('name', col.get('COLUMN_NAME', ''))
                                ref_table = col.get('referenced_table', '')
                                if col_name and ref_table:
                                    fk_info.append(f"{col_name} -> {ref_table}")
                        
                        if fk_info:
                            schema_lines.append(f"    šŸ”— Foreign Keys: {', '.join(fk_info)}")
                else:
                    schema_lines.append("    [Column information not available]")
        
        return "\n".join(schema_lines)
    
    def _get_relationships_for_tables(self, selected_tables: List[str], discovered_schema) -> str:
        """Get discovered relationships between selected tables"""
        schema_dict = discovered_schema.to_dict()
        relationships = schema_dict.get('relationships', [])
        relevant_relationships = []
        
        for rel in relationships:
            from_table = rel.get('from_table')
            to_table = rel.get('to_table')
            
            # Include relationship if both tables are selected
            if from_table in selected_tables and to_table in selected_tables:
                from_column = rel.get('from_column')
                to_column = rel.get('to_column')
                confidence = rel.get('confidence', 0)
                relationship_type = rel.get('type', 'unknown')
                
                relevant_relationships.append(
                    f"- {from_table}.{from_column} → {to_table}.{to_column} "
                    f"({relationship_type}, confidence: {confidence:.2f})"
                )
        
        if relevant_relationships:
            return '\n'.join(relevant_relationships)
        else:
            return "No discovered relationships between selected tables"
    
    def _format_previous_errors(self, previous_errors: List[Dict]) -> str:
        """Format previous errors for context"""
        if not previous_errors:
            return "None"
        
        error_lines = []
        invalid_columns = []
        
        for error_info in previous_errors[-2:]:  # Only show last 2 errors
            error_lines.append(f"Iteration {error_info['iteration']}:")
            error_lines.append(f"  Error: {error_info['error']}")
            
            # Extract invalid column names from error messages
            error_msg = error_info['error']
            if 'Invalid column name' in error_msg:
                import re
                col_match = re.search(r"Invalid column name '([^']+)'", error_msg)
                if col_match:
                    invalid_columns.append(col_match.group(1))
            
            if error_info.get('columns_used'):
                error_lines.append(f"  Columns used: {error_info['columns_used']}")
        
        if invalid_columns:
            error_lines.append(f"\nāŒ INVALID COLUMNS TO AVOID: {', '.join(set(invalid_columns))}")
            error_lines.append("   These column names do NOT exist in the database!")
        
        return "\n".join(error_lines)
    
    def _test_sql_execution(self, sql_query: str) -> Dict[str, Any]:
        """Test SQL execution and return results"""
        try:
            start_time = datetime.now()
            
            # Use data processor to execute query
            from models import DataSource, DataSourceType
            from sql_query_generator import get_default_connection_config
            
            conn_config = get_default_connection_config()
            data_source = DataSource(
                source_type=DataSourceType.SQL_QUERY,
                sql_connection=conn_config.to_connection_string(),
                sql_query=sql_query
            )
            
            df, metadata = self.data_processor.load_data(data_source)
            
            execution_time = (datetime.now() - start_time).total_seconds()
            
            return {
                'success': True,
                'data': df,
                'metadata': metadata,
                'row_count': len(df),
                'execution_time': execution_time
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'execution_time': None,
                'row_count': None
            }
    
    def _fallback_table_selection(self, user_request: str, available_tables: List[str]) -> TableSelectionResult:
        """Fallback table selection when LLM fails"""
        # Simple keyword-based selection
        keywords = user_request.lower().split()
        selected = []
        
        # Always include core tables
        core_tables = ['Requests', 'Samples', 'Results', 'Analyses']
        for table in core_tables:
            if table in available_tables:
                selected.append(table)
        
        # Add lookup tables if relevant
        if any(keyword in ['company', 'customer', 'client'] for keyword in keywords):
            if 'Companies' in available_tables:
                selected.append('Companies')
        
        if any(keyword in ['type', 'category'] for keyword in keywords):
            for table in ['SampleTypes', 'AnalysisCategories']:
                if table in available_tables and table not in selected:
                    selected.append(table)
        
        return TableSelectionResult(
            selected_tables=selected[:6],  # Limit to 6 tables
            reasoning="Fallback selection based on core laboratory workflow",
            confidence=0.6,
            suggested_joins=[
                "Samples s ON s.Sample_Request = r.Id",
                "Results res ON res.Result_Sample = s.Id",
                "Analyses a ON a.Id = res.Result_Analysis"
            ]
        )
    
    def _extract_sql_from_response(self, response: str, expected_tables: List[str]) -> Optional[str]:
        """Extract SQL query from LLM response even when JSON parsing fails"""
        import re
        
        # Look for SQL patterns in the response
        sql_patterns = [
            r'SELECT\s+.*?(?=\n\n|\n$|$)',  # SQL until double newline or end
            r'```sql\s*(.*?)\s*```',        # SQL in code blocks
            r'```\s*(SELECT.*?)\s*```',     # SQL in generic code blocks
            r'(SELECT\s+(?:TOP\s+\d+\s+)?.*?(?:FROM|JOIN).*?)(?:\n\n|\Z)',  # Complete SELECT statements
        ]
        
        for pattern in sql_patterns:
            matches = re.findall(pattern, response, re.IGNORECASE | re.DOTALL)
            for match in matches:
                sql_candidate = match.strip()
                
                # Basic validation: should contain SELECT and at least one expected table
                if ('SELECT' in sql_candidate.upper() and 
                    any(table in sql_candidate for table in expected_tables)):
                    
                    # Clean up the SQL
                    sql_candidate = sql_candidate.replace('\n', ' ').strip()
                    if sql_candidate.endswith(';'):
                        sql_candidate = sql_candidate[:-1]
                    
                    logger.info(f"Extracted SQL from response: {sql_candidate[:100]}...")
                    return sql_candidate
        
        return None

    def _fallback_sql_generation(self, user_request: str, table_selection: TableSelectionResult, max_rows: int) -> SqlGenerationResult:
        """Context-aware fallback SQL generation - generates queries based on selected tables when LLM JSON parsing fails"""
        
        # Use the selected tables from Pass 1 to create a more comprehensive fallback
        selected_tables = table_selection.selected_tables
        
        # Check if this looks like an enhancement request
        is_enhancement = "TASK: Enhance" in user_request
        
        if is_enhancement and any(table in ['Results', 'Samples', 'SampleAnalyseGroups'] for table in selected_tables):
            # Enhancement fallback - try to create a comprehensive query with selected tables
            sql = f"""SELECT TOP {max_rows}
    s.Id AS SampleId, s.SampleNr, s.Identification, s.DateSampling, s.Sample_SpeciesType,
    sag.Id AS SampleAnalyseGroupId, sag.RequestAnalyseGroup_AnalysisGroup, sag.Germ,
    res.Id AS ResultId, res.Result_Analysis, res.Result_AnalysisGroup, res.Result_Value, res.TechValidated, res.BioValidated"""
            
            # Add species table if selected
            if 'tblSpecies' in selected_tables:
                sql += ",\n    sp.SpeciesName, sp.SpeciesType AS SpeciesTypeName"
            if 'Flock_Ras' in selected_tables:
                sql += ",\n    fr.Species, fr.RasID, fr.FlockID"
                
            sql += "\nFROM Samples s"
            
            # Add JOINs for selected tables
            if 'SampleAnalyseGroups' in selected_tables:
                sql += "\n    LEFT JOIN SampleAnalyseGroups sag ON sag.SampleAnalyseGroup_Sample = s.Id"
            if 'Results' in selected_tables:
                sql += "\n    LEFT JOIN Results res ON res.SampleAnalyseGroupId = sag.Id AND res.Result_Sample = s.Id"
            if 'Flock_Ras' in selected_tables:
                sql += "\n    LEFT JOIN Flock_Ras fr ON s.Flock_Sample = fr.FlockID"
            if 'tblSpecies' in selected_tables:
                sql += "\n    LEFT JOIN tblSpecies sp ON s.Sample_SpeciesType = sp.SpeciesID"
                
            sql += f"""
WHERE s.Sample_SpeciesType IS NOT NULL
    AND s.DateSampling >= DATEADD(MONTH, -3, GETDATE())
ORDER BY s.DateSampling DESC, s.SampleNr"""
            
            explanation = f"Enhanced fallback query using selected tables: {', '.join(selected_tables)}"
            confidence = 0.6  # Better confidence for enhancement fallback
            
        else:
            # Regular fallback - use core tables from selection
            core_table = selected_tables[0] if selected_tables else 'Requests'
            
            sql = f"""SELECT TOP {max_rows}
    t.Id,
    t.DateCreated
FROM {core_table} t
WHERE t.DateCreated >= DATEADD(MONTH, -6, GETDATE())
ORDER BY t.DateCreated DESC"""
            
            explanation = f"Basic fallback query using {core_table} when LLM parsing failed"
            confidence = 0.3
        
        return SqlGenerationResult(
            sql_query=sql,
            explanation=explanation,
            confidence=confidence,
            tables_used=selected_tables,
            columns_used={table: [] for table in selected_tables}  # Empty since we don't track specific columns in fallback
        )
    
    def _extract_json_from_response(self, response: str) -> Dict[str, Any]:
        """
        Robust JSON extraction from LLM response that may contain extra text
        """
        import re
        
        if not response or not response.strip():
            logger.warning("Empty response from LLM")
            return {}
        
        # Try direct JSON parsing first
        try:
            return json.loads(response.strip())
        except json.JSONDecodeError:
            pass
        
        # Try to find JSON within the response using regex
        json_patterns = [
            r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}',  # Simple nested JSON
            r'\{.*?\}',  # Any content between braces
        ]
        
        for pattern in json_patterns:
            matches = re.findall(pattern, response, re.DOTALL)
            for match in matches:
                try:
                    result = json.loads(match.strip())
                    if isinstance(result, dict):
                        logger.debug(f"Successfully extracted JSON: {list(result.keys())}")
                        return result
                except json.JSONDecodeError:
                    continue
        
        # Try to find JSON-like structure and clean it
        try:
            # Look for content between first { and last }
            start = response.find('{')
            end = response.rfind('}')
            if start != -1 and end != -1 and end > start:
                json_candidate = response[start:end+1]
                return json.loads(json_candidate)
        except json.JSONDecodeError:
            pass
        
        logger.error(f"Could not extract valid JSON from response: {response[:200]}...")
        return {}
    
    def _create_workflow_summary(self, workflow_results: Dict[str, Any]) -> str:
        """Create a summary of the workflow execution"""
        iterations_count = len(workflow_results['iterations'])
        success = workflow_results['final_success']
        
        summary_lines = [
            f"Two-Pass SQL Workflow Summary:",
            f"  Total Iterations: {iterations_count}",
            f"  Final Success: {success}",
        ]
        
        if success:
            final_iteration = workflow_results['iterations'][-1]
            summary_lines.extend([
                f"  Successful Iteration: {final_iteration.iteration_number}",
                f"  Tables Selected: {', '.join(final_iteration.table_selection.selected_tables)}",
                f"  Final Row Count: {final_iteration.row_count}",
                f"  Execution Time: {final_iteration.execution_time:.2f}s"
            ])
        else:
            summary_lines.append("  All iterations failed")
            
        return "\n".join(summary_lines)

Parameters

Name Type Default Kind
bases - -

Parameter Details

bases: Parameter of type

Return Value

Returns unspecified type

Class Interface

Methods

__init__(self, schema_discovery, data_processor, statistical_agent)

Purpose: Internal method: init

Parameters:

  • schema_discovery: Parameter
  • data_processor: Parameter
  • statistical_agent: Parameter

Returns: None

generate_sql_with_iterations(self, user_request, max_rows) -> Dict[str, Any]

Purpose: Main workflow: Generate SQL through two-pass approach with iterations

Parameters:

  • user_request: Type: str
  • max_rows: Type: int

Returns: Returns Dict[str, Any]

_pass1_select_tables(self, user_request, available_tables, previous_errors, preferred_tables) -> TableSelectionResult

Purpose: Pass 1: Use LLM to select relevant tables based on user request

Parameters:

  • user_request: Type: str
  • available_tables: Type: List[str]
  • previous_errors: Type: List[Dict]
  • preferred_tables: Type: List[str]

Returns: Returns TableSelectionResult

_pass2_generate_sql(self, user_request, table_selection, discovered_schema, max_rows, previous_errors, preferred_columns) -> SqlGenerationResult

Purpose: Pass 2: Generate SQL with detailed schema information for selected tables

Parameters:

  • user_request: Type: str
  • table_selection: Type: TableSelectionResult
  • discovered_schema: Parameter
  • max_rows: Type: int
  • previous_errors: Type: List[Dict]
  • preferred_columns: Type: List[str]

Returns: Returns SqlGenerationResult

_create_table_selection_context(self, available_tables, previous_errors, preferred_tables) -> str

Purpose: Create context for table selection with table descriptions

Parameters:

  • available_tables: Type: List[str]
  • previous_errors: Type: List[Dict]
  • preferred_tables: Type: List[str]

Returns: Returns str

_format_preferred_columns_context(self, preferred_columns, detailed_schema) -> str

Purpose: Format preferred columns context for SQL generation

Parameters:

  • preferred_columns: Type: List[str]
  • detailed_schema: Type: str

Returns: Returns str

_get_detailed_schema_for_tables(self, selected_tables, discovered_schema) -> str

Purpose: Get detailed column information for selected tables only

Parameters:

  • selected_tables: Type: List[str]
  • discovered_schema: Parameter

Returns: Returns str

_get_relationships_for_tables(self, selected_tables, discovered_schema) -> str

Purpose: Get discovered relationships between selected tables

Parameters:

  • selected_tables: Type: List[str]
  • discovered_schema: Parameter

Returns: Returns str

_format_previous_errors(self, previous_errors) -> str

Purpose: Format previous errors for context

Parameters:

  • previous_errors: Type: List[Dict]

Returns: Returns str

_test_sql_execution(self, sql_query) -> Dict[str, Any]

Purpose: Test SQL execution and return results

Parameters:

  • sql_query: Type: str

Returns: Returns Dict[str, Any]

_fallback_table_selection(self, user_request, available_tables) -> TableSelectionResult

Purpose: Fallback table selection when LLM fails

Parameters:

  • user_request: Type: str
  • available_tables: Type: List[str]

Returns: Returns TableSelectionResult

_extract_sql_from_response(self, response, expected_tables) -> Optional[str]

Purpose: Extract SQL query from LLM response even when JSON parsing fails

Parameters:

  • response: Type: str
  • expected_tables: Type: List[str]

Returns: Returns Optional[str]

_fallback_sql_generation(self, user_request, table_selection, max_rows) -> SqlGenerationResult

Purpose: Context-aware fallback SQL generation - generates queries based on selected tables when LLM JSON parsing fails

Parameters:

  • user_request: Type: str
  • table_selection: Type: TableSelectionResult
  • max_rows: Type: int

Returns: Returns SqlGenerationResult

_extract_json_from_response(self, response) -> Dict[str, Any]

Purpose: Robust JSON extraction from LLM response that may contain extra text

Parameters:

  • response: Type: str

Returns: Returns Dict[str, Any]

_create_workflow_summary(self, workflow_results) -> str

Purpose: Create a summary of the workflow execution

Parameters:

  • workflow_results: Type: Dict[str, Any]

Returns: Returns str

Required Imports

import logging
import json
from typing import Dict
from typing import List
from typing import Optional

Usage Example

# Example usage:
# result = TwoPassSqlWorkflow(bases)

Similar Components

AI-powered semantic similarity - components with related functionality:

  • class EnhancedSQLWorkflow 71.0% similar

    Enhanced SQL workflow with iterative optimization

    From: /tf/active/vicechatdev/full_smartstat/enhanced_sql_workflow.py
  • function demonstrate_sql_workflow 52.1% similar

    Demonstrates the enhanced SQL workflow for the SmartStat system by loading configurations, initializing SQL query generator, testing natural language to SQL conversion, and displaying schema analysis.

    From: /tf/active/vicechatdev/smartstat/demo_enhanced_sql_workflow.py
  • function demonstrate_sql_workflow_v1 52.0% similar

    Demonstrates the enhanced SQL workflow for the SmartStat system by loading configurations, initializing the SQL query generator, testing natural language to SQL conversion, and displaying schema analysis.

    From: /tf/active/vicechatdev/full_smartstat/demo_enhanced_sql_workflow.py
  • function enhanced_sql_workflow 51.4% similar

    Flask route handler that initiates an enhanced SQL workflow with iterative optimization, executing data extraction and analysis in a background thread while providing real-time progress tracking.

    From: /tf/active/vicechatdev/full_smartstat/app.py
  • class IterationResult 49.6% similar

    A dataclass that encapsulates the complete results of a single iteration in a two-pass process, including table selection, SQL generation, and execution outcomes.

    From: /tf/active/vicechatdev/full_smartstat/two_pass_sql_workflow.py
← Back to Browse