This example shows how to efficiently process multiple inputs in a single request, optimizing GPU utilization and reducing API overhead for high-throughput scenarios.
What We'll Build
A batch text processing service that:
📊 Processes multiple texts in a single request
⚡ Optimizes GPU utilization with efficient batching
🔄 Handles variable input sizes with dynamic padding
📈 Provides performance metrics and timing information
🛡️ Validates batch constraints for stability
Complete Example
batch_processor.py
import torch
import time
from typing importList, Optionalfrom transformers import AutoTokenizer, AutoModelForSequenceClassification
from pydantic import BaseModel, Field, validator
from fastapi import HTTPException
from chutes.chute import Chute, NodeSelector
from chutes.image import Image
# === INPUT/OUTPUT SCHEMAS ===classBatchTextInput(BaseModel):
texts: List[str] = Field(..., min_items=1, max_items=100, description="List of texts to process")
max_length: int = Field(512, ge=50, le=1024, description="Maximum token length")
batch_size: int = Field(16, ge=1, le=32, description="Processing batch size")
@validator('texts')defvalidate_texts(cls, v):
for i, text inenumerate(v):
ifnot text.strip():
raise ValueError(f'Text at index {i} cannot be empty')
iflen(text) > 10000:
raise ValueError(f'Text at index {i} is too long (max 10000 chars)')
return [text.strip() for text in v]
classTextResult(BaseModel):
text: str
sentiment: str
confidence: float
token_count: int
processing_order: intclassBatchResult(BaseModel):
results: List[TextResult]
total_texts: int
processing_time: float
average_time_per_text: float
batch_info: dict
performance_metrics: dict# === CUSTOM IMAGE ===
image = (
Image(username="myuser", name="batch-processor", tag="1.0")
.from_base("nvidia/cuda:12.2-runtime-ubuntu22.04")
.with_python("3.11")
.run_command("pip install torch==2.1.0 transformers==4.30.0 accelerate==0.20.0 numpy>=1.24.0")
.with_env("TRANSFORMERS_CACHE", "/app/models")
.with_env("TOKENIZERS_PARALLELISM", "false") # Avoid warnings
)
# === CHUTE DEFINITION ===
chute = Chute(
username="myuser",
name="batch-processor",
image=image,
tagline="High-throughput batch text processing",
readme="""
# Batch Text Processor
Efficiently process multiple texts in a single request with optimized GPU utilization.
## Usage
```bash
curl -X POST https://myuser-batch-processor.chutes.ai/process-batch \\
-H "Content-Type: application/json" \\
-d '{
"texts": [
"I love this product!",
"This is terrible quality.",
"Amazing service and support!"
],
"batch_size": 8
}'
```
## Features
- Process up to 100 texts per request
- Automatic batching for GPU optimization
- Dynamic padding for efficient processing
- Comprehensive performance metrics
""",
node_selector=NodeSelector(
gpu_count=1,
min_vram_gb_per_gpu=12
),
concurrency=4# Allow multiple concurrent requests
)
# === MODEL LOADING ===@chute.on_startup()asyncdefload_model(self):
"""Load sentiment analysis model optimized for batch processing."""print("Loading model for batch processing...")
model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest"# Load tokenizer and modelself.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Optimize for batch processingself.device = "cuda"if torch.cuda.is_available() else"cpu"self.model.to(self.device)
self.model.eval()
# Enable optimizationsif torch.cuda.is_available():
# Enable mixed precision for faster processingself.scaler = torch.cuda.amp.GradScaler()
# Enable TensorCore optimizations where available
torch.backends.cudnn.benchmark = True# Cache for performance trackingself.batch_stats = {
"total_requests": 0,
"total_texts_processed": 0,
"average_batch_time": 0.0,
"peak_batch_size": 0
}
print(f"Model loaded on {self.device} with batch optimizations enabled")
# === BATCH PROCESSING ENDPOINTS ===@chute.cord(
public_api_path="/process-batch",
method="POST",
input_schema=BatchTextInput,
output_content_type="application/json")asyncdefprocess_batch(self, data: BatchTextInput) -> BatchResult:
"""Process multiple texts efficiently with batching."""
start_time = time.time()
# Update statisticsself.batch_stats["total_requests"] += 1self.batch_stats["total_texts_processed"] += len(data.texts)
try:
# Process in chunks if batch is too large
all_results = []
total_batches = 0for chunk_start inrange(0, len(data.texts), data.batch_size):
chunk_end = min(chunk_start + data.batch_size, len(data.texts))
text_chunk = data.texts[chunk_start:chunk_end]
# Process this chunk
chunk_results = awaitself._process_chunk(
text_chunk,
data.max_length,
chunk_start
)
all_results.extend(chunk_results)
total_batches += 1# Calculate performance metrics
processing_time = time.time() - start_time
avg_time_per_text = processing_time / len(data.texts)
# Update global statsself.batch_stats["average_batch_time"] = (
(self.batch_stats["average_batch_time"] * (self.batch_stats["total_requests"] - 1) +
processing_time) / self.batch_stats["total_requests"]
)
self.batch_stats["peak_batch_size"] = max(
self.batch_stats["peak_batch_size"],
len(data.texts)
)
return BatchResult(
results=all_results,
total_texts=len(data.texts),
processing_time=processing_time,
average_time_per_text=avg_time_per_text,
batch_info={
"requested_batch_size": data.batch_size,
"actual_batches_used": total_batches,
"max_length": data.max_length,
"device": self.device
},
performance_metrics={
"texts_per_second": len(data.texts) / processing_time,
"gpu_memory_used": self._get_gpu_memory_usage(),
"total_tokens_processed": sum(r.token_count for r in all_results)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}")
asyncdef \_process_chunk(self, texts: List[str], max_length: int, start_index: int) -> List[TextResult]:
"""Process a chunk of texts efficiently."""# Tokenize all texts in the chunk
encoded = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
# Move to device
input_ids = encoded['input_ids'].to(self.device)
attention_mask = encoded['attention_mask'].to(self.device)
# Process with mixed precision if availablewith torch.no_grad():
if torch.cuda.is_available():
with torch.cuda.amp.autocast():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
else:
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
# Get predictions
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_classes = predictions.argmax(dim=-1)
confidences = predictions.max(dim=-1).values
# Convert to results
labels = ["NEGATIVE", "NEUTRAL", "POSITIVE"]
results = []
for i, (text, pred_class, confidence, tokens) inenumerate(
zip(texts, predicted_classes, confidences, input_ids)
):
results.append(TextResult(
text=text,
sentiment=labels[pred_class.item()],
confidence=confidence.item(),
token_count=tokens.ne(self.tokenizer.pad_token_id).sum().item(),
processing_order=start_index + i
))
return results
def \_get_gpu_memory_usage(self) -> Optional[float]:
"""Get current GPU memory usage in GB."""if torch.cuda.is_available():
return torch.cuda.memory_allocated() / 1024\*\*3returnNone@chute.cord(
public_api_path="/batch-stats",
method="GET",
output_content_type="application/json")asyncdefget_batch_stats(self) -> dict:
"""Get performance statistics for batch processing."""
stats = self.batch_stats.copy()
# Add current system info
stats.update({
"device": self.device,
"model_loaded": hasattr(self, 'model'),
"current_gpu_memory": self._get_gpu_memory_usage(),
"max_gpu_memory": torch.cuda.max_memory_allocated() / 1024**3if torch.cuda.is_available() elseNone
})
return stats
# === STREAMING BATCH PROCESSING ===@chute.cord(
public_api_path="/process-batch-stream",
method="POST",
input_schema=BatchTextInput,
stream=True,
output_content_type="application/json")asyncdefprocess_batch_stream(self, data: BatchTextInput):
"""Process batch with streaming progress updates."""
start_time = time.time()
yield {
"status": "started",
"total_texts": len(data.texts),
"batch_size": data.batch_size,
"estimated_batches": (len(data.texts) + data.batch_size - 1) // data.batch_size
}
all_results = []
for batch_idx, chunk_start inenumerate(range(0, len(data.texts), data.batch_size)):
chunk_end = min(chunk_start + data.batch_size, len(data.texts))
text_chunk = data.texts[chunk_start:chunk_end]
yield {
"status": "processing_batch",
"batch_number": batch_idx + 1,
"batch_size": len(text_chunk),
"progress": chunk_end / len(data.texts)
}
# Process chunk
batch_start = time.time()
chunk_results = awaitself._process_chunk(text_chunk, data.max_length, chunk_start)
batch_time = time.time() - batch_start
all_results.extend(chunk_results)
yield {
"status": "batch_complete",
"batch_number": batch_idx + 1,
"batch_time": batch_time,
"texts_per_second": len(text_chunk) / batch_time,
"partial_results": chunk_results
}
# Final results
total_time = time.time() - start_time
yield {
"status": "completed",
"total_time": total_time,
"average_time_per_text": total_time / len(data.texts),
"final_results": all_results
}
# Test locallyif **name** == "**main**":
import asyncio
asyncdeftest_batch_processing():
# Simulate startupawait load_model(chute)
# Test batch
test_texts = [
"I love this product!",
"Terrible quality, very disappointed.",
"Pretty good, would recommend.",
"Outstanding service and delivery!",
"Not worth the money spent.",
"Amazing features and great design!"
]
test_input = BatchTextInput(
texts=test_texts,
batch_size=3
)
result = await process_batch(chute, test_input)
print(f"Processed {result.total_texts} texts in {result.processing_time:.2f}s")
print(f"Average time per text: {result.average_time_per_text:.3f}s")
for r in result.results:
print(f"'{r.text[:30]}...' -> {r.sentiment} ({r.confidence:.2f})")
asyncio.run(test_batch_processing())
Performance Optimization Techniques
1. Dynamic Batching
# Automatically adjust batch size based on text lengthsdefoptimize_batch_size(texts: List[str], max_tokens: int = 8192) -> int:
avg_length = sum(len(text.split()) for text in texts) / len(texts)
estimated_tokens_per_text = avg_length * 1.3# Account for subword tokenization
optimal_batch_size = max(1, int(max_tokens / estimated_tokens_per_text))
returnmin(optimal_batch_size, 32) # Cap at 32 for memory safety
2. Memory-Efficient Processing
# Process very large batches in chunksasyncdefprocess_large_batch(self, texts: List[str], chunk_size: int = 50):
results = []
for i inrange(0, len(texts), chunk_size):
chunk = texts[i:i + chunk_size]
chunk_results = awaitself._process_chunk(chunk, 512, i)
results.extend(chunk_results)
# Clear GPU cache between chunksif torch.cuda.is_available():
torch.cuda.empty_cache()
return results
3. Mixed Precision Training
# Use automatic mixed precision for faster processingwith torch.cuda.amp.autocast():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
Testing the Batch API
Simple Batch Test
import requests
import time
# Prepare test data
texts = [
"I absolutely love this new product!",
"Worst purchase I've ever made.",
"It's okay, nothing special.",
"Fantastic quality and great service!",
"Complete waste of money.",
"Highly recommend to everyone!",
"Poor customer support experience.",
"Exceeded all my expectations!",
"Not worth the high price.",
"Perfect for my needs!"
]
# Test different batch sizesfor batch_size in [2, 5, 10]:
print(f"\nTesting batch size: {batch_size}")
start_time = time.time()
response = requests.post(
"https://myuser-batch-processor.chutes.ai/process-batch",
json={
"texts": texts,
"batch_size": batch_size,
"max_length": 256
}
)
result = response.json()
print(f"Total time: {result['processing_time']:.2f}s")
print(f"Texts/second: {result['performance_metrics']['texts_per_second']:.1f}")
print(f"Avg time per text: {result['average_time_per_text']:.3f}s")
Performance Comparison
import asyncio
import aiohttp
import time
asyncdefcompare_batch_vs_individual():
"""Compare batch processing vs individual requests."""
texts = ["Sample text for testing"] * 20# Test individual requests
start_time = time.time()
individual_results = []
asyncwith aiohttp.ClientSession() as session:
tasks = []
for text in texts:
task = session.post(
"https://myuser-batch-processor.chutes.ai/analyze-single",
json={"text": text}
)
tasks.append(task)
responses = await asyncio.gather(*tasks)
for resp in responses:
result = await resp.json()
individual_results.append(result)
individual_time = time.time() - start_time
# Test batch processing
start_time = time.time()
asyncwith aiohttp.ClientSession() as session:
asyncwith session.post(
"https://myuser-batch-processor.chutes.ai/process-batch",
json={"texts": texts, "batch_size": 10}
) as resp:
batch_result = await resp.json()
batch_time = time.time() - start_time
print(f"Individual requests: {individual_time:.2f}s")
print(f"Batch processing: {batch_time:.2f}s")
print(f"Speedup: {individual_time / batch_time:.1f}x")
asyncio.run(compare_batch_vs_individual())
Streaming Batch Processing
import asyncio
import aiohttp
import json
asyncdeftest_streaming_batch():
"""Test streaming batch processing with progress updates."""
texts = [f"Test message number {i} for batch processing"for i inrange(25)]
asyncwith aiohttp.ClientSession() as session:
asyncwith session.post(
"https://myuser-batch-processor.chutes.ai/process-batch-stream",
json={"texts": texts, "batch_size": 5}
) as response:
asyncfor line in response.content:
if line:
try:
data = json.loads(line.decode())
if data['status'] == 'processing_batch':
print(f"Processing batch {data['batch_number']} ({data['progress']:.1%} complete)")
elif data['status'] == 'batch_complete':
print(f"Batch {data['batch_number']} completed in {data['batch_time']:.2f}s")
elif data['status'] == 'completed':
print(f"All processing completed in {data['total_time']:.2f}s")
except json.JSONDecodeError:
continue
asyncio.run(test_streaming_batch())
Key Performance Concepts
1. Batch Size Optimization
# Find optimal batch size for your hardwaredeffind_optimal_batch_size(model, tokenizer, device, max_length=512):
batch_sizes = [1, 2, 4, 8, 16, 32]
test_texts = ["Sample text for testing"] * 32
best_throughput = 0
best_batch_size = 1for batch_size in batch_sizes:
try:
start_time = time.time()
# Test processingfor i inrange(0, len(test_texts), batch_size):
batch = test_texts[i:i + batch_size]
encoded = tokenizer(batch, padding=True, truncation=True,
max_length=max_length, return_tensors="pt")
with torch.no_grad():
_ = model(**encoded.to(device))
total_time = time.time() - start_time
throughput = len(test_texts) / total_time
if throughput > best_throughput:
best_throughput = throughput
best_batch_size = batch_size
except RuntimeError as e:
if"out of memory"instr(e):
breakreturn best_batch_size, best_throughput
2. Memory Management
# Monitor and manage GPU memorydefmanage_gpu_memory():
if torch.cuda.is_available():
# Clear cache between large batches
torch.cuda.empty_cache()
# Get memory usage
allocated = torch.cuda.memory_allocated() / 1024**3
cached = torch.cuda.memory_reserved() / 1024**3print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")
# Set memory fraction if needed
torch.cuda.set_per_process_memory_fraction(0.8)
3. Padding Optimization
# Minimize padding for better efficiencydefoptimize_padding(texts, tokenizer, max_length):
# Sort by length to minimize padding
text_lengths = [(len(text), i, text) for i, text inenumerate(texts)]
text_lengths.sort()
batches = []
current_batch = []
for length, original_idx, text in text_lengths:
current_batch.append((original_idx, text))
# Create batch when we have enough similar-length textsiflen(current_batch) >= batch_size:
batches.append(current_batch)
current_batch = []
if current_batch:
batches.append(current_batch)
return batches