class DocChatEmbeddingFunction
A custom ChromaDB embedding function that generates OpenAI embeddings with automatic text summarization for documents exceeding token limits.
/tf/active/vicechatdev/docchat/document_indexer.py
128 - 227
moderate
Purpose
This class implements ChromaDB's EmbeddingFunction interface to provide OpenAI-based embeddings with intelligent handling of long documents. It automatically summarizes texts that exceed token limits, sanitizes text encoding issues, and uses lazy initialization to avoid serialization problems with ChromaDB. The class is designed to be used as a custom embedding function when creating or querying ChromaDB collections, ensuring all documents can be embedded regardless of length.
Source Code
class DocChatEmbeddingFunction(EmbeddingFunction):
"""Custom embedding function for ChromaDB with automatic text summarization"""
def __init__(self, api_key: str, embed_model_name: str = "text-embedding-3-small",
llm_model_name: str = "gpt-4o-mini"):
"""
Initialize embedding function
Args:
api_key: OpenAI API key
embed_model_name: Embedding model to use
llm_model_name: LLM model for summarization
"""
self.api_key = api_key
self.embed_model_name = embed_model_name
self.llm_model_name = llm_model_name
# Store tokenizer encoding name instead of the tokenizer object
self.tokenizer_encoding = "cl100k_base"
# Don't initialize OpenAI client here - do it lazily to avoid serialization issues
self._openai_client = None
def _get_openai_client(self):
"""Lazy-load OpenAI client to avoid serialization issues with ChromaDB"""
if self._openai_client is None:
from openai import OpenAI
self._openai_client = OpenAI(api_key=self.api_key)
return self._openai_client
def _get_tokenizer(self):
"""Lazy-load tokenizer to avoid serialization issues with ChromaDB"""
return tiktoken.get_encoding(self.tokenizer_encoding)
def _get_llm(self):
"""Lazy-load LLM to avoid serialization issues with ChromaDB"""
return get_llm_instance(model_name=self.llm_model_name, temperature=0)
def count_tokens(self, text: str) -> int:
"""Count tokens in text"""
tokenizer = self._get_tokenizer()
return len(tokenizer.encode(text))
def sanitize_text(self, text: str) -> str:
"""Sanitize text to handle encoding issues"""
return text.encode("utf-8", errors="replace").decode("utf-8")
def summarize_text(self, text: str, max_tokens: int = 8192) -> str:
"""
Summarize text if it exceeds token limit
Args:
text: Text to summarize
max_tokens: Maximum tokens for summary
Returns:
Summarized text
"""
text = self.sanitize_text(text)
prompt = f"Please summarize the following text such that the summary is under {max_tokens} tokens:\n\n{text}"
llm = self._get_llm()
response = llm.invoke(prompt)
return response.content.strip()
def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for documents
Args:
input: List of document texts
Returns:
List of embeddings as numpy arrays
"""
embeddings = []
for content in input:
# Truncate very long content
if len(content) > 1000000:
content = content[:1000000]
logger.warning("Content truncated due to length")
# Ensure content is within token limits
while self.count_tokens(content) > 110000:
content = content[:-1000]
# Summarize if needed for embedding
if self.count_tokens(content) > 8192:
logger.debug("Summarizing text for embedding due to token limit")
content = self.summarize_text(content, 8000)
# Generate embedding using OpenAI client
client = self._get_openai_client()
response = client.embeddings.create(
model=self.embed_model_name,
input=content
)
embedding = response.data[0].embedding
# Convert to numpy array for ChromaDB 0.6.x
embeddings.append(np.array(embedding, dtype=np.float32))
return embeddings
Parameters
| Name | Type | Default | Kind |
|---|---|---|---|
bases |
EmbeddingFunction | - |
Parameter Details
api_key: OpenAI API key required for authentication with OpenAI's embedding and LLM services. Must be a valid API key string.
embed_model_name: Name of the OpenAI embedding model to use. Defaults to 'text-embedding-3-small'. Can be any valid OpenAI embedding model like 'text-embedding-3-large' or 'text-embedding-ada-002'.
llm_model_name: Name of the OpenAI LLM model used for text summarization when documents exceed token limits. Defaults to 'gpt-4o-mini'. Should be a model capable of summarization tasks.
Return Value
Instantiation returns a DocChatEmbeddingFunction object that can be passed to ChromaDB collection creation. The __call__ method returns a list of numpy arrays (Embeddings type), where each array represents the embedding vector for a corresponding input document. Each embedding is a float32 numpy array compatible with ChromaDB 0.6.x.
Class Interface
Methods
__init__(self, api_key: str, embed_model_name: str = 'text-embedding-3-small', llm_model_name: str = 'gpt-4o-mini')
Purpose: Initialize the embedding function with API credentials and model configurations
Parameters:
api_key: OpenAI API key for authenticationembed_model_name: Name of OpenAI embedding model (default: 'text-embedding-3-small')llm_model_name: Name of LLM model for summarization (default: 'gpt-4o-mini')
Returns: None - constructor initializes instance
_get_openai_client(self)
Purpose: Lazy-load OpenAI client to avoid serialization issues with ChromaDB
Returns: OpenAI client instance for making API calls
_get_tokenizer(self)
Purpose: Lazy-load tiktoken tokenizer to avoid serialization issues with ChromaDB
Returns: tiktoken Encoding object for cl100k_base encoding
_get_llm(self)
Purpose: Lazy-load LLM instance for text summarization to avoid serialization issues
Returns: LLM instance from llm_factory with temperature=0 for deterministic summarization
count_tokens(self, text: str) -> int
Purpose: Count the number of tokens in a text string using tiktoken
Parameters:
text: Input text string to count tokens for
Returns: Integer count of tokens in the text
sanitize_text(self, text: str) -> str
Purpose: Sanitize text to handle encoding issues by replacing problematic characters
Parameters:
text: Input text that may contain encoding issues
Returns: Sanitized text string with encoding errors replaced
summarize_text(self, text: str, max_tokens: int = 8192) -> str
Purpose: Summarize text using LLM if it exceeds token limit for embedding
Parameters:
text: Text to summarizemax_tokens: Maximum tokens for the summary (default: 8192)
Returns: Summarized text string that fits within token limit
__call__(self, input: Documents) -> Embeddings
Purpose: Generate embeddings for a list of documents, automatically handling long texts through summarization
Parameters:
input: List of document text strings (Documents type from ChromaDB)
Returns: List of numpy float32 arrays representing embeddings for each document (Embeddings type)
Attributes
| Name | Type | Description | Scope |
|---|---|---|---|
api_key |
str | OpenAI API key for authentication | instance |
embed_model_name |
str | Name of the OpenAI embedding model to use | instance |
llm_model_name |
str | Name of the LLM model for text summarization | instance |
tokenizer_encoding |
str | Name of tiktoken encoding to use ('cl100k_base'), stored as string to avoid serialization issues | instance |
_openai_client |
Optional[OpenAI] | Lazily initialized OpenAI client instance, None until first use | instance |
Dependencies
tiktokenopenainumpychromadbllm_factory
Required Imports
import tiktoken
import numpy as np
from chromadb import Documents, EmbeddingFunction, Embeddings
from llm_factory import get_llm_instance
Conditional/Optional Imports
These imports are only needed under specific conditions:
from openai import OpenAI
Condition: Lazily imported when _get_openai_client() is first called to avoid serialization issues with ChromaDB
Required (conditional)Usage Example
import chromadb
from chromadb import Documents
import numpy as np
# Instantiate the embedding function
embedding_fn = DocChatEmbeddingFunction(
api_key='your-openai-api-key',
embed_model_name='text-embedding-3-small',
llm_model_name='gpt-4o-mini'
)
# Use with ChromaDB collection
client = chromadb.Client()
collection = client.create_collection(
name='my_documents',
embedding_function=embedding_fn
)
# Add documents (embedding happens automatically)
collection.add(
documents=['This is a document', 'Another document'],
ids=['doc1', 'doc2']
)
# Or generate embeddings directly
documents = ['Short text', 'A very long document that might need summarization...']
embeddings = embedding_fn(documents)
print(f'Generated {len(embeddings)} embeddings')
print(f'First embedding shape: {embeddings[0].shape}')
Best Practices
- Always provide a valid OpenAI API key during instantiation to avoid runtime errors
- The class uses lazy initialization for OpenAI client, tokenizer, and LLM to avoid ChromaDB serialization issues - do not try to access these directly
- Documents are automatically truncated at 1,000,000 characters and summarized if they exceed 8,192 tokens for embedding
- The class handles encoding issues automatically through sanitize_text method
- Token counting uses tiktoken's cl100k_base encoding which matches OpenAI's models
- Embeddings are returned as numpy float32 arrays for ChromaDB 0.6.x compatibility
- The __call__ method is invoked automatically by ChromaDB when adding documents to a collection
- For very large documents, expect automatic summarization which may lose some detail
- The class is thread-safe due to lazy initialization pattern
- Monitor logs for warnings about content truncation or summarization
Tags
Similar Components
AI-powered semantic similarity - components with related functionality:
-
class MyEmbeddingFunction_v1 93.1% similar
-
class MyEmbeddingFunction_v2 88.3% similar
-
class MyEmbeddingFunction_v3 86.6% similar
-
class MyEmbeddingFunction 82.8% similar
-
class DocumentIndexer 59.7% similar