Developers

Text Embeddings with TEI

This guide demonstrates how to build powerful text embedding services using Text Embeddings Inference (TEI), enabling semantic search, similarity analysis, and retrieval-augmented generation (RAG) applications.

Overview

Text Embeddings Inference (TEI) is a high-performance embedding server that provides:

  • Fast Inference: Optimized for batch processing and low latency
  • Multiple Models: Support for various embedding architectures
  • Similarity Search: Built-in similarity and ranking capabilities
  • Pooling Strategies: Multiple pooling methods for optimal embeddings
  • Batch Processing: Efficient handling of multiple texts
  • Production Ready: Auto-scaling and error handling

Complete Implementation

Input Schema Design

Define comprehensive input validation for embedding operations:

from pydantic import BaseModel, Field
from typing import List, Optional, Union
from enum import Enum

class PoolingStrategy(str, Enum):
    CLS = "cls"                    # Use [CLS] token
    MEAN = "mean"                  # Mean pooling
    MAX = "max"                    # Max pooling
    MEAN_SQRT_LEN = "mean_sqrt_len" # Mean pooling with sqrt normalization

class EmbeddingInput(BaseModel):
    inputs: Union[str, List[str]]  # Single text or batch
    normalize: bool = Field(default=True)
    truncate: bool = Field(default=True)
    pooling: Optional[PoolingStrategy] = PoolingStrategy.MEAN

class SimilarityInput(BaseModel):
    source_text: str
    target_texts: List[str] = Field(max_items=100)
    normalize: bool = Field(default=True)

class RerankInput(BaseModel):
    query: str
    texts: List[str] = Field(max_items=50)
    top_k: Optional[int] = Field(default=None, ge=1, le=50)

class SearchInput(BaseModel):
    query: str
    corpus: List[str] = Field(max_items=1000)
    top_k: int = Field(default=10, ge=1, le=100)
    threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)

Custom Image with TEI

Build a custom image with Text Embeddings Inference:

from chutes.image import Image
from chutes.chute import Chute, NodeSelector

image = (
    Image(
        username="myuser",
        name="text-embeddings",
        tag="0.0.1",
        readme="High-performance text embeddings with TEI")
    .from_base("parachutes/base-python:3.11")
    .run_command("pip install --upgrade pip")
    .run_command("pip install text-embeddings-inference-client")
    .run_command("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
    .run_command("pip install transformers sentence-transformers")
    .run_command("pip install numpy scikit-learn faiss-cpu")
    .run_command("pip install loguru pydantic fastapi")
    # Install TEI server
    .run_command(
        "wget https://github.com/huggingface/text-embeddings-inference/releases/download/v1.2.3/text-embeddings-inference-1.2.3-x86_64-unknown-linux-gnu.tar.gz && "
        "tar -xzf text-embeddings-inference-1.2.3-x86_64-unknown-linux-gnu.tar.gz && "
        "chmod +x text-embeddings-inference && "
        "mv text-embeddings-inference /usr/local/bin/"
    )
)

Chute Configuration

Configure the service with appropriate GPU and memory requirements:

chute = Chute(
    username="myuser",
    name="text-embeddings-service",
    tagline="High-performance text embeddings and similarity search",
    readme="Production-ready text embedding service with similarity search, reranking, and semantic analysis capabilities",
    image=image,
    node_selector=NodeSelector(
        gpu_count=1,
        min_vram_gb_per_gpu=16,  # Sufficient for most embedding models
    ),
    concurrency=8,  # Handle multiple concurrent requests
)

Model Initialization

Initialize the embedding model and TEI server:

import subprocess
import time
import requests
from loguru import logger

@chute.on_startup()
async def initialize_embeddings(self):
    """
    Initialize TEI server and embedding capabilities.
    """
    import torch
    import numpy as np
    from sentence_transformers import SentenceTransformer

    # Model configuration
    self.model_name = "sentence-transformers/all-MiniLM-L6-v2"  # Default model
    self.tei_port = 8080
    self.tei_url = f"http://localhost:{self.tei_port}"

    # Start TEI server in background
    logger.info("Starting TEI server...")
    self.tei_process = subprocess.Popen([
        "text-embeddings-inference",
        "--model-id", self.model_name,
        "--port", str(self.tei_port),
        "--max-concurrent-requests", "32",
        "--max-batch-tokens", "16384",
        "--max-batch-requests", "16"
    ])

    # Wait for server to start
    max_wait = 120
    for i in range(max_wait):
        try:
            response = requests.get(f"{self.tei_url}/health", timeout=5)
            if response.status_code == 200:
                logger.success("TEI server started successfully")
                break
        except requests.exceptions.RequestException:
            if i < max_wait - 1:
                time.sleep(1)
            else:
                raise Exception("TEI server failed to start")

    # Initialize fallback model for local processing
    logger.info("Loading fallback sentence transformer...")
    self.sentence_transformer = SentenceTransformer(self.model_name)

    # Store utilities
    self.torch = torch
    self.numpy = np

    # Initialize vector storage (in-memory for this example)
    self.vector_store = {}
    self.text_store = {}

    # Warmup
    await self._warmup_model()

async def _warmup_model(self):
    """Perform warmup embedding generation."""
    warmup_texts = [
        "This is a warmup sentence to initialize the embedding model.",
        "Another test sentence for model warming.",
        "Final warmup text to ensure optimal performance."
    ]

    try:
        # Warmup TEI server
        response = requests.post(
            f"{self.tei_url}/embed",
            json={"inputs": warmup_texts},
            timeout=30
        )
        if response.status_code == 200:
            logger.info("TEI server warmed up successfully")
        else:
            logger.warning("TEI warmup failed, using fallback model")
            # Warmup fallback model
            _ = self.sentence_transformer.encode(warmup_texts)

    except Exception as e:
        logger.warning(f"Warmup failed: {e}, using fallback model")
        _ = self.sentence_transformer.encode(warmup_texts)

Core Embedding Functions

Implement core embedding functionality:

import hashlib
from typing import List, Dict, Tuple

async def get_embeddings(self, texts: Union[str, List[str]], normalize: bool = True) -> np.ndarray:
    """
    Get embeddings for text(s) using TEI server or fallback.
    """
    if isinstance(texts, str):
        texts = [texts]

    try:
        # Try TEI server first
        response = requests.post(
            f"{self.tei_url}/embed",
            json={
                "inputs": texts,
                "normalize": normalize,
                "truncate": True
            },
            timeout=30
        )

        if response.status_code == 200:
            embeddings = self.numpy.array(response.json())
            return embeddings
        else:
            logger.warning(f"TEI server error: {response.status_code}, using fallback")

    except Exception as e:
        logger.warning(f"TEI server failed: {e}, using fallback")

    # Fallback to local model
    embeddings = self.sentence_transformer.encode(
        texts,
        normalize_embeddings=normalize,
        convert_to_numpy=True
    )
    return embeddings

def compute_similarity(self, embeddings1: np.ndarray, embeddings2: np.ndarray) -> np.ndarray:
    """Compute cosine similarity between embeddings."""

    # Normalize if not already normalized
    if embeddings1.ndim == 1:
        embeddings1 = embeddings1.reshape(1, -1)
    if embeddings2.ndim == 1:
        embeddings2 = embeddings2.reshape(1, -1)

    # Compute cosine similarity
    dot_product = self.numpy.dot(embeddings1, embeddings2.T)
    norms1 = self.numpy.linalg.norm(embeddings1, axis=1, keepdims=True)
    norms2 = self.numpy.linalg.norm(embeddings2, axis=1, keepdims=True)

    similarities = dot_product / (norms1 * norms2.T)
    return similarities

def add_to_vector_store(self, texts: List[str], embeddings: np.ndarray, collection: str = "default"):
    """Add texts and embeddings to vector store."""

    if collection not in self.vector_store:
        self.vector_store[collection] = []
        self.text_store[collection] = []

    for text, embedding in zip(texts, embeddings):
        text_id = hashlib.md5(text.encode()).hexdigest()

        self.vector_store[collection].append({
            "id": text_id,
            "embedding": embedding,
            "text": text
        })
        self.text_store[collection].append(text)

Embedding Generation Endpoints

Create endpoints for different embedding operations:

from fastapi import HTTPException

@chute.cord(
    public_api_path="/embed",
    public_api_method="POST",
    stream=False)
async def generate_embeddings(self, args: EmbeddingInput) -> Dict:
    """
    Generate embeddings for input text(s).
    """
    try:
        embeddings = await get_embeddings(self, args.inputs, args.normalize)

        # Convert to list for JSON serialization
        embeddings_list = embeddings.tolist()

        if isinstance(args.inputs, str):
            return {
                "embeddings": embeddings_list[0],
                "model": self.model_name,
                "dimension": len(embeddings_list[0])
            }
        else:
            return {
                "embeddings": embeddings_list,
                "model": self.model_name,
                "dimension": len(embeddings_list[0]),
                "count": len(embeddings_list)
            }

    except Exception as e:
        logger.error(f"Embedding generation failed: {e}")
        raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")

@chute.cord(
    public_api_path="/similarity",
    public_api_method="POST",
    stream=False)
async def compute_text_similarity(self, args: SimilarityInput) -> Dict:
    """
    Compute similarity between source text and target texts.
    """
    try:
        # Get embeddings for all texts
        all_texts = [args.source_text] + args.target_texts
        embeddings = await get_embeddings(self, all_texts, args.normalize)

        # Separate source and target embeddings
        source_embedding = embeddings[0:1]
        target_embeddings = embeddings[1:]

        # Compute similarities
        similarities = compute_similarity(self, source_embedding, target_embeddings)
        similarity_scores = similarities[0].tolist()

        # Create results with metadata
        results = []
        for i, (text, score) in enumerate(zip(args.target_texts, similarity_scores)):
            results.append({
                "text": text,
                "similarity": float(score),
                "rank": i + 1
            })

        # Sort by similarity (descending)
        results.sort(key=lambda x: x["similarity"], reverse=True)

        # Update ranks
        for i, result in enumerate(results):
            result["rank"] = i + 1

        return {
            "source_text": args.source_text,
            "results": results,
            "model": self.model_name
        }

    except Exception as e:
        logger.error(f"Similarity computation failed: {e}")
        raise HTTPException(status_code=500, detail=f"Similarity computation failed: {str(e)}")

@chute.cord(
    public_api_path="/rerank",
    public_api_method="POST",
    stream=False)
async def rerank_texts(self, args: RerankInput) -> Dict:
    """
    Rerank texts based on relevance to query.
    """
    try:
        # Get embeddings
        query_embedding = await get_embeddings(self, args.query, normalize=True)
        text_embeddings = await get_embeddings(self, args.texts, normalize=True)

        # Compute similarities
        similarities = compute_similarity(self, query_embedding, text_embeddings)
        scores = similarities[0].tolist()

        # Create scored results
        scored_texts = [
            {
                "text": text,
                "score": float(score),
                "index": i
            }
            for i, (text, score) in enumerate(zip(args.texts, scores))
        ]

        # Sort by score (descending)
        scored_texts.sort(key=lambda x: x["score"], reverse=True)

        # Apply top_k limit if specified
        if args.top_k:
            scored_texts = scored_texts[:args.top_k]

        # Add ranks
        for rank, item in enumerate(scored_texts):
            item["rank"] = rank + 1

        return {
            "query": args.query,
            "results": scored_texts,
            "total_results": len(scored_texts),
            "model": self.model_name
        }

    except Exception as e:
        logger.error(f"Reranking failed: {e}")
        raise HTTPException(status_code=500, detail=f"Reranking failed: {str(e)}")

Semantic Search Implementation

Build a complete semantic search system:

@chute.cord(
    public_api_path="/search",
    public_api_method="POST",
    stream=False)
async def semantic_search(self, args: SearchInput) -> Dict:
    """
    Perform semantic search over a corpus of texts.
    """
    try:
        # Get query embedding
        query_embedding = await get_embeddings(self, args.query, normalize=True)

        # Get corpus embeddings (batch processing for efficiency)
        corpus_embeddings = await get_embeddings(self, args.corpus, normalize=True)

        # Compute similarities
        similarities = compute_similarity(self, query_embedding, corpus_embeddings)
        scores = similarities[0]

        # Create results with scores
        results = []
        for i, (text, score) in enumerate(zip(args.corpus, scores)):
            if args.threshold is None or score >= args.threshold:
                results.append({
                    "text": text,
                    "score": float(score),
                    "corpus_index": i
                })

        # Sort by score (descending) and take top_k
        results.sort(key=lambda x: x["score"], reverse=True)
        results = results[:args.top_k]

        # Add ranks
        for rank, result in enumerate(results):
            result["rank"] = rank + 1

        return {
            "query": args.query,
            "results": results,
            "total_corpus_size": len(args.corpus),
            "results_returned": len(results),
            "model": self.model_name,
            "threshold": args.threshold
        }

    except Exception as e:
        logger.error(f"Semantic search failed: {e}")
        raise HTTPException(status_code=500, detail=f"Semantic search failed: {str(e)}")

Advanced Features

Vector Store Management

Implement persistent vector storage:

class VectorStoreInput(BaseModel):
    collection: str = "default"
    texts: List[str]
    metadata: Optional[Dict] = None

class SearchStoreInput(BaseModel):
    collection: str = "default"
    query: str
    top_k: int = Field(default=10, ge=1, le=100)
    filter_metadata: Optional[Dict] = None

@chute.cord(public_api_path="/store/add", method="POST")
async def add_to_store(self, args: VectorStoreInput) -> Dict:
    """Add texts to persistent vector store."""

    try:
        # Generate embeddings
        embeddings = await get_embeddings(self, args.texts, normalize=True)

        # Add to store
        add_to_vector_store(self, args.texts, embeddings, args.collection)

        return {
            "collection": args.collection,
            "added_count": len(args.texts),
            "total_in_collection": len(self.text_store.get(args.collection, []))
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to add to store: {str(e)}")

@chute.cord(public_api_path="/store/search", method="POST")
async def search_store(self, args: SearchStoreInput) -> Dict:
    """Search within a specific collection."""

    if args.collection not in self.vector_store:
        raise HTTPException(status_code=404, detail=f"Collection '{args.collection}' not found")

    try:
        # Get query embedding
        query_embedding = await get_embeddings(self, args.query, normalize=True)

        # Get stored embeddings
        stored_items = self.vector_store[args.collection]
        stored_embeddings = self.numpy.array([item["embedding"] for item in stored_items])

        # Compute similarities
        similarities = compute_similarity(self, query_embedding, stored_embeddings)
        scores = similarities[0]

        # Create results
        results = []
        for i, (item, score) in enumerate(zip(stored_items, scores)):
            results.append({
                "text": item["text"],
                "score": float(score),
                "id": item["id"]
            })

        # Sort and limit
        results.sort(key=lambda x: x["score"], reverse=True)
        results = results[:args.top_k]

        # Add ranks
        for rank, result in enumerate(results):
            result["rank"] = rank + 1

        return {
            "collection": args.collection,
            "query": args.query,
            "results": results,
            "total_in_collection": len(stored_items)
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Store search failed: {str(e)}")

@chute.cord(public_api_path="/store/collections", method="GET")
async def list_collections(self) -> Dict:
    """List all available collections."""

    collections = []
    for name, texts in self.text_store.items():
        collections.append({
            "name": name,
            "size": len(texts),
            "sample_texts": texts[:3] if texts else []
        })

    return {"collections": collections}

Batch Processing Optimization

Optimize for large-scale batch operations:

class BatchEmbeddingInput(BaseModel):
    texts: List[str] = Field(max_items=1000)
    batch_size: int = Field(default=32, ge=1, le=128)
    normalize: bool = True

@chute.cord(public_api_path="/embed/batch", method="POST")
async def batch_embeddings(self, args: BatchEmbeddingInput) -> Dict:
    """Process large batches of texts efficiently."""

    try:
        all_embeddings = []
        processed_count = 0

        # Process in batches
        for i in range(0, len(args.texts), args.batch_size):
            batch_texts = args.texts[i:i + args.batch_size]
            batch_embeddings = await get_embeddings(self, batch_texts, args.normalize)
            all_embeddings.extend(batch_embeddings.tolist())
            processed_count += len(batch_texts)

            # Optional: yield progress for very large batches
            if processed_count % 100 == 0:
                logger.info(f"Processed {processed_count}/{len(args.texts)} texts")

        return {
            "embeddings": all_embeddings,
            "processed_count": processed_count,
            "batch_size": args.batch_size,
            "model": self.model_name,
            "dimension": len(all_embeddings[0]) if all_embeddings else 0
        }

    except Exception as e:
        logger.error(f"Batch embedding failed: {e}")
        raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}")

Clustering and Analysis

Add text clustering capabilities:

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

class ClusterInput(BaseModel):
    texts: List[str] = Field(min_items=2, max_items=500)
    n_clusters: int = Field(default=5, ge=2, le=20)
    method: str = Field(default="kmeans")

@chute.cord(public_api_path="/cluster", method="POST")
async def cluster_texts(self, args: ClusterInput) -> Dict:
    """Cluster texts based on semantic similarity."""

    try:
        # Get embeddings
        embeddings = await get_embeddings(self, args.texts, normalize=True)

        # Perform clustering
        if args.method == "kmeans":
            # Adjust number of clusters if needed
            n_clusters = min(args.n_clusters, len(args.texts))
            kmeans = KMeans(n_clusters=n_clusters, random_state=42)
            cluster_labels = kmeans.fit_predict(embeddings)

            # Get cluster centers
            cluster_centers = kmeans.cluster_centers_

        else:
            raise HTTPException(status_code=400, detail=f"Unsupported clustering method: {args.method}")

        # Organize results by cluster
        clusters = {}
        for i, (text, label) in enumerate(zip(args.texts, cluster_labels)):
            label = int(label)
            if label not in clusters:
                clusters[label] = []
            clusters[label].append({
                "text": text,
                "index": i
            })

        # Calculate cluster statistics
        cluster_stats = []
        for label, items in clusters.items():
            # Find centroid text (closest to cluster center)
            cluster_embeddings = embeddings[[item["index"] for item in items]]
            center = cluster_centers[label]

            # Compute distances to center
            distances = self.numpy.linalg.norm(cluster_embeddings - center, axis=1)
            centroid_idx = self.numpy.argmin(distances)

            cluster_stats.append({
                "cluster_id": label,
                "size": len(items),
                "centroid_text": items[centroid_idx]["text"],
                "texts": [item["text"] for item in items]
            })

        return {
            "clusters": cluster_stats,
            "n_clusters": len(clusters),
            "method": args.method,
            "total_texts": len(args.texts)
        }

    except Exception as e:
        logger.error(f"Clustering failed: {e}")
        raise HTTPException(status_code=500, detail=f"Clustering failed: {str(e)}")

Deployment and Usage

Deploy the Service

# Build and deploy the embeddings service
chutes deploy my_embeddings:chute

# Monitor the deployment
chutes chutes get my-embeddings

Using the API

Basic Embedding Generation

curl -X POST "https://myuser-my-embeddings.chutes.ai/embed" \
  -H "Content-Type: application/json" \
  -d '{
    "inputs": "This is a sample text for embedding generation",
    "normalize": true
  }'
curl -X POST "https://myuser-my-embeddings.chutes.ai/similarity" \
  -H "Content-Type: application/json" \
  -d '{
    "source_text": "machine learning algorithms",
    "target_texts": [
      "artificial intelligence techniques",
      "cooking recipes",
      "neural network models",
      "gardening tips",
      "deep learning frameworks"
    ],
    "normalize": true
  }'

Python Client Example

import requests
from typing import List, Dict, Optional

class EmbeddingsClient:
    def __init__(self, base_url: str):
        self.base_url = base_url.rstrip('/')

    def embed(self, texts: Union[str, List[str]], normalize: bool = True) -> Dict:
        """Generate embeddings for text(s)."""
        response = requests.post(
            f"{self.base_url}/embed",
            json={
                "inputs": texts,
                "normalize": normalize
            }
        )

        if response.status_code == 200:
            return response.json()
        else:
            raise Exception(f"Embedding failed: {response.status_code} - {response.text}")

    def similarity(self, source_text: str, target_texts: List[str]) -> Dict:
        """Compute similarity between source and target texts."""
        response = requests.post(
            f"{self.base_url}/similarity",
            json={
                "source_text": source_text,
                "target_texts": target_texts,
                "normalize": True
            }
        )
        return response.json()

    def search(self, query: str, corpus: List[str], top_k: int = 10) -> Dict:
        """Perform semantic search over corpus."""
        response = requests.post(
            f"{self.base_url}/search",
            json={
                "query": query,
                "corpus": corpus,
                "top_k": top_k
            }
        )
        return response.json()

    def rerank(self, query: str, texts: List[str], top_k: Optional[int] = None) -> Dict:
        """Rerank texts by relevance to query."""
        payload = {
            "query": query,
            "texts": texts
        }
        if top_k:
            payload["top_k"] = top_k

        response = requests.post(
            f"{self.base_url}/rerank",
            json=payload
        )
        return response.json()

    def add_to_store(self, texts: List[str], collection: str = "default") -> Dict:
        """Add texts to vector store."""
        response = requests.post(
            f"{self.base_url}/store/add",
            json={
                "texts": texts,
                "collection": collection
            }
        )
        return response.json()

    def search_store(self, query: str, collection: str = "default", top_k: int = 10) -> Dict:
        """Search within stored collection."""
        response = requests.post(
            f"{self.base_url}/store/search",
            json={
                "query": query,
                "collection": collection,
                "top_k": top_k
            }
        )
        return response.json()

    def cluster(self, texts: List[str], n_clusters: int = 5) -> Dict:
        """Cluster texts by semantic similarity."""
        response = requests.post(
            f"{self.base_url}/cluster",
            json={
                "texts": texts,
                "n_clusters": n_clusters,
                "method": "kmeans"
            }
        )
        return response.json()

# Usage examples
client = EmbeddingsClient("https://myuser-my-embeddings.chutes.ai")

# Generate embeddings
result = client.embed("This is a test sentence")
embedding = result["embeddings"]
print(f"Embedding dimension: {result['dimension']}")

# Batch embeddings
batch_result = client.embed([
    "First document about machine learning",
    "Second document about cooking",
    "Third document about artificial intelligence"
])

# Find similar texts
similarity_result = client.similarity(
    source_text="artificial intelligence research",
    target_texts=[
        "machine learning algorithms",
        "cooking recipes",
        "neural networks",
        "gardening techniques"
    ]
)

print("Most similar texts:")
for result in similarity_result["results"][:3]:
    print(f"- {result['text']} (similarity: {result['similarity']:.3f})")

# Build a knowledge base
documents = [
    "Python is a programming language",
    "Machine learning uses algorithms to learn patterns",
    "Deep learning is a subset of machine learning",
    "Natural language processing analyzes text",
    "Computer vision processes images",
    "Reinforcement learning learns through trial and error"
]

# Add to vector store
client.add_to_store(documents, collection="ai_knowledge")

# Search the knowledge base
search_result = client.search_store(
    query="algorithms for learning",
    collection="ai_knowledge",
    top_k=3
)

print("Knowledge base search results:")
for result in search_result["results"]:
    print(f"- {result['text']} (score: {result['score']:.3f})")

# Cluster documents
cluster_result = client.cluster(documents, n_clusters=3)
print(f"Clustered into {cluster_result['n_clusters']} groups:")
for cluster in cluster_result["clusters"]:
    print(f"Cluster {cluster['cluster_id']} ({cluster['size']} items):")
    print(f"  Centroid: {cluster['centroid_text']}")

Best Practices

1. Model Selection

# Different models for different use cases
model_recommendations = {
    "general_purpose": "sentence-transformers/all-MiniLM-L6-v2",  # Fast, good quality
    "multilingual": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "high_quality": "sentence-transformers/all-mpnet-base-v2",  # Best quality
    "domain_specific": "sentence-transformers/allenai-specter",  # Scientific papers
    "code": "microsoft/codebert-base",  # Code similarity
}

def select_model_for_use_case(use_case: str) -> str:
    """Select optimal model based on use case."""
    return model_recommendations.get(use_case, model_recommendations["general_purpose"])

2. Text Preprocessing

import re
from typing import List

def preprocess_text(text: str) -> str:
    """Clean and prepare text for embedding."""
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text)

    # Remove special characters if needed
    text = re.sub(r'[^\w\s\-\.]', '', text)

    # Normalize case (optional, depends on model)
    # text = text.lower()

    # Remove very short texts
    if len(text.strip()) < 3:
        return ""

    return text.strip()

def batch_preprocess(texts: List[str]) -> List[str]:
    """Preprocess batch of texts."""
    processed = []
    for text in texts:
        cleaned = preprocess_text(text)
        if cleaned:  # Only add non-empty texts
            processed.append(cleaned)

    return processed

3. Caching and Performance

import hashlib
from typing import Dict
import pickle

class EmbeddingCache:
    """Simple LRU cache for embeddings."""

    def __init__(self, max_size: int = 1000):
        self.cache: Dict[str, np.ndarray] = {}
        self.access_order = []
        self.max_size = max_size

    def get_key(self, text: str, model: str) -> str:
        """Generate cache key."""
        content = f"{text}_{model}"
        return hashlib.md5(content.encode()).hexdigest()

    def get(self, text: str, model: str) -> Optional[np.ndarray]:
        """Get cached embedding."""
        key = self.get_key(text, model)
        if key in self.cache:
            # Update access order
            self.access_order.remove(key)
            self.access_order.append(key)
            return self.cache[key]
        return None

    def set(self, text: str, model: str, embedding: np.ndarray):
        """Cache embedding."""
        key = self.get_key(text, model)

        # Remove oldest if at capacity
        if len(self.cache) >= self.max_size and key not in self.cache:
            oldest_key = self.access_order.pop(0)
            del self.cache[oldest_key]

        self.cache[key] = embedding
        if key not in self.access_order:
            self.access_order.append(key)

# Add to chute initialization
@chute.on_startup()
async def initialize_with_cache(self):
    # ... existing initialization ...
    self.embedding_cache = EmbeddingCache(max_size=2000)

async def get_embeddings_cached(self, texts: Union[str, List[str]], normalize: bool = True) -> np.ndarray:
    """Get embeddings with caching."""
    if isinstance(texts, str):
        texts = [texts]

    cached_embeddings = []
    uncached_texts = []
    uncached_indices = []

    # Check cache
    for i, text in enumerate(texts):
        cached = self.embedding_cache.get(text, self.model_name)
        if cached is not None:
            cached_embeddings.append((i, cached))
        else:
            uncached_texts.append(text)
            uncached_indices.append(i)

    # Generate uncached embeddings
    if uncached_texts:
        new_embeddings = await get_embeddings(self, uncached_texts, normalize)

        # Cache new embeddings
        for text, embedding in zip(uncached_texts, new_embeddings):
            self.embedding_cache.set(text, self.model_name, embedding)

        # Combine cached and new embeddings
        all_embeddings = [None] * len(texts)

        # Place cached embeddings
        for orig_idx, embedding in cached_embeddings:
            all_embeddings[orig_idx] = embedding

        # Place new embeddings
        for new_idx, orig_idx in enumerate(uncached_indices):
            all_embeddings[orig_idx] = new_embeddings[new_idx]

        return self.numpy.array(all_embeddings)

    else:
        # All cached
        return self.numpy.array([emb for _, emb in sorted(cached_embeddings)])

4. Error Handling and Monitoring

import time
from loguru import logger

@chute.cord(public_api_path="/robust_embed", method="POST")
async def robust_embeddings(self, args: EmbeddingInput) -> Dict:
    """Embeddings with comprehensive error handling."""

    start_time = time.time()

    try:
        # Validate input
        if isinstance(args.inputs, list) and len(args.inputs) > 1000:
            raise HTTPException(
                status_code=400,
                detail="Batch size too large. Maximum 1000 texts allowed."
            )

        # Preprocess texts
        if isinstance(args.inputs, str):
            processed_texts = preprocess_text(args.inputs)
            if not processed_texts:
                raise HTTPException(status_code=400, detail="Text too short after preprocessing")
        else:
            processed_texts = batch_preprocess(args.inputs)
            if not processed_texts:
                raise HTTPException(status_code=400, detail="No valid texts after preprocessing")

        # Generate embeddings with retry logic
        max_retries = 3
        for attempt in range(max_retries):
            try:
                embeddings = await get_embeddings_cached(self, processed_texts, args.normalize)
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    raise e
                logger.warning(f"Embedding attempt {attempt + 1} failed: {e}")
                time.sleep(1)

        generation_time = time.time() - start_time
        logger.info(f"Embedding generation completed in {generation_time:.2f}s")

        # Return results
        embeddings_list = embeddings.tolist()
        return {
            "embeddings": embeddings_list if isinstance(args.inputs, list) else embeddings_list[0],
            "model": self.model_name,
            "dimension": len(embeddings_list[0]),
            "generation_time": generation_time,
            "processed_count": len(processed_texts)
        }

    except HTTPException:
        raise
    except Exception as e:
        error_time = time.time() - start_time
        logger.error(f"Embedding generation failed after {error_time:.2f}s: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Embedding generation failed: {str(e)}"
        )

Performance Optimization

Batch Size Tuning

def get_optimal_batch_size(text_lengths: List[int], max_tokens: int = 16384) -> int:
    """Calculate optimal batch size based on text lengths."""

    # Estimate tokens (rough approximation: 1 token ≈ 4 characters)
    estimated_tokens = [length // 4 for length in text_lengths]

    # Calculate how many texts can fit in max_tokens
    cumulative_tokens = 0
    optimal_batch = 0

    for tokens in estimated_tokens:
        if cumulative_tokens + tokens <= max_tokens:
            cumulative_tokens += tokens
            optimal_batch += 1
        else:
            break

    return max(1, optimal_batch)

Memory Management

async def memory_efficient_embeddings(self, texts: List[str], max_batch_size: int = 32) -> np.ndarray:
    """Generate embeddings with memory management."""

    all_embeddings = []

    for i in range(0, len(texts), max_batch_size):
        batch = texts[i:i + max_batch_size]

        # Clear cache before each batch
        if hasattr(self, 'torch'):
            self.torch.cuda.empty_cache()

        batch_embeddings = await get_embeddings(self, batch, normalize=True)
        all_embeddings.extend(batch_embeddings)

        # Optional: yield progress
        if (i + max_batch_size) % 100 == 0:
            logger.info(f"Processed {min(i + max_batch_size, len(texts))}/{len(texts)} texts")

    return self.numpy.array(all_embeddings)

Next Steps

  • Fine-tuning: Train custom embedding models on domain-specific data
  • Advanced Search: Implement hybrid search (dense + sparse)
  • Real-time Updates: Build dynamic vector databases
  • Multimodal: Extend to image and audio embeddings

For more advanced examples, see: