class MyEmbeddingFunction_v2
A custom embedding function class that generates embeddings for text documents using OpenAI's embedding models, with automatic text summarization and token management for large documents.
/tf/active/vicechatdev/offline_docstore_multi_vice.py
135 - 195
moderate
Purpose
This class extends the EmbeddingFunction interface to provide a robust embedding generation system that handles large documents by automatically summarizing content that exceeds token limits. It integrates with OpenAI's ChatGPT models for summarization and OpenAI's embedding models for vector generation. The class is designed to work with ChromaDB and handles token counting, text sanitization, and content truncation to ensure documents fit within model constraints.
Source Code
class MyEmbeddingFunction(EmbeddingFunction):
def __init__(self, model_name: str, embed_model_name: str, api_key: str):
self.model_name = model_name
self.api_key = api_key
self.llm = ChatOpenAI(model_name=model_name, temperature=0,api_key=api_key)
self.embed_model_name = embed_model_name
def summarize_text(self,text, max_tokens_summary=8192):
"""
Summarize the input text using the GPT-4o-mini summarizer.
The summary will be limited to under max_tokens_summary tokens.
"""
# Prepare the summarization prompt
text=self.sanitize_text(text)
prompt = (
f"Please summarize the following text such that the summary is under {max_tokens_summary} tokens:\n\n{text}"
)
# Call the ChatCompletion API with the GPT-4o-mini model
response = self.llm.invoke(prompt)
summary = response.content.strip()
return summary
def sanitize_text(self,text):
"""
Sanitize text by encoding to UTF-8 with error replacement and decoding back.
This replaces any characters that might cause ASCII encoding errors.
"""
return text.encode("utf-8", errors="replace").decode("utf-8")
def count_tokens(self,text):
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
def __call__(self, input: Documents) -> Embeddings:
# embed the documents somehow
## expect a list of str and return a list of embeddings
embeddings=[]
for content in input:
if len(content) > 1000000:
content = content[:1000000]
logger.warning(f"Shrinking content due to token limit")
while self.count_tokens(content) > 110000:
content = content[:-1000]
# Create embedding
if self.count_tokens(content) > 8192:
logger.warning(f"Summarizing text due to token limit")
content=self.summarize_text(content, self.api_key)
response = openai.embeddings.create(
model=self.embed_model_name,
input=content,
)
embedding = response.data[0].embedding
embeddings.append(embedding)
return embeddings
Parameters
| Name | Type | Default | Kind |
|---|---|---|---|
bases |
EmbeddingFunction | - |
Parameter Details
model_name: The name of the OpenAI chat model to use for text summarization (e.g., 'gpt-4o-mini', 'gpt-4'). This model is used when documents exceed the embedding model's token limit.
embed_model_name: The name of the OpenAI embedding model to use for generating vector embeddings (e.g., 'text-embedding-ada-002', 'text-embedding-3-small').
api_key: The OpenAI API key required for authentication with OpenAI services. Must be a valid API key with access to both chat and embedding endpoints.
Return Value
The class instantiation returns a MyEmbeddingFunction object. The __call__ method returns a list of embeddings (Embeddings type), where each embedding is a list of floating-point numbers representing the vector embedding of the corresponding input document. The summarize_text method returns a string containing the summarized text. The count_tokens method returns an integer representing the token count. The sanitize_text method returns a sanitized string.
Class Interface
Methods
__init__(self, model_name: str, embed_model_name: str, api_key: str)
Purpose: Initializes the embedding function with OpenAI model configurations and API credentials
Parameters:
model_name: Name of the ChatGPT model for summarizationembed_model_name: Name of the OpenAI embedding modelapi_key: OpenAI API key for authentication
Returns: None (constructor)
summarize_text(self, text, max_tokens_summary=8192)
Purpose: Summarizes input text using the configured ChatGPT model to reduce token count below the specified limit
Parameters:
text: The text content to be summarizedmax_tokens_summary: Maximum number of tokens allowed in the summary (default: 8192)
Returns: A string containing the summarized text that fits within the token limit
sanitize_text(self, text)
Purpose: Cleans text by encoding to UTF-8 with error replacement to handle problematic characters
Parameters:
text: The text string to sanitize
Returns: A sanitized string with problematic characters replaced
count_tokens(self, text)
Purpose: Counts the number of tokens in the given text using the cl100k_base encoding
Parameters:
text: The text string to count tokens for
Returns: An integer representing the number of tokens in the text
__call__(self, input: Documents) -> Embeddings
Purpose: Generates embeddings for a list of documents, automatically handling token limits through truncation and summarization
Parameters:
input: A list of document strings (Documents type from ChromaDB) to generate embeddings for
Returns: A list of embeddings (Embeddings type), where each embedding is a list of floats representing the vector embedding
Attributes
| Name | Type | Description | Scope |
|---|---|---|---|
model_name |
str | The name of the ChatGPT model used for text summarization | instance |
api_key |
str | The OpenAI API key for authentication | instance |
llm |
ChatOpenAI | The LangChain ChatOpenAI instance configured with the specified model and API key, used for summarization | instance |
embed_model_name |
str | The name of the OpenAI embedding model used for generating vector embeddings | instance |
Dependencies
langchain_openaitiktokenopenaichromadblogging
Required Imports
from langchain_openai import ChatOpenAI
import tiktoken
import openai
from chromadb import Documents
from chromadb import EmbeddingFunction
from chromadb import Embeddings
import logging
Usage Example
# Initialize the embedding function
api_key = 'your-openai-api-key'
embedding_fn = MyEmbeddingFunction(
model_name='gpt-4o-mini',
embed_model_name='text-embedding-ada-002',
api_key=api_key
)
# Set up OpenAI client
openai.api_key = api_key
# Generate embeddings for documents
documents = ['This is the first document.', 'This is the second document with more content.']
embeddings = embedding_fn(documents)
# Use with ChromaDB
import chromadb
client = chromadb.Client()
collection = client.create_collection(
name='my_collection',
embedding_function=embedding_fn
)
collection.add(
documents=documents,
ids=['doc1', 'doc2']
)
Best Practices
- Ensure the OpenAI API key is kept secure and not hardcoded in production code
- The class automatically handles large documents by truncating to 1,000,000 characters and then summarizing if token count exceeds 8,192 tokens
- Token counting uses the 'cl100k_base' encoding which is appropriate for GPT-4 and newer models
- The __call__ method processes documents sequentially, which may be slow for large batches - consider implementing batch processing for production use
- Ensure the 'logger' variable is properly configured before instantiation to capture warnings about content truncation and summarization
- The class modifies input content if it exceeds token limits, so original content may not be preserved
- Set the openai.api_key globally or ensure the openai client is properly configured before calling the embedding function
- Be aware of rate limits when processing many documents, as each document makes at least one API call
- The summarization process uses temperature=0 for deterministic results
Tags
Similar Components
AI-powered semantic similarity - components with related functionality:
-
class MyEmbeddingFunction_v3 97.2% similar
-
class MyEmbeddingFunction_v1 93.9% similar
-
class DocChatEmbeddingFunction 88.3% similar
-
class MyEmbeddingFunction 80.5% similar
-
class DocumentIndexer 52.4% similar