🔍 Code Extractor

class OneCo_hybrid_RAG_v4

Maturity: 15

A class named OneCo_hybrid_RAG

File:
/tf/active/vicechatdev/datacapture_backup_16072025/OneCo_hybrid_RAG.py
Lines:
1165 - 2806
Complexity:
moderate

Purpose

No detailed description available

Source Code

class OneCo_hybrid_RAG ():


    def __init__(self):
        ## Set API keys
        self.set_api_keys()
        ## Define the flow control variables to be exposed and set default values
        self.flow_control = {
            "pre_model" : ["OpenAi","gpt-4o-mini",0],
            "model" : ["OpenAi","gpt-4o",0],
            "search_engine" : ["Serper","google"],
            "enable_search" : False,
            "enable_memory" : False,
            "memory_max_size" : 3,
            "enable_referencing" : True,
        }
        ## Different type of data can be provided here and will be included in the flow
        self.data_handles = SimpleDataHandle()
        ## Define the UI elements to be exposed
        self.chat_interface=pn.chat.ChatInterface(callback=self.response_callback,width=1200,callback_exception='verbose')
        ## Plan for chat memory
        self.chat_memory = SimpleChatMemory(max_history=self.flow_control["memory_max_size"])
        self.extended_query=None
        # Set up the blocks_dict for references
        self.blocks_dict = {}
        self.block_counter = 1
        
        # Explicitly set OpenAI API type for this class
        os.environ["OPENAI_API_TYPE"] = "openai"
        
        self.init_connections()
        return
    
    def init_connections(self):

        uri = config.DB_ADDR
        user, password = config.DB_AUTH
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.session = self.driver.session(database=config.DB_NAME)
        api_key = "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"  # Replace with your actual API key
        self.chroma_embedder=MyEmbeddingFunction("gpt-4o-mini","text-embedding-3-small",api_key)
        self.chroma_client=chromadb.HttpClient(host='vice_chroma', port=8000)
        self.available_collections = self.chroma_client.list_collections()
        return


    def run_query(self, query, params=None):
        """
        Execute a Cypher query and return the result
        
        Parameters
        ----------
        query : str
            The Cypher query to execute
        params : dict, optional
            Parameters for the query
            
        Returns
        -------
        result
            The query result
        """
        if params is None:
            params = {}
        return self.session.run(query, params)
    
    def evaluate_query(self, query, params=None):
        """
        Execute a Cypher query and return a single result
        
        Parameters
        ----------
        query : str
            The Cypher query to execute
        params : dict, optional
            Parameters for the query
            
        Returns
        -------
        object
            The single result value
        """
        if params is None:
            params = {}
        result = self.session.run(query, params)
        record = result.single()
        if record:
            return record[0]
        return None
    
    def push_changes(self, node):
        """
        Push changes to a node to the database
        
        Parameters
        ----------
        node : dict or node-like object
            Node with properties to update
        """
        # Extract node properties, handling both dict-like and node-like objects
        if hasattr(node, 'items'):
            # Dict-like object
            properties = {k: v for k, v in node.items() if k != 'labels'}
            labels = node.get('labels', [])
            uid = node.get('UID')
        else:
            # Node-like object from previous driver
            properties = {k: node[k] for k in node.keys() if k != 'UID'}
            labels = list(node.labels)
            uid = node['UID']
        
        # Construct labels string for Cypher
        if labels:
            labels_str = ':'.join(labels)
            match_clause = f"MATCH (n:{labels_str} {{UID: $uid}})"
        else:
            match_clause = "MATCH (n {UID: $uid})"
        
        # Update node properties
        if properties:
            set_clauses = [f"n.`{key}` = ${key}" for key in properties]
            query = f"{match_clause} SET {', '.join(set_clauses)}"
            params = {"uid": uid, **properties}
            self.run_query(query, params)
        
        return
    

    def count_tokens(self,text):
        encoding = tiktoken.get_encoding("cl100k_base")
        return len(encoding.encode(text))
    
    def set_api_keys(self):
        ## Public openAI key
        os.environ["OPENAI_API_KEY"]='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'
        ## Serper API key
        os.environ["SERPER_API_KEY"] = "9a1f42c99feee69526e216af14e07b64fb4b3bfb"
        ## AzureOpenAI endpoint
        os.environ["AZURE_OPENAI_ENDPOINT"] = "https://vice-llm-2.openai.azure.com/openai/deployments/OneCo-gpt/chat/completions?api-version=2024-08-01-preview"
        ## AzureOpenAI key
        os.environ["AZURE_OPENAI_API_KEY"] = "8DaDtzYz3HePiypmFb6JQmJd3zUCtyCQkiYE8bePRnpyk2YNkJZRJQQJ99BAACfhMk5XJ3w3AAABACOGyJVB"
        return
    
    def extract_core_query(self,query_text):
        """
        Extracts the core information-seeking question from a user query that may contain
        both a question and processing instructions for the RAG system.
        
        Args:
            query_text: The original user query text
            
        Returns:
            dict: Contains the extracted information with keys:
                - core_question: The actual information need/question
                - instructions: Any processing instructions found
                - is_complex: Boolean indicating if query contained instructions
        """
        # Use the pre-model (smaller model) for this extraction task
        llm = ChatOpenAI(
            model=self.flow_control['pre_model'][1],
            temperature=self.flow_control['pre_model'][2],
        )
        
        prompt = f"""
        You are an AI query analyzer. Your task is to analyze the following user query and separate it into:
        1. The core information-seeking question (what the user actually wants to know)
        2. Any processing instructions (how the user wants information presented or processed)
        
        User query: {query_text}
        
        Output your analysis in strict JSON format:
        ```json
        {{
            "core_question": "The main question or information need",
            "instructions": "Any processing instructions (or empty string if none)",
            "is_complex": true/false (true if query contains instructions, false if it's just a question)
        }}
        ```
        
        Examples:
        
        Input: "Tell me about mRNA vaccines and format the answer with bullet points"
        Output: 
        ```json
        {{
            "core_question": "Tell me about mRNA vaccines",
            "instructions": "format the answer with bullet points",
            "is_complex": true
        }}
        ```
        
        Input: "What are the main types of vaccine adjuvants?"
        Output: 
        ```json
        {{
            "core_question": "What are the main types of vaccine adjuvants?",
            "instructions": "",
            "is_complex": false
        }}
        ```
        
        Only respond with the JSON output, nothing else.
        """
        
        response = llm.invoke(prompt)
        
        try:
            # Extract JSON from response if needed
            content = response.content
            if '```json' in content:
                content = content.split('```json')[1].split('```')[0].strip()
            elif '```' in content:
                content = content.split('```')[1].split('```')[0].strip()
                
            result = json.loads(content)
            return result
        except Exception as e:
            # Fallback if parsing fails
            print(f"Error parsing LLM response: {e}")
            return {
                "core_question": query_text,
                "instructions": "",
                "is_complex": False
            }
        
    def extract_serper_results(self, serper_response):
        """
        Extract formatted search results and URLs from GoogleSerperAPI response.
        
        Args:
            serper_response: Raw response from GoogleSerperAPI (JSON object or string)
            
        Returns:
            tuple: (formatted_results, extracted_urls)
        """
        search_results = ""
        extracted_urls = []
        
        try:
            # Convert to dict if it's a string
            if isinstance(serper_response, str):
                try:
                    data = json.loads(serper_response)
                except json.JSONDecodeError as e:
                    print(f"Error parsing Serper JSON: {e}")
                    return "Error processing search results.", []
            else:
                # It's already a dict/object
                data = serper_response
            
            # Add search query to the results
            if 'searchParameters' in data and 'q' in data['searchParameters']:
                search_query = data['searchParameters']['q']
                search_results += f"### Search Results for: '{search_query}'\n\n"
            else:
                search_results += "### Search Results\n\n"
            
            # Process organic search results
            if 'organic' in data and isinstance(data['organic'], list):
                for i, result in enumerate(data['organic']):
                    title = result.get('title', 'No title')
                    link = result.get('link', '')
                    snippet = result.get('snippet', '')
                    date = result.get('date', '')
                    
                    # Format the result with block reference
                    block_num = self.block_counter + len(extracted_urls)
                    search_results += f"[block {block_num}] **{title}**\n"
                    if date:
                        search_results += f"*{date}*\n"
                    search_results += f"{snippet}\n"
                    search_results += f"URL: {link}\n\n"
                    
                    # Add to extracted URLs
                    extracted_urls.append({
                        'title': title,
                        'url': link,
                        'snippet': snippet,
                        'date': date
                    })
            
            # Process "People Also Ask" section
            if 'peopleAlsoAsk' in data and isinstance(data['peopleAlsoAsk'], list):
                search_results += "#### People Also Ask\n\n"
                
                for i, qa in enumerate(data['peopleAlsoAsk']):
                    question = qa.get('question', '')
                    snippet = qa.get('snippet', '')
                    title = qa.get('title', '')
                    link = qa.get('link', '')
                    
                    # Format the result with block reference
                    block_num = self.block_counter + len(extracted_urls)
                    search_results += f"[block {block_num}] **{question}**\n"
                    search_results += f"*{title}*\n"
                    search_results += f"{snippet}\n"
                    search_results += f"URL: {link}\n\n"
                    
                    # Add to extracted URLs
                    extracted_urls.append({
                        'title': f"{question} - {title}",
                        'url': link,
                        'snippet': snippet
                    })
            
            # If no results were found
            if not extracted_urls:
                search_results += "No search results were found.\n"
            
        except Exception as e:
            print(f"Error extracting Serper results: {e}")
            search_results = "Error processing search results.\n"
        
        return search_results, extracted_urls

    def response_callback(self, query):
        ## We make a difference between the search enabled or disabled mode  - the first will have 2 separate LLM calls.
        ## Common part - prepare the data
        query_analysis = self.extract_core_query(query)
        print("query analysis", query_analysis)
        search_query = query_analysis["core_question"]
        print("search query", search_query)
        
        # Store the analysis for later use in processing
        self.current_query_analysis = query_analysis
        
        # Parse handler using the core question for retrieval
        data_sections = self.parse_handler(search_query)

        ## prepare LLM following flow control
        if self.flow_control['model'][0]=="OpenAi":
            llm = ChatOpenAI(
                model=self.flow_control['model'][1],
                temperature=self.flow_control['model'][2],
                timeout=None,
                max_retries=2)
        elif self.flow_control['model'][0]=="Azure":
            llm = AzureChatOpenAI(
                azure_deployment=self.flow_control['model'][1],
                api_version=self.flow_control['model'][3],
                temperature=self.flow_control['model'][2],
                max_tokens=2500,
                timeout=None,
                max_retries=2)
        else:
            llm = ChatOpenAI(
                model='gpt-4o',
                temperature=0,
                timeout=None,
                max_retries=2)

        ## Search enabled mode
        self.search_results = ""
        if self.flow_control["enable_search"]:
            ## generate a first response to start the search
            prompt=self.generate_prompt("Vaccine_google",data_sections,query)
            answer = llm.invoke(prompt)
            print("input for web search", answer.content)
            dict=json.loads(answer.content[8:-4])
            search_tool = GoogleSerperAPIWrapper()
            
            # Create a counter for web references
            web_ref_count = self.block_counter
            
            for s in dict['search_queries']:
                print("searching with ", s)
                
                # Parse Serper results to extract content and URLs
                search_output=search_tool.results(s)
                #print("search output", search_output)
                search_results, extracted_urls = self.extract_serper_results(search_output)
                self.search_results = self.search_results + "\n" + search_results
                
                # Add extracted URLs to blocks_dict for reference
                for url_info in extracted_urls:
                    title = url_info.get('title', 'Web Page')
                    url = url_info.get('url', '')
                    snippet = url_info.get('snippet', '')
                    
                    # Add reference in blocks_dict
                    self.blocks_dict[self.block_counter] = {
                        "type": "web",
                        "id": f"web_{self.block_counter}",
                        "url": url,
                        "title": title,
                        "snippet": snippet,
                        "content": f"Web search result: {title}. {url}"
                    }
                    self.block_counter += 1

        print("nr blocks dict", len(self.blocks_dict))            
        ## This is the common part for both modes
        prompt=self.generate_prompt("Vaccine_base",data_sections,query)
        answer = llm.invoke(prompt)
        print("output of full prompt", answer.content)
        
        # If reference formatting is enabled, apply it
        if self.flow_control["enable_referencing"]:
            # No need for conversion - use blocks_dict directly
            ref_manager = ReferenceManager(default_style="apa")
            processed_text, references_section = ref_manager.process_references(
                answer.content, 
                self.blocks_dict, 
                style="apa"
            )
            formatted_answer = processed_text + "\n\n" + references_section
        else:
            formatted_answer = answer.content

        self.chat_memory.save_context(
                {"role": "user", "content": query},
                {"role": "assistant", "content": answer.content},
            )

        return pn.pane.Markdown(formatted_answer)


    def get_embedding(self,text):
        """Generate an embedding for the given text using OpenAI's text-embedding-ada-002 model."""
        response = openai.embeddings.create(
                    model="text-embedding-3-small",
                    input=text
                )
        return response.data[0].embedding
    
    def extract_for_queries(self,text, queries, max_tokens=5000, api_key=None):
        """
        Extract information from text based on queries.
        
        Args:
            text: Text to extract from
            queries: List of queries to guide extraction
            max_tokens: Maximum tokens in the output
            api_key: API key for the LLM service
            
        Returns:
            Extracted text relevant to the queries
        """
        api_key = "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"  # Replace with your actual API key
        extractor = QueryBasedExtractor(
            max_output_tokens=max_tokens,
            api_key=api_key,
            model_name="gpt-4o-mini"  # Or another small model
        )
        return extractor.extract(text, queries)

    def parse_handler(self, query):
        data_sections = {}
        # Create blocks_dict directly in the format needed by ReferenceManager
        self.blocks_dict = {}  # Replace self.inline_refs with self.blocks_dict
        self.block_counter = 1  # Start block numbering from 1 to match example
        
        for key in self.data_handles.handlers.keys():
            if self.data_handles.handlers[key]["type"] == "text":
                data_sections[key] = f"[block {self.block_counter}] {self.data_handles.handlers[key]['data']}"
                # Create block entry in proper format
                self.blocks_dict[self.block_counter] = {
                    "type": "generic",
                    "id": f"text_{self.block_counter}",
                    "content": f"Text content: {self.data_handles.handlers[key]['data'][:100]}..."
                }
            elif self.data_handles.handlers[key]["type"] == "dataframe":
                data_sections[key] = f"[block {self.block_counter}] {self.extract_for_queries(self.data_handles.handlers[key]['data'].to_markdown(), [query])}"
                # Create block entry for dataframe
                self.blocks_dict[self.block_counter] = {
                    "type": "generic",
                    "id": f"dataframe_{self.block_counter}",
                    "content": f"Dataframe content from {key}"
                }
            elif self.data_handles.handlers[key]["type"] == "vectorstore":
                data_sections[key] = self.collect_text_blocks(self.data_handles.handlers[key], query)
            elif self.data_handles.handlers[key]["type"] == "db_search":
                data_sections[key] = self.collect_data_from_neo4j(self.data_handles.handlers[key], query)
            elif self.data_handles.handlers[key]["type"] == "chromaDB":
                data_sections[key] = self.collect_data_from_chroma(self.data_handles.handlers[key], query)
            
            self.block_counter += 1
        
        return data_sections
    
    def reformat_data(self, data, min_document_length=30, similarity_threshold=0.95, use_crossencoder=False, inclusions=10):
        """
        Reformat and filter data to be grouped by ID, excluding too-short documents
        and documents that are too similar to each other. Optionally applies crossencoder ranking.
        
        Args:
            data: Original data structure
            min_document_length: Minimum character length for documents to include (default: 30)
            similarity_threshold: Threshold for document similarity (default: 0.95, higher means more similar)
            use_crossencoder: Whether to apply crossencoder reranking (default: False)
            inclusions: Number of documents to return after filtering (default: 10)
            
        Returns:
            List of selected documents (not dictionary)
        """
        from sentence_transformers import CrossEncoder
        import numpy as np
        
        result = {}
        selected_docs = []
        selected_embeddings = []
        
        # Unpack the nested lists for easier access
        ids_list = data['ids'][0]
        documents_list = data['documents'][0]
        metadatas_list = data['metadatas'][0]
        embeddings_array = data['embeddings'][0]
        
        # First pass: filter by document length and organize data
        candidates = []
        for i, id_val in enumerate(ids_list):
            # Check if document meets length requirement and does not exceed a max toaken lenght
            if len(documents_list[i]) >= min_document_length and self.count_tokens(documents_list[i]) <= 10000:
                candidates.append({
                    'id': id_val,
                    'document': documents_list[i],
                    'metadata': metadatas_list[i],
                    'embedding': embeddings_array[i].tolist() if embeddings_array is not None else None
                })
        
        # If we don't have enough candidates, return all we have
        if len(candidates) <= inclusions:
            return [(doc['metadata'],doc['document']) for doc in candidates]
        
        # Second pass: filter by similarity
        for candidate in candidates:
            candidate_embedding = np.array(candidate['embedding'])
            # Normalize embedding
            norm = np.linalg.norm(candidate_embedding)
            if norm > 0:
                candidate_embedding = candidate_embedding / norm
                
            # Check if candidate is too similar to any already selected document
            is_redundant = False
            for sel_emb in selected_embeddings:
                similarity = np.dot(candidate_embedding, sel_emb)
                if similarity >= similarity_threshold:
                    is_redundant = True
                    break
                    
            if not is_redundant:
                # Add to result dictionary
                result[candidate['id']] = {
                    'document': candidate['document'],
                    'metadata': candidate['metadata'],
                    'embedding': candidate['embedding']
                }
                # Add to selected lists for similarity checks
                selected_docs.append(candidate)
                selected_embeddings.append(candidate_embedding)
                
                # If we've collected enough documents and don't need crossencoder, we can stop
                #if len(selected_docs) >= inclusions * 2 and not use_crossencoder:
                #    break
        
        # If using crossencoder for reranking
        if use_crossencoder and len(selected_docs) > inclusions:
            query = data.get('query_text', '')
            if not query:  # If no query provided, use a placeholder
                query = "default query"  # Ideally this should be passed in
                
            cross_model = CrossEncoder('BAAI/bge-reranker-base')
            query_doc_pairs = [(query, doc['document']) for doc in selected_docs]
            scores = cross_model.predict(query_doc_pairs)
            
            # Zip documents with their scores and sort by score (highest first)
            doc_score_pairs = list(zip(selected_docs, scores))
            ranked_docs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
            
            # Take the top 'inclusions' documents after reranking
            selected_docs = [doc for doc, _ in ranked_docs[:inclusions]]
        elif len(selected_docs) > inclusions:
            # If not using crossencoder but have too many docs, just take the first 'inclusions'
            selected_docs = selected_docs[:inclusions]
        
        # Return just the document text for further processing
        #print("returning ",[(doc['metadata'],doc['document']) for doc in selected_docs])
        return [(doc['metadata'],self.extract_for_queries(doc['document'],self.extended_query)) for doc in selected_docs]

    def extract_filter_keywords(self, query, n_keywords=2):
        """
        Extract distinguishing keywords from a query for filtering search results.
        
        Args:
            query: The user's query text
            n_keywords: Maximum number of keywords to extract
        
        Returns:
            List of keywords for filtering
        """
        llm = ChatOpenAI(
            model=self.flow_control['pre_model'][1],
            temperature=0
        )
        
        # Make the instruction much more explicit about the exact count
        prompt = f"""
        You are a search optimization expert. Extract EXACTLY {n_keywords} specific distinguishing keyword(s) from this query:
        "{query}"
        
        Guidelines:
        - You MUST return EXACTLY {n_keywords} keyword(s) - no more, no less
        - Focus on proper nouns, company names, technical terms, and specific concepts
        - Select word(s) that would differentiate this topic from related but irrelevant topics
        - Choose word(s) that could filter out incorrect contexts (wrong companies, unrelated domains)
        - Exclude common words like "the", "and", "of"
        - Return ONLY a JSON array of strings with EXACTLY {n_keywords} string(s)
        
        Example:
        For "Impact of Pfizer's mRNA vaccine development on COVID-19 transmission" and n_keywords=1
        Output: ["Pfizer"]
        
        For "Impact of Pfizer's mRNA vaccine development on COVID-19 transmission" and n_keywords=3
        Output: ["Pfizer", "mRNA", "COVID-19"]
        
        Output format:
        ```json
        ["keyword1"{", keyword2" if n_keywords > 1 else ""}{", ..." if n_keywords > 2 else ""}]
        ```
        
        Remember: I need EXACTLY {n_keywords} keyword(s). Count carefully before submitting.
        """
        
        response = llm.invoke(prompt)
        
        try:
            # Extract JSON from response
            content = response.content.strip()
            if '```json' in content:
                content = content.split('```json')[1].split('```')[0].strip()
            elif '```' in content:
                content = content.split('```')[1].split('```')[0].strip()
                
            keywords = json.loads(content)
            
            # Force the correct number of keywords
            if len(keywords) > n_keywords:
                keywords = keywords[:n_keywords]
            elif len(keywords) < n_keywords and len(keywords) > 0:
                # If we got fewer keywords than requested but at least one, duplicate the first one
                while len(keywords) < n_keywords:
                    keywords.append(keywords[0])
                
            return [k.lower() for k in keywords]  # Convert to lowercase for case-insensitive matching
        except Exception as e:
            print(f"Error extracting keywords: {e}")
            # Fall back to simple keyword extraction if LLM fails
            words = query.lower().split()
            stopwords = ['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'with', 'about']
            return [w for w in words if w not in stopwords and len(w) > 3][:n_keywords]
    
    def collect_data_from_chroma(self, data, query):
        """
        Collect relevant documents from ChromaDB based on query with optimized workflow:
        1) Combine results from all extended queries
        2) Apply keyword filters across all results
        3) Remove similar documents and apply cross-encoder if requested
        4) Evaluate against target and add additional documents as needed
        
        Args:
            data: Configuration data for collection and processing
            query: The user query
            
        Returns:
            String with collected document blocks
        """
        # Start with empty collections
        collected_blocks = []
        candidate_docs = {
            'ids': [],
            'documents': [],
            'metadatas': [],
            'embeddings': []
        }
        
        # Extract filter keywords for hybrid search
        filter_keywords = self.extract_filter_keywords(query)
        print(f"Using filter keywords: {filter_keywords}")
        
        # Configure retrieval parameters
        use_crossencoder = "crossencoder" in data["processing_steps"]
        target_docs = data["inclusions"]
        initial_k = target_docs * 10 if use_crossencoder else target_docs * 3
        
        # Generate extended queries if needed
        if "extend_query" in data["processing_steps"]:
            self.extend_query(query)
            self.extended_query.append(query)
        else:
            self.extended_query = [query]
            
        # Get ChromaDB collection
        client = self.chroma_client.get_collection(data["data"], embedding_function=self.chroma_embedder)
        
        # STEP 1: Retrieve candidate documents from all extended queries
        print(f"Retrieving initial candidate documents for {len(self.extended_query)} queries")
        all_ids = set()  # Track IDs to avoid duplicates
        
        for q in self.extended_query:
            print(f"Retrieving documents for query: {q}")
            
            # Retrieve a larger batch of documents
            retrieved_docs = client.query(
                query_texts=[q],
                n_results=initial_k,
                include=["documents", "metadatas", "embeddings"]
            )
            
            # Only process if we got results
            if retrieved_docs['documents'] and len(retrieved_docs['documents'][0]) > 0:
                # Add unique documents to our candidates pool
                for i, doc_id in enumerate(retrieved_docs['ids'][0]):
                    if doc_id not in all_ids:
                        all_ids.add(doc_id)
                        candidate_docs['ids'].append(doc_id)
                        candidate_docs['documents'].append(retrieved_docs['documents'][0][i])
                        candidate_docs['metadatas'].append(retrieved_docs['metadatas'][0][i])
                        if retrieved_docs['embeddings'] and len(retrieved_docs['embeddings'][0]) > i:
                            candidate_docs['embeddings'].append(retrieved_docs['embeddings'][0][i])
        
        print(f"Retrieved {len(candidate_docs['ids'])} unique candidate documents")
        
        # STEP 2: Apply keyword filtering if requested
        filtered_docs = []
        if filter_keywords and "keyword_filter" in data.get("processing_steps", []):
            print(f"Applying keyword filter with keywords: {filter_keywords}")
            
            # First try documents containing ALL keywords
            all_keywords_docs = []
            for i, doc in enumerate(candidate_docs['documents']):
                doc_lower = doc.lower()
                if all(keyword.lower() in doc_lower for keyword in filter_keywords):
                    embedding = candidate_docs['embeddings'][i] if i < len(candidate_docs['embeddings']) else None
                    all_keywords_docs.append({
                        'id': candidate_docs['ids'][i],
                        'document': doc,
                        'metadata': candidate_docs['metadatas'][i],
                        'embedding': embedding
                    })
            
            print(f"Found {len(all_keywords_docs)} documents containing ALL keywords")
            filtered_docs.extend(all_keywords_docs)
            
            # If we don't have enough with all keywords, try documents with ANY keyword
            if len(all_keywords_docs) < target_docs:
                print("Looking for documents with ANY keyword")
                any_keyword_docs = []
                for i, doc in enumerate(candidate_docs['documents']):
                    doc_id = candidate_docs['ids'][i]
                    # Skip if already included
                    if any(d['id'] == doc_id for d in filtered_docs):
                        continue
                        
                    doc_lower = doc.lower()
                    if any(keyword.lower() in doc_lower for keyword in filter_keywords):
                        embedding = candidate_docs['embeddings'][i] if i < len(candidate_docs['embeddings']) else None
                        any_keyword_docs.append({
                            'id': doc_id,
                            'document': doc,
                            'metadata': candidate_docs['metadatas'][i],
                            'embedding': embedding
                        })
                
                print(f"Found {len(any_keyword_docs)} additional documents with ANY keyword")
                filtered_docs.extend(any_keyword_docs)
        else:
            # Without keyword filtering, use all candidates
            for i, doc in enumerate(candidate_docs['documents']):
                embedding = candidate_docs['embeddings'][i] if i < len(candidate_docs['embeddings']) else None
                filtered_docs.append({
                    'id': candidate_docs['ids'][i],
                    'document': doc,
                    'metadata': candidate_docs['metadatas'][i],
                    'embedding': embedding
                })
        
        # STEP 3: Process using similarity threshold to remove near-duplicates
        print("Filtering candidates by similarity threshold")
        min_doc_length = 30
        #max_token_length = 10000  # Avoid documents that are too long
        
        # Apply basic filtering (length, tokens)
        candidates = [doc for doc in filtered_docs 
                     if len(doc['document']) >= min_doc_length ]
        
        print(f"Have {len(candidates)} candidates after basic filtering")
        
        # Apply similarity filtering to remove near-duplicates
        selected_docs = []
        selected_embeddings = []
        similarity_threshold = 0.95
        
        for candidate in candidates:
            # Skip documents without embeddings
            if candidate['embedding'] is None or not isinstance(candidate['embedding'], (list, np.ndarray)):
                continue
                
            candidate_embedding = np.array(candidate['embedding'])
            # Normalize embedding
            norm = np.linalg.norm(candidate_embedding)
            if norm > 0:
                candidate_embedding = candidate_embedding / norm
                
            # Check if candidate is too similar to any already selected document
            is_redundant = False
            for sel_emb in selected_embeddings:
                similarity = np.dot(candidate_embedding, sel_emb)
                if similarity >= similarity_threshold:
                    is_redundant = True
                    break
                    
            if not is_redundant:
                selected_docs.append(candidate)
                selected_embeddings.append(candidate_embedding)
        
        print(f"Selected {len(selected_docs)} documents after similarity filtering")
        
        # STEP 4: Apply cross-encoder reranking if requested
        final_docs = []
        
        if use_crossencoder and len(selected_docs) > target_docs:
            print("Applying cross-encoder reranking")
            from sentence_transformers import CrossEncoder
            cross_model = CrossEncoder('BAAI/bge-reranker-base')
            
            # Create query-document pairs for the reranker
            query_doc_pairs = [(query, doc['document']) for doc in selected_docs]
            scores = cross_model.predict(query_doc_pairs)
            
            # Sort by score (highest first)
            doc_score_pairs = list(zip(selected_docs, scores))
            ranked_docs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
            
            # Select top documents after reranking
            final_docs = [doc for doc, _ in ranked_docs[:target_docs]]
        else:
            # If not using cross-encoder or don't have enough docs, use all selected docs
            final_docs = selected_docs[:target_docs]
        
        # STEP 5: If we still don't have enough documents, try unfiltered search
        if len(final_docs) < target_docs and len(final_docs) < len(candidates):
            print(f"Adding {target_docs - len(final_docs)} more documents to reach target")
            # Find documents not already selected
            remaining_docs = [doc for doc in candidates if doc['id'] not in [d['id'] for d in final_docs]]
            # Add up to the target number
            final_docs.extend(remaining_docs[:target_docs - len(final_docs)])
        
        # STEP 6: Process final documents and create blocks
        print(f"Processing final {len(final_docs)} documents")
        for doc in final_docs:
            # Extract the most relevant content using query-based extractor
            extracted_content = self.extract_for_queries(doc['document'], self.extended_query)
            
            # Add reference in formatted text
            collected_blocks.append(f"[block {self.block_counter}] {extracted_content}")
            
            # Create proper blocks_dict entry
            filepath = doc['metadata'].get('bibtex', doc['metadata'].get('path', ''))
            if filepath and filepath.lower().endswith(('.pptx', '.docx', '.xlsx', '.pdf', '.csv', '.txt')):
                self.blocks_dict[self.block_counter] = {
                    "type": "document",
                    "id": f"doc_{self.block_counter}",
                    "path": filepath,
                    "description": f"Document from ChromaDB collection '{data['data']}'"
                }
            elif '@article' in filepath:
                self.blocks_dict[self.block_counter] = {
                    "type": "literature",
                    "id": f"doc_{self.block_counter}",
                    "bibtex": filepath,
                    "description": f"Document from ChromaDB collection '{data['data']}'"
                }
            else:
                self.blocks_dict[self.block_counter] = {
                    "type": "generic",
                    "id": f"ref_{self.block_counter}",
                    "content": f"ChromaDB: {data['data']} - {filepath}"
                }
            self.block_counter += 1
        
        print(f"Added {len(final_docs)} blocks to the response")
        return "\n".join(collected_blocks)
    
    def get_embedding(self, text):
        """Generate an embedding for the given text using OpenAI's text-embedding model."""
        # Use direct client instead of module-level API to avoid ambiguity errors
        from openai import OpenAI
        
        try:
            client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
            response = client.embeddings.create(
                model="text-embedding-3-small",
                input=text
            )
            embedding = response.data[0].embedding
            return embedding
        except Exception as e:
            print(f"Error generating embedding: {str(e)}")
            # Return an empty embedding in case of error
            return [0.0] * 1536  # Typical dimension for text-embedding-3-small
        
    def collect_data_from_neo4j(self, data, query, doc_type="literature data"):
        """
        Collect relevant documents from Neo4j using keyword pre-filtering and FAISS vector search
        
        Args:
            data: Dictionary containing search configuration
            query: User's query text
            doc_type: Type of documents to search (default: "literature data")
            
        Returns:
            String with formatted text blocks for LLM context
        """
        
        

        collected_blocks = []
        

            
        # Handle query extension if needed
        if "extend_query" in data["processing_steps"]:
            self.extend_query(query)
            self.extended_query.append(query)
        else:
            self.extended_query = [query]

        partial_inclusions=max(2,data['inclusions']//len(self.extended_query)+1)
        
        # Set up embeddings
        embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
        
        # Set up a collector for documents processed
        docs_added = 0
        #target_docs = data["inclusions"]
        self.litt_nodes_cache = None
        self.litt_VS_cache = None
        
        # Process each query (original and any extensions)
        for q in self.extended_query:
            print(f"Working on query: {q}")
            
            # First, get embeddings for the query
            query_embedding = self.get_embedding(q)
            
            # Process Text_chunks
            text_chunks_processed = self._process_node_type(
                data['data'],  # Base Cypher filter
                query_embedding,
                q,
                "Text_chunk", 
                "Text",
                partial_inclusions,
                embedding_function,
                doc_type
            )
            
            if text_chunks_processed:
                collected_blocks.extend(text_chunks_processed["blocks"])
                docs_added += len(text_chunks_processed["blocks"])
            
            print(f"Added {len(text_chunks_processed.get('blocks', []))} text chunks")
            
            # If we still need more documents, try Table_chunks
            # if docs_added < target_docs:
            #     table_chunks_processed = self._process_node_type(
            #         data['data'],  # Base Cypher filter
            #         query_embedding, 
            #         q,
            #         "Table_chunk",
            #         "Html",  # Use Html for tables
            #         initial_k,
            #         embedding_function,
            #         doc_type,
            #         target_docs - docs_added
            #     )
                
            #     if table_chunks_processed:
            #         collected_blocks.extend(table_chunks_processed["blocks"])
            #         docs_added += len(table_chunks_processed["blocks"])
                    
            #     print(f"Added {len(table_chunks_processed.get('blocks', []))} table chunks")
            
            # # If we have enough documents, stop processing queries
            # if docs_added >= target_docs:
            #     break
        
        print(f"Total blocks added: {len(collected_blocks)}")
        return "\n".join(collected_blocks)
    

    def collect_data_from_neo4j_new(self, data, query, doc_type="literature data"):
        """
        Collect relevant documents from Neo4j using optimized workflow:
        1) Combine results from all extended queries
        2) Apply keyword filters across all results
        3) Remove similar documents and apply cross-encoder if requested
        4) Evaluate against target and add additional documents as needed
        
        Args:
            data: Dictionary containing search configuration
            query: User's query text
            doc_type: Type of documents to search (default: "literature data")
            
        Returns:
            String with formatted text blocks for LLM context
        """
        # Clear cache for new query
        self.litt_nodes_cache = None
        self.litt_VS_cache = None
        
        collected_blocks = []
        
        # Configure retrieval parameters
        use_crossencoder = "crossencoder" in data["processing_steps"]
        target_docs = data["inclusions"]
        initial_k = target_docs * 10 if use_crossencoder else target_docs * 3
        
        # Handle query extension if needed
        if "extend_query" in data["processing_steps"]:
            self.extend_query(query)
            self.extended_query.append(query)
        else:
            self.extended_query = [query]
        
        # Set up embeddings for vector search
        embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
        
        # STEP 1: Get query embedding for vector search
        query_embedding = self.get_embedding(query)
        
        # STEP 2: Retrieve nodes from Neo4j for all queries together
        # This retrieval is based solely on cypher queries from all extended queries
        all_nodes = []
        all_node_ids = set()  # Track retrieved node IDs to avoid duplicates
        
        
        print(f"Retrieving Neo4j nodes")
        
        # Extract keywords for this query to use in Neo4j filtering
        if "keyword_filter" in data.get("processing_steps", []):
            filter_keywords = self.extract_filter_keywords(q, n_keywords=3)
            print(f"Using filter keywords for Neo4j: {filter_keywords}")
        else:
            filter_keywords = []

        
        # Process base query based on type
        if isinstance(data['data'], list):
            # Multiple query variants
            for query_variant in data['data']:
                nodes = self._fetch_neo4j_nodes(
                    query_variant, 
                    filter_keywords,
                    "Text_chunk", 
                    "Text",
                    initial_k
                )
                # Add unique nodes to collection
                for node in nodes:
                    if node["uid"] not in all_node_ids:
                        all_node_ids.add(node["uid"])
                        all_nodes.append(node)
        else:
            # Single query string
            nodes = self._fetch_neo4j_nodes(
                data['data'],
                filter_keywords,
                "Text_chunk", 
                "Text",
                initial_k
            )
            # Add unique nodes to collection
            for node in nodes:
                if node["uid"] not in all_node_ids:
                    all_node_ids.add(node["uid"])
                    all_nodes.append(node)
    

        print(f"Retrieved {len(all_nodes)} unique nodes from Neo4j")
    
        # Cache all retrieved nodes
        self.litt_nodes_cache = all_nodes
        for q in self.extended_query:
            # STEP 3: Filter nodes by basic criteria (length, token count)
            min_doc_length = 30
            max_token_length = 15000  # Avoid very long documents
            
            filtered_nodes = [
                node for node in all_nodes 
                if node["content"] and len(node["content"]) >= min_doc_length ]
            
            print(f"Have {len(filtered_nodes)} nodes after basic filtering")
            
            # If no filtered nodes, return empty result
            if not filtered_nodes:
                return ""
                
            # STEP 4: Apply vector search with similarity threshold to remove near-duplicates
            selected_nodes = []
            selected_embeddings = []
            similarity_threshold = 0.95
            
            # First generate embeddings for all filtered nodes
            node_embeddings = []
            for node in filtered_nodes:
                try:
                    content = node["content"]
                    if self.count_tokens(content) > 8192:
                        # Summarize very long content
                        from openai import OpenAI
                        client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
                        summarized = self.extract_for_queries(content, self.extended_query[:1])
                        embedding = client.embeddings.create(
                            model="text-embedding-3-small",
                            input=summarized
                        ).data[0].embedding
                    else:
                        # Use original content for shorter texts
                        embedding = self.get_embedding(content)
                    
                    node_embeddings.append((node, embedding))
                except Exception as e:
                    print(f"Error generating embedding for node {node['uid']}: {str(e)}")
                    # Skip this node
                    continue
            
            # Apply similarity filtering
            query_embedding_array = np.array(query_embedding)
            for node, embedding in node_embeddings:
                embedding_array = np.array(embedding)
                
                # Normalize embeddings
                norm = np.linalg.norm(embedding_array)
                if norm > 0:
                    embedding_array = embedding_array / norm
                
                # Check if too similar to any already selected node
                is_redundant = False
                for sel_emb in selected_embeddings:
                    similarity = np.dot(embedding_array, sel_emb)
                    if similarity >= similarity_threshold:
                        is_redundant = True
                        break
                
                if not is_redundant:
                    selected_nodes.append(node)
                    selected_embeddings.append(embedding_array)
            
            print(f"Selected {len(selected_nodes)} nodes after similarity filtering")
            
            # STEP 5: Apply cross-encoder reranking if requested
            final_nodes = []
            
            if use_crossencoder and len(selected_nodes) > target_docs:
                print("Applying cross-encoder reranking")
                from sentence_transformers import CrossEncoder
                cross_model = CrossEncoder('BAAI/bge-reranker-base')
                
                # Create query-document pairs for the reranker
                query_doc_pairs = [(query, node["content"]) for node in selected_nodes]
                scores = cross_model.predict(query_doc_pairs)
                
                # Sort by score (highest first)
                node_score_pairs = list(zip(selected_nodes, scores))
                ranked_nodes = sorted(node_score_pairs, key=lambda x: x[1], reverse=True)
                
                # Select top nodes after reranking
                final_nodes = [node for node, _ in ranked_nodes[:target_docs]]
            else:
                # If not using cross-encoder, take top nodes by relevance
                final_nodes = selected_nodes[:target_docs]
        
        # STEP 6: Process final nodes into blocks
        print(f"Processing final {len(final_nodes)} nodes into blocks")
        
        for node in final_nodes:
            content = node["content"]
            uid = node["uid"]
            parent_name = node["parent_name"] if "parent_name" in node else "Unknown"
            bibtex = node["bibtex"] if "bibtex" in node else None
            
            # Extract the most relevant content using query-based extractor
            extracted_content = self.extract_for_queries(content, self.extended_query)
            
            # Add reference in formatted text
            block_text = f"[block {self.block_counter}] {extracted_content}"
            collected_blocks.append(block_text)
            
            # Create reference entry
            if bibtex and '@' in bibtex:
                # Literature reference with BibTeX
                self.blocks_dict[self.block_counter] = {
                    "type": "literature",
                    "id": f"lit_{self.block_counter}",
                    "bibtex": bibtex,
                    "content": f"Neo4j literature: {parent_name}"
                }
            else:
                # Generic document reference
                self.blocks_dict[self.block_counter] = {
                    "type": "generic",
                    "id": f"doc_{self.block_counter}", 
                    "content": f"Document: {parent_name}"
                }
            
            # Increment block counter
            self.block_counter += 1
        
        print(f"Added {len(final_nodes)} blocks to the response")
        return "\n".join(collected_blocks)
    
    def _fetch_neo4j_nodes(self, base_query, keywords, node_type, content_field, limit):
        """
        Helper method to fetch nodes from Neo4j with keyword filtering
        
        Args:
            base_query: Base cypher query for filtering
            keywords: Keywords for filtering results
            node_type: Type of node to fetch
            content_field: Field containing the content
            limit: Maximum number of nodes to retrieve
            
        Returns:
            List of node dictionaries
        """
        # Construct keyword clause for filtering
        keyword_clauses = []
        for keyword in keywords:
            # Escape single quotes in keywords
            safe_keyword = keyword.replace("'", "\\'")
            keyword_clauses.append(f"x.{content_field} CONTAINS '{safe_keyword}'")
        
        # Combine keyword clauses with OR
        if keyword_clauses==[]:
            keyword_filter = None
        else:       
            keyword_filter = " OR ".join(keyword_clauses)
        
        # Construct the final query with keyword filtering
        if keyword_filter:
            if "WHERE" in base_query or "where" in base_query:
                # Add to existing WHERE clause
                query_with_keywords = base_query.replace("where","WHERE").replace("WHERE", f"WHERE ({keyword_filter}) AND ")
            else:
                # Add new WHERE clause
                query_with_keywords = f"{base_query} WHERE {keyword_filter}"
        else:
            # No keywords, use original query
            query_with_keywords = base_query
        
        # Complete the query to fetch publications and other metadata
        cypher_query = f"""
        {query_with_keywords}
        MATCH (p:Publication)-->(x)
        RETURN x.UID AS uid, 
            x.{content_field} AS content,
            p.Name AS parent_name,
            p.BibTex AS bibtex
        LIMIT {limit}
        """
        
        # Execute query and collect results
        results = self.session.run(cypher_query)
        nodes = []
        
        for record in results:
            nodes.append({
                "uid": record["uid"],
                "content": record["content"],
                "parent_name": record["parent_name"],
                "bibtex": record["bibtex"],
                "node_type": node_type
            })
        
        return nodes

    def _process_node_type(self, base_query, query_embedding, query_text, node_type, content_field, k, embedding_function, doc_type):
        """
        Helper method to process a specific node type with FAISS vector search
        
        Args:
            base_query: Base cypher query for pre-filtering
            query_embedding: Embedding vector for the query
            query_text: Text of the query for cross-encoder ranking
            node_type: Type of node to process (Text_chunk or Table_chunk)
            content_field: Field containing the node content
            k: Number of results to retrieve
            embedding_function: Function to generate embeddings
            doc_type: Type of document
            max_results: Maximum number of results to return
            
        Returns:
            Dictionary with blocks added and other metadata
        """
        import numpy as np
        import faiss
        from sentence_transformers import CrossEncoder
        
        processing_steps = self.data_handles.handlers.get(doc_type, {}).get("processing_steps", [])
        if "crossencoder" in processing_steps:
            k_initial=k*10
        else:
            k_initial=k
        # Step 1: Fetch pre-filtered nodes from Neo4j without vector search
        # Instead of using the raw base_query, let's parse it properly
        if not self.litt_nodes_cache:
            if isinstance(base_query, list):
                # If base_query is a list of query strings, run them separately
                all_nodes = []
                for query_variant in base_query:
                    # Use MATCH pattern instead of directly inserting raw query
                    cypher_query = f"""
                    {query_variant}
                    WITH x, count(distinct k) as keyword_count
                    ORDER BY keyword_count DESC
                    LIMIT {k_initial * 10}
                    MATCH (p:Publication)-->(x)
                    RETURN x.UID AS uid, 
                        x.{content_field} AS content,
                        p.Name AS parent_name,
                        p.BibTex AS bibtex
                    """
                    
                    results = self.session.run(cypher_query)
                    for record in results:
                        all_nodes.append({
                            "uid": record["uid"],
                            "content": record["content"],
                            "parent_name": record["parent_name"],
                            "bibtex": record["bibtex"],
                            "node_type": node_type
                        })
                
                nodes = all_nodes
            else:
                # For string base_query, use it as a filter
                cypher_query = f"""
                {base_query}
                MATCH (p:Publication)-->(x) 
                RETURN x.UID AS uid, 
                    x.{content_field} AS content,
                    p.Name AS parent_name,
                    p.BibTex AS bibtex
                LIMIT {k_initial * 10}
                """
                
                results = self.session.run(cypher_query)
                nodes = []
                for record in results:
                    nodes.append({
                        "uid": record["uid"],
                        "content": record["content"],
                        "parent_name": record["parent_name"],
                        "bibtex": record["bibtex"],
                        "node_type": node_type
                    })
            self.litt_nodes_cache = nodes
        else:    
            nodes = self.litt_nodes_cache
        
        # Rest of the method remains the same...
        contents = [node["content"] for node in nodes if node["content"] and len(node["content"]) >= 30]
        metadata = [{
            "uid": node["uid"],
            "parent_name": node["parent_name"],
            "bibtex": node["bibtex"],
            "node_type": node["node_type"]
        } for node in nodes if node["content"] and len(node["content"]) >= 30]
        
        # If we didn't find any nodes, return empty result
        if not contents:
            return {"blocks": [], "count": 0}
            
        # Continue with the rest of the method...
        if len(contents) > 0:
            try:
                if not(self.litt_VS_cache):
                    # Use a direct OpenAI client without module-level API to avoid ambiguity
                    from openai import OpenAI
                    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
                    
                    # Generate embeddings for all contents
                    content_embeddings = []
                    for content in contents:
                        try:
                            if self.count_tokens(content) > 8192:
                                # Summarize for very long content
                                content = self.summarize_text(content, os.environ["OPENAI_API_KEY"])
                            
                            # Use direct client instead of module-level functions
                            response = client.embeddings.create(
                                model="text-embedding-3-small",
                                input=content
                            )
                            embedding = response.data[0].embedding
                            content_embeddings.append(embedding)
                        except Exception as e:
                            print(f"Error generating embedding: {str(e)}")
                            content_embeddings.append([0.0] * len(query_embedding))  # Empty embedding
                    
                    # Create FAISS index
                    dimension = len(query_embedding)
                    index = faiss.IndexFlatL2(dimension)
                    
                    # Add embeddings to index
                    if content_embeddings:
                        

                        index.add(np.array(content_embeddings, dtype=np.float32))
                        self.litt_VS_cache = index
                else:
                    index=self.litt_VS_cache
                    
                try:
                    # Search for similar vectors
                    D, I = index.search(np.array([query_embedding], dtype=np.float32), k_initial)
                    
                    # Get the most similar nodes
                    similar_indices = I[0]
                    similar_nodes = [nodes[idx] for idx in similar_indices if idx < len(nodes)]

                    
                    # Step 3: Apply cross-encoder reranking if needed
                    processing_steps = self.data_handles.handlers.get(doc_type, {}).get("processing_steps", [])
                    if "crossencoder" in processing_steps:
                        print("Applying cross-encoder reranking")
                        cross_model = CrossEncoder('BAAI/bge-reranker-base')
                        
                        # Prepare document pairs for reranking
                        query_chunk_pairs = [(query_text, node["content"]) for node in similar_nodes]
                        scores = cross_model.predict(query_chunk_pairs)
                        
                        # Combine nodes with their scores
                        node_score_pairs = list(zip(similar_nodes, scores))
                        
                        # Sort by score (highest first)
                        ranked_nodes = sorted(node_score_pairs, key=lambda x: x[1], reverse=True)
                        
                        # Take top nodes
                        top_nodes = [node for node, _ in ranked_nodes[:k]]
                    else:
                        # Just limit the number if no reranking
                        top_nodes = similar_nodes[:k]
                    
                    # Step 4: Format the results
                    blocks = []
                    
                    for i, node in enumerate(top_nodes):
                        content = node["content"]
                        uid = node["uid"]
                        parent_name = node["parent_name"]
                        bibtex = node["bibtex"]
                        node_type = node["node_type"]
                        
                        # Format the content block
                        content = self.extract_for_queries(content, self.extended_query)
                        block_text = f"[block {self.block_counter}] {content}"
                        blocks.append(block_text)
                        
                        # Create reference entry
                        if bibtex and '@' in bibtex:
                            # Literature reference with BibTeX
                            self.blocks_dict[self.block_counter] = {
                                "type": "literature",
                                "id": f"lit_{self.block_counter}",
                                "bibtex": bibtex,
                                "content": f"Neo4j literature: {parent_name}"
                            }
                        else:
                            # Generic document reference
                            self.blocks_dict[self.block_counter] = {
                                "type": "generic",
                                "id": f"doc_{self.block_counter}", 
                                "content": f"Document: {parent_name}"
                            }
                        
                        # Increment block counter
                        self.block_counter += 1
                    
                    return {"blocks": blocks, "count": len(blocks)}
                except Exception as e:
                    print(f"Error processing block results: {str(e)}")
                    return {"blocks": [], "count": 0}
                    
            except Exception as e:
                print(f"Error in FAISS processing: {str(e)}")
        
        return {"blocks": [], "count": 0}
    

    def collect_text_blocks(self, data,query):
        embedding_function = OpenAIEmbeddings()
        if "crossencoder" in data["processing_steps"]:
            initial_k=data["inclusions"]*10
        else:
            initial_k=data["inclusions"]
        if "extend_query" in data["processing_steps"]:
            self.extend_query(query)
            self.extended_query.append(query)
        else:
            self.extended_query=[query]
        ## First step is alway a similarity search
        collected_blocks = []
        retriever = data['data'].as_retriever(
                search_type="similarity", 
                search_kwargs={"k": initial_k*3}
            )
        for q in self.extended_query:
            print("working on query ",q)
            retrieved_docs = retriever.invoke(q)
            retrieved_texts = [doc.page_content for doc in retrieved_docs if len(doc.page_content)>30]
            # Here we recompute embeddings for each candidate document.
            candidate_embeddings = np.array([self.normalize(embedding_function.embed_query(doc))
                                 for doc in retrieved_texts])

            # Compute and normalize the query embedding
            query_embedding = self.normalize(np.array(embedding_function.embed_query(q)))

            # 4. Run MMR to select a diverse subset of documents
            print("running MMR")
            #retrieved_texts = self.mmr(query_embedding, candidate_embeddings, retrieved_texts, lambda_param=0.5, top_k=initial_k)
            retrieved_texts=self.similarity_threshold_filter(query_embedding, candidate_embeddings, retrieved_texts, similarity_threshold=0.95,top_k=initial_k)
            ## If crossencoder is used, we need to rerank the results
            if "crossencoder" in data["processing_steps"]:
                cross_model = CrossEncoder('BAAI/bge-reranker-base')
                query_chunk_pairs = [(q, chunk) for chunk in retrieved_texts]
                scores = cross_model.predict(query_chunk_pairs)
                chunk_score_pairs = list(zip(retrieved_texts, scores))
                ranked_chunks = sorted(chunk_score_pairs, key=lambda x: x[1], reverse=True)
                retrieved_texts = [chunk for chunk, score in ranked_chunks[:data["inclusions"]//2]]
            #print("blocks from ",q," \n","\n".join(retrieved_texts))
            for block in retrieved_texts:
                collected_blocks.append("[block "+str(self.block_counter)+"] "+block)
                self.inline_refs['block '+str(self.block_counter)]='VStore Block '+str(self.block_counter)
                self.block_counter+=1
        return "\n".join(collected_blocks)
    
    def generate_prompt(self,template,data_sections,query):
        prompt_template=my_prompt_templates.get(template,'')
        if prompt_template=="":
            prompt_template=my_prompt_templates.get("Vaccine_base",'')
        prompt=prompt_template["Instructions"]+"\n"
        i=0
        for i, key in enumerate(data_sections.keys()):
            prompt=prompt+"Step "+str(i+1)+" on section labeled  [" +key+"]:  "+self.data_handles.handlers[key]['instructions']+ "\n"
        if self.flow_control["enable_search"]:
            prompt=prompt+"Step "+str(i+2)+" on section labeled [web search results] : Provide a summary of the given context data extracted from the web, using summary tables when possible.\n"
        if self.flow_control["enable_memory"]:
            prompt=prompt+"Step "+str(i+3)+" on section labeled [previous chats] : Also take into account your previous answers.\n"
        prompt=prompt+prompt_template["Output Constraints"]+"\n\n"
        i=0
        for i, key in enumerate(data_sections.keys()):
            prompt=prompt+"Data section "+str(i+1)+"- [" +key+"]\n"+data_sections[key]+ "\n"
        if self.flow_control["enable_search"]:
            prompt=prompt+"Data section "+str(i+2)+"- [web search results] \n"+self.search_results+ "\n"
        if self.flow_control["enable_memory"]:
            prompt=prompt+"Data section "+str(i+3)+"- [previous chats] \n"+self.chat_memory.get_formatted_history()+ "\n"
        prompt=prompt+"User query: "+query
        return prompt

    def extend_query(self,query):
        llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
        prompt = f"""
        You are an AI that enhances a given user search query to improve information  retrieval.
    
        ### User Query:
        {query}

        ### Instructions:   
        - Provide exactly 5 expanded queries.
        - Each query should explore a different aspect or perspective of the original query.
        - Use synonyms, related terms, or rephrased versions to cover various dimensions of the topic.
        - Ensure that the expanded queries are relevant and coherent with the original query.
        - Avoid generating queries that are too similar to each other.
        - ONLY return the text of the expanded queries.
    
        ### Expanded Queries:
        """
        answer = llm.invoke(prompt)
        self.extended_query=[x for x in answer.content.strip().split("\n") if x != ""]
        print("extended query",self.extended_query)
        return
    
    def normalize(self,vector):
 
       return vector / np.linalg.norm(vector)
    
    def mmr(self,query_embedding, candidate_embeddings, candidate_docs, lambda_param=0.7, top_k=5):

        # Compute similarity between the query and each candidate (dot product assumes normalized vectors)
        candidate_similarities = np.dot(candidate_embeddings, query_embedding)
        
        # Initialize selected and remaining indices
        selected_indices = []
        candidate_indices = list(range(len(candidate_docs)))
        
        # First selection: candidate with highest similarity
        first_idx = int(np.argmax(candidate_similarities))
        selected_indices.append(first_idx)
        candidate_indices.remove(first_idx)
        
        # Iteratively select documents that balance relevance and diversity
        while len(selected_indices) < top_k and candidate_indices:
            best_score = -np.inf
            best_idx = None
            for idx in candidate_indices:
                # Relevance score for candidate idx
                relevance = candidate_similarities[idx]
                # Diversity score: maximum similarity with any already selected document
                diversity = max(np.dot(candidate_embeddings[idx], candidate_embeddings[sel_idx])
                                for sel_idx in selected_indices)
                # Combined MMR score
                score = lambda_param * relevance - (1 - lambda_param) * diversity
                if score > best_score:
                    best_score = score
                    best_idx = idx
            selected_indices.append(best_idx)
            candidate_indices.remove(best_idx)
        
        return [candidate_docs[i] for i in selected_indices]

    def similarity_threshold_filter(self, query_embedding, candidate_embeddings, candidate_docs, similarity_threshold=0.9,top_k=5):
   
        selected_docs = []
        selected_embeddings = []

        # Compute query similarity scores for sorting candidates (highest first)
        candidate_scores = np.dot(candidate_embeddings, query_embedding)
        sorted_indices = np.argsort(candidate_scores)[::-1]

        for idx in sorted_indices:
            candidate_embedding = candidate_embeddings[idx]
            # Check if candidate is too similar to any already selected document
            is_redundant = any(np.dot(candidate_embedding, sel_emb) >= similarity_threshold 
                            for sel_emb in selected_embeddings)
            if not is_redundant and len(selected_docs) < top_k:
                print("appending ",candidate_docs[idx])
                selected_docs.append(candidate_docs[idx])
                selected_embeddings.append(candidate_embedding)
                
        return selected_docs

Parameters

Name Type Default Kind
bases - -

Parameter Details

bases: Parameter of type

Return Value

Returns unspecified type

Class Interface

Methods

__init__(self)

Purpose: Internal method: init

Returns: None

init_connections(self)

Purpose: Performs init connections

Returns: None

run_query(self, query, params)

Purpose: Execute a Cypher query and return the result Parameters ---------- query : str The Cypher query to execute params : dict, optional Parameters for the query Returns ------- result The query result

Parameters:

  • query: Parameter
  • params: Parameter

Returns: See docstring for return details

evaluate_query(self, query, params)

Purpose: Execute a Cypher query and return a single result Parameters ---------- query : str The Cypher query to execute params : dict, optional Parameters for the query Returns ------- object The single result value

Parameters:

  • query: Parameter
  • params: Parameter

Returns: See docstring for return details

push_changes(self, node)

Purpose: Push changes to a node to the database Parameters ---------- node : dict or node-like object Node with properties to update

Parameters:

  • node: Parameter

Returns: None

count_tokens(self, text)

Purpose: Performs count tokens

Parameters:

  • text: Parameter

Returns: None

set_api_keys(self)

Purpose: Sets api keys

Returns: None

extract_core_query(self, query_text)

Purpose: Extracts the core information-seeking question from a user query that may contain both a question and processing instructions for the RAG system. Args: query_text: The original user query text Returns: dict: Contains the extracted information with keys: - core_question: The actual information need/question - instructions: Any processing instructions found - is_complex: Boolean indicating if query contained instructions

Parameters:

  • query_text: Parameter

Returns: See docstring for return details

extract_serper_results(self, serper_response)

Purpose: Extract formatted search results and URLs from GoogleSerperAPI response. Args: serper_response: Raw response from GoogleSerperAPI (JSON object or string) Returns: tuple: (formatted_results, extracted_urls)

Parameters:

  • serper_response: Parameter

Returns: See docstring for return details

response_callback(self, query)

Purpose: Performs response callback

Parameters:

  • query: Parameter

Returns: None

get_embedding(self, text)

Purpose: Generate an embedding for the given text using OpenAI's text-embedding-ada-002 model.

Parameters:

  • text: Parameter

Returns: None

extract_for_queries(self, text, queries, max_tokens, api_key)

Purpose: Extract information from text based on queries. Args: text: Text to extract from queries: List of queries to guide extraction max_tokens: Maximum tokens in the output api_key: API key for the LLM service Returns: Extracted text relevant to the queries

Parameters:

  • text: Parameter
  • queries: Parameter
  • max_tokens: Parameter
  • api_key: Parameter

Returns: See docstring for return details

parse_handler(self, query)

Purpose: Performs parse handler

Parameters:

  • query: Parameter

Returns: None

reformat_data(self, data, min_document_length, similarity_threshold, use_crossencoder, inclusions)

Purpose: Reformat and filter data to be grouped by ID, excluding too-short documents and documents that are too similar to each other. Optionally applies crossencoder ranking. Args: data: Original data structure min_document_length: Minimum character length for documents to include (default: 30) similarity_threshold: Threshold for document similarity (default: 0.95, higher means more similar) use_crossencoder: Whether to apply crossencoder reranking (default: False) inclusions: Number of documents to return after filtering (default: 10) Returns: List of selected documents (not dictionary)

Parameters:

  • data: Parameter
  • min_document_length: Parameter
  • similarity_threshold: Parameter
  • use_crossencoder: Parameter
  • inclusions: Parameter

Returns: See docstring for return details

extract_filter_keywords(self, query, n_keywords)

Purpose: Extract distinguishing keywords from a query for filtering search results. Args: query: The user's query text n_keywords: Maximum number of keywords to extract Returns: List of keywords for filtering

Parameters:

  • query: Parameter
  • n_keywords: Parameter

Returns: See docstring for return details

collect_data_from_chroma(self, data, query)

Purpose: Collect relevant documents from ChromaDB based on query with optimized workflow: 1) Combine results from all extended queries 2) Apply keyword filters across all results 3) Remove similar documents and apply cross-encoder if requested 4) Evaluate against target and add additional documents as needed Args: data: Configuration data for collection and processing query: The user query Returns: String with collected document blocks

Parameters:

  • data: Parameter
  • query: Parameter

Returns: See docstring for return details

get_embedding(self, text)

Purpose: Generate an embedding for the given text using OpenAI's text-embedding model.

Parameters:

  • text: Parameter

Returns: None

collect_data_from_neo4j(self, data, query, doc_type)

Purpose: Collect relevant documents from Neo4j using keyword pre-filtering and FAISS vector search Args: data: Dictionary containing search configuration query: User's query text doc_type: Type of documents to search (default: "literature data") Returns: String with formatted text blocks for LLM context

Parameters:

  • data: Parameter
  • query: Parameter
  • doc_type: Parameter

Returns: See docstring for return details

collect_data_from_neo4j_new(self, data, query, doc_type)

Purpose: Collect relevant documents from Neo4j using optimized workflow: 1) Combine results from all extended queries 2) Apply keyword filters across all results 3) Remove similar documents and apply cross-encoder if requested 4) Evaluate against target and add additional documents as needed Args: data: Dictionary containing search configuration query: User's query text doc_type: Type of documents to search (default: "literature data") Returns: String with formatted text blocks for LLM context

Parameters:

  • data: Parameter
  • query: Parameter
  • doc_type: Parameter

Returns: See docstring for return details

_fetch_neo4j_nodes(self, base_query, keywords, node_type, content_field, limit)

Purpose: Helper method to fetch nodes from Neo4j with keyword filtering Args: base_query: Base cypher query for filtering keywords: Keywords for filtering results node_type: Type of node to fetch content_field: Field containing the content limit: Maximum number of nodes to retrieve Returns: List of node dictionaries

Parameters:

  • base_query: Parameter
  • keywords: Parameter
  • node_type: Parameter
  • content_field: Parameter
  • limit: Parameter

Returns: See docstring for return details

_process_node_type(self, base_query, query_embedding, query_text, node_type, content_field, k, embedding_function, doc_type)

Purpose: Helper method to process a specific node type with FAISS vector search Args: base_query: Base cypher query for pre-filtering query_embedding: Embedding vector for the query query_text: Text of the query for cross-encoder ranking node_type: Type of node to process (Text_chunk or Table_chunk) content_field: Field containing the node content k: Number of results to retrieve embedding_function: Function to generate embeddings doc_type: Type of document max_results: Maximum number of results to return Returns: Dictionary with blocks added and other metadata

Parameters:

  • base_query: Parameter
  • query_embedding: Parameter
  • query_text: Parameter
  • node_type: Parameter
  • content_field: Parameter
  • k: Parameter
  • embedding_function: Parameter
  • doc_type: Parameter

Returns: See docstring for return details

collect_text_blocks(self, data, query)

Purpose: Performs collect text blocks

Parameters:

  • data: Parameter
  • query: Parameter

Returns: None

generate_prompt(self, template, data_sections, query)

Purpose: Performs generate prompt

Parameters:

  • template: Parameter
  • data_sections: Parameter
  • query: Parameter

Returns: None

extend_query(self, query)

Purpose: Performs extend query

Parameters:

  • query: Parameter

Returns: None

normalize(self, vector)

Purpose: Performs normalize

Parameters:

  • vector: Parameter

Returns: None

mmr(self, query_embedding, candidate_embeddings, candidate_docs, lambda_param, top_k)

Purpose: Performs mmr

Parameters:

  • query_embedding: Parameter
  • candidate_embeddings: Parameter
  • candidate_docs: Parameter
  • lambda_param: Parameter
  • top_k: Parameter

Returns: None

similarity_threshold_filter(self, query_embedding, candidate_embeddings, candidate_docs, similarity_threshold, top_k)

Purpose: Performs similarity threshold filter

Parameters:

  • query_embedding: Parameter
  • candidate_embeddings: Parameter
  • candidate_docs: Parameter
  • similarity_threshold: Parameter
  • top_k: Parameter

Returns: None

Required Imports

from typing import List
from typing import Any
from typing import Dict
import os
import panel as pn

Usage Example

# Example usage:
# result = OneCo_hybrid_RAG(bases)

Similar Components

AI-powered semantic similarity - components with related functionality:

  • class OneCo_hybrid_RAG_v2 99.0% similar

    A class named OneCo_hybrid_RAG

    From: /tf/active/vicechatdev/OneCo_hybrid_RAG.py
  • class OneCo_hybrid_RAG_v3 98.9% similar

    A class named OneCo_hybrid_RAG

    From: /tf/active/vicechatdev/vice_ai/hybrid_rag_engine.py
  • class OneCo_hybrid_RAG_v5 98.8% similar

    A class named OneCo_hybrid_RAG

    From: /tf/active/vicechatdev/data_capture_backup_18072025/OneCo_hybrid_RAG.py
  • class OneCo_hybrid_RAG_v1 98.6% similar

    A class named OneCo_hybrid_RAG

    From: /tf/active/vicechatdev/OneCo_hybrid_RAG_old.py
  • class OneCo_hybrid_RAG 98.2% similar

    A class named OneCo_hybrid_RAG

    From: /tf/active/vicechatdev/OneCo_hybrid_RAG copy.py
← Back to Browse