Text-to-Speech with CSM-1B
This guide demonstrates how to build a sophisticated text-to-speech (TTS) service using CSM-1B (Conversational Speech Model), capable of generating natural-sounding speech with context awareness and multiple speaker support.
Overview
CSM-1B from Sesame is a state-of-the-art speech generation model that:
- Generates high-quality speech from text input
- Supports multiple speakers (2 speakers available)
- Uses context from previous audio/text for continuity
- Employs Llama backbone with specialized audio decoder
- Produces Mimi audio codes for natural speech output
- Supports configurable duration limits
Complete Implementation
Input Schema Design
Define comprehensive input validation for TTS generation:
from pydantic import BaseModel, Field
from typing import Optional, List
class Context(BaseModel):
text: str
speaker: int = Field(0, gte=0, lte=1)
audio_b64: str # Base64 encoded reference audio
class InputArgs(BaseModel):
text: str
context: Optional[List[Context]] = []
speaker: Optional[int] = Field(1, gte=0, lte=1)
max_duration_ms: Optional[int] = 10000 # Maximum output duration
Custom Image with CSM-1B
Build a custom image with all required dependencies:
from chutes.image import Image
from chutes.chute import Chute, NodeSelector
image = (
Image(
username="myuser",
name="csm-1b",
tag="0.0.2",
readme="## Text-to-speech using sesame/csm-1b")
.from_base("parachutes/base-python:3.12.9")
.run_command(
"pip install -r https://huggingface.co/chutesai/csm-1b/resolve/main/requirements.txt"
)
.run_command("pip install pybase64") # For audio encoding/decoding
.run_command(
"wget -O /app/generator.py https://huggingface.co/chutesai/csm-1b/resolve/main/generator.py"
)
.run_command(
"wget -O /app/models.py https://huggingface.co/chutesai/csm-1b/resolve/main/models.py"
)
.run_command(
"wget -O /app/watermarking.py https://huggingface.co/chutesai/csm-1b/resolve/main/watermarking.py"
)
)
Chute Configuration
Configure the service with appropriate GPU requirements:
chute = Chute(
username="myuser",
name="csm-1b-tts",
tagline="High-quality text-to-speech with CSM-1B",
readme="CSM (Conversational Speech Model) generates natural speech from text with context awareness and multiple speaker support.",
image=image,
node_selector=NodeSelector(
gpu_count=1,
min_vram_gb_per_gpu=24 # 24GB required for optimal performance
))
Model Initialization
Load and initialize the CSM-1B model on startup:
@chute.on_startup()
async def initialize(self):
"""
Initialize the CSM-1B model and perform warmup.
"""
from huggingface_hub import snapshot_download
from generator import Generator
from models import Model
import torchaudio
import torch
# Download the model with specific revision
revision = "01e2ed64be01915391ec7881f666d6dda0e1d509"
snapshot_download("chutesai/csm-1b", revision=revision)
# Store torchaudio for later use
self.torchaudio = torchaudio
# Initialize the model
model = Model.from_pretrained("chutesai/csm-1b", revision=revision)
model.to(device="cuda", dtype=torch.bfloat16)
# Create the generator
self.generator = Generator(model)
# Warmup generation to load models into memory
_ = self.generator.generate(
text="Warming up Sesame...",
speaker=0,
context=[],
max_audio_length_ms=10000)
Audio Processing Utilities
Add utilities for handling audio input and output:
import pybase64 as base64
import tempfile
import os
from io import BytesIO
from loguru import logger
from fastapi import HTTPException, status
def load_audio(self, audio_b64):
"""
Convert base64 audio data into audio tensor.
Ensures the output is a 1D tensor [T] for compatibility.
"""
try:
# Decode base64 to audio bytes
audio_bytes = BytesIO(base64.b64decode(audio_b64))
# Save to temporary file for processing
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file.write(audio_bytes.getvalue())
temp_path = temp_file.name
# Load audio with torchaudio
waveform, sample_rate = self.torchaudio.load(temp_path)
os.unlink(temp_path) # Clean up temp file
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0)
else:
waveform = waveform.squeeze(0)
# Resample to model's expected sample rate
audio_tensor = self.torchaudio.functional.resample(
waveform,
orig_freq=sample_rate,
new_freq=self.generator.sample_rate)
# Ensure 1D tensor
if audio_tensor.dim() > 1:
audio_tensor = audio_tensor.squeeze()
return audio_tensor
except Exception as exc:
logger.error(f"Error loading audio: {exc}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input audio_b64 provided: {exc}")
Text-to-Speech Endpoint
Create the main TTS generation endpoint:
import uuid
from fastapi import Response
@chute.cord(
public_api_path="/speak",
public_api_method="POST",
stream=False,
output_content_type="audio/wav")
async def speak(self, args: InputArgs) -> Response:
"""
Convert text to speech with optional context.
"""
from generator import Segment
# Process context if provided
segments = []
if args.context:
for ctx in args.context:
audio_tensor = load_audio(self, ctx.audio_b64)
segments.append(
Segment(
text=ctx.text,
speaker=ctx.speaker,
audio=audio_tensor)
)
# Generate speech audio
audio = self.generator.generate(
text=args.text,
speaker=args.speaker,
context=segments,
max_audio_length_ms=args.max_duration_ms)
# Save to temporary file
path = f"/tmp/{uuid.uuid4()}.wav"
self.torchaudio.save(
path,
audio.unsqueeze(0).cpu(),
self.generator.sample_rate
)
try:
# Return audio file
with open(path, "rb") as infile:
return Response(
content=infile.read(),
media_type="audio/wav",
headers={
"Content-Disposition": f"attachment; filename={uuid.uuid4()}.wav",
})
finally:
# Clean up temporary file
if os.path.exists(path):
os.remove(path)
Advanced Features
Multi-Speaker Conversation
Create endpoint for generating conversation between speakers:
class ConversationTurn(BaseModel):
speaker: int = Field(ge=0, le=1)
text: str
pause_ms: Optional[int] = Field(default=500, ge=0, le=2000)
class ConversationInput(BaseModel):
turns: List[ConversationTurn]
max_total_duration_ms: int = Field(default=30000, ge=5000, le=60000)
@chute.cord(public_api_path="/conversation", method="POST")
async def generate_conversation(self, args: ConversationInput) -> Response:
"""Generate a conversation between multiple speakers."""
from generator import Segment
conversation_audio = []
context_segments = []
for turn in args.turns:
# Generate speech for this turn with accumulated context
audio = self.generator.generate(
text=turn.text,
speaker=turn.speaker,
context=context_segments,
max_audio_length_ms=args.max_total_duration_ms // len(args.turns))
conversation_audio.append(audio)
# Add silence between turns
if turn.pause_ms > 0:
silence_samples = int(turn.pause_ms * self.generator.sample_rate / 1000)
silence = torch.zeros(silence_samples)
conversation_audio.append(silence)
# Add this turn to context for future turns
context_segments.append(
Segment(
text=turn.text,
speaker=turn.speaker,
audio=audio)
)
# Concatenate all audio
full_audio = torch.cat(conversation_audio, dim=0)
# Save and return
path = f"/tmp/conversation_{uuid.uuid4()}.wav"
self.torchaudio.save(path, full_audio.unsqueeze(0).cpu(), self.generator.sample_rate)
try:
with open(path, "rb") as infile:
return Response(
content=infile.read(),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=conversation.wav"})
finally:
if os.path.exists(path):
os.remove(path)
Voice Cloning with Reference Audio
Clone a voice from a reference audio sample:
class VoiceCloningInput(BaseModel):
text: str
reference_audio_b64: str
reference_text: str # Text that was spoken in reference audio
max_duration_ms: int = Field(default=15000, ge=1000, le=30000)
@chute.cord(public_api_path="/clone_voice", method="POST")
async def clone_voice(self, args: VoiceCloningInput) -> Response:
"""Generate speech using a reference voice sample."""
from generator import Segment
# Load reference audio
reference_audio = load_audio(self, args.reference_audio_b64)
# Create context segment from reference
reference_segment = Segment(
text=args.reference_text,
speaker=0, # Use speaker 0 as base
audio=reference_audio)
# Generate new speech with reference voice characteristics
audio = self.generator.generate(
text=args.text,
speaker=0,
context=[reference_segment],
max_audio_length_ms=args.max_duration_ms)
# Save and return
path = f"/tmp/cloned_{uuid.uuid4()}.wav"
self.torchaudio.save(path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
try:
with open(path, "rb") as infile:
return Response(
content=infile.read(),
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=cloned_voice.wav"})
finally:
if os.path.exists(path):
os.remove(path)
Batch Processing
Process multiple texts efficiently:
class BatchTTSInput(BaseModel):
texts: List[str] = Field(max_items=10) # Limit batch size
speaker: int = Field(default=0, ge=0, le=1)
max_duration_per_text_ms: int = Field(default=10000, ge=1000, le=20000)
@chute.cord(public_api_path="/batch_speak", method="POST")
async def batch_speak(self, args: BatchTTSInput) -> List[str]:
"""Generate speech for multiple texts and return as base64 list."""
results = []
for text in args.texts:
# Generate audio for each text
audio = self.generator.generate(
text=text,
speaker=args.speaker,
context=[],
max_audio_length_ms=args.max_duration_per_text_ms)
# Convert to WAV bytes
path = f"/tmp/batch_{uuid.uuid4()}.wav"
self.torchaudio.save(path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
try:
with open(path, "rb") as infile:
audio_b64 = base64.b64encode(infile.read()).decode()
results.append(audio_b64)
finally:
if os.path.exists(path):
os.remove(path)
return results
Deployment and Usage
Deploy the Service
# Build and deploy the TTS service
chutes deploy my_tts:chute
# Monitor the deployment
chutes chutes get my-tts
Using the API
Basic Text-to-Speech
curl -X POST "https://myuser-my-tts.chutes.ai/speak" \
-H "Content-Type: application/json" \
-d '{
"text": "Hello, this is a demonstration of high-quality text-to-speech synthesis.",
"speaker": 0,
"max_duration_ms": 15000
}' \
--output speech.wav
Voice Cloning
# First, encode your reference audio to base64
# base64 -i reference.wav > reference.b64
curl -X POST "https://myuser-my-tts.chutes.ai/clone_voice" \
-H "Content-Type: application/json" \
-d '{
"text": "This is new text spoken in the reference voice",
"reference_audio_b64": "'$(cat reference.b64)'",
"reference_text": "Original text that was spoken in the reference audio",
"max_duration_ms": 20000
}' \
--output cloned_speech.wav
Python Client Example
import requests
import base64
import io
from pydantic import BaseModel
from typing import List, Optional
class TTSClient:
def __init__(self, base_url: str):
self.base_url = base_url.rstrip('/')
def speak(self, text: str, speaker: int = 0, max_duration_ms: int = 10000) -> bytes:
"""Generate speech from text."""
response = requests.post(
f"{self.base_url}/speak",
json={
"text": text,
"speaker": speaker,
"max_duration_ms": max_duration_ms
}
)
if response.status_code == 200:
return response.content
else:
raise Exception(f"TTS failed: {response.status_code} - {response.text}")
def clone_voice(self, text: str, reference_audio_path: str, reference_text: str) -> bytes:
"""Generate speech using voice cloning."""
# Encode reference audio
with open(reference_audio_path, "rb") as f:
reference_b64 = base64.b64encode(f.read()).decode()
response = requests.post(
f"{self.base_url}/clone_voice",
json={
"text": text,
"reference_audio_b64": reference_b64,
"reference_text": reference_text,
"max_duration_ms": 20000
}
)
return response.content
def generate_conversation(self, turns: List[dict]) -> bytes:
"""Generate a conversation between speakers."""
response = requests.post(
f"{self.base_url}/conversation",
json={
"turns": turns,
"max_total_duration_ms": 30000
}
)
return response.content
def batch_speak(self, texts: List[str], speaker: int = 0) -> List[bytes]:
"""Generate speech for multiple texts."""
response = requests.post(
f"{self.base_url}/batch_speak",
json={
"texts": texts,
"speaker": speaker,
"max_duration_per_text_ms": 10000
}
)
if response.status_code == 200:
b64_results = response.json()
return [base64.b64decode(b64) for b64 in b64_results]
else:
raise Exception(f"Batch TTS failed: {response.status_code}")
# Usage examples
client = TTSClient("https://myuser-my-tts.chutes.ai")
# Basic TTS
speech_audio = client.speak("Hello, world! This is synthesized speech.")
with open("hello.wav", "wb") as f:
f.write(speech_audio)
# Voice cloning
cloned_audio = client.clone_voice(
text="This is new content in the cloned voice",
reference_audio_path="reference_voice.wav",
reference_text="This was the original reference text"
)
with open("cloned.wav", "wb") as f:
f.write(cloned_audio)
# Conversation generation
conversation_turns = [
{"speaker": 0, "text": "Hello, how are you today?", "pause_ms": 1000},
{"speaker": 1, "text": "I'm doing great, thanks for asking!", "pause_ms": 800},
{"speaker": 0, "text": "That's wonderful to hear.", "pause_ms": 500}
]
conversation_audio = client.generate_conversation(conversation_turns)
with open("conversation.wav", "wb") as f:
f.write(conversation_audio)
Best Practices
1. Text Preprocessing
import re
def preprocess_text(text: str) -> str:
"""Clean and prepare text for TTS."""
# Expand common abbreviations
text = text.replace("Dr.", "Doctor")
text = text.replace("Mr.", "Mister")
text = text.replace("Mrs.", "Missus")
text = text.replace("&", "and")
# Handle numbers (basic example)
text = re.sub(r'\b(\d+)\b', lambda m: num_to_words(int(m.group(1))), text)
# Remove excessive punctuation
text = re.sub(r'[.]{2,}', '.', text)
text = re.sub(r'[!]{2,}', '!', text)
text = re.sub(r'[?]{2,}', '?', text)
return text.strip()
def num_to_words(num: int) -> str:
"""Convert numbers to words (basic implementation)."""
if num == 0:
return "zero"
elif num == 1:
return "one"
# Add more number conversions as needed
else:
return str(num) # Fallback
2. Context Management
class ContextManager:
"""Manage conversation context for better continuity."""
def __init__(self, max_context_length: int = 5):
self.context_segments = []
self.max_length = max_context_length
def add_segment(self, text: str, speaker: int, audio_tensor):
"""Add a new segment to context."""
from generator import Segment
segment = Segment(text=text, speaker=speaker, audio=audio_tensor)
self.context_segments.append(segment)
# Keep only recent context
if len(self.context_segments) > self.max_length:
self.context_segments = self.context_segments[-self.max_length:]
def get_context(self) -> List:
"""Get current context for generation."""
return self.context_segments.copy()
def clear(self):
"""Clear all context."""
self.context_segments = []
# Usage in endpoint
@chute.cord(public_api_path="/contextual_speak", method="POST")
async def contextual_speak(self, args: InputArgs) -> Response:
"""Generate speech with persistent context."""
if not hasattr(self, 'context_manager'):
self.context_manager = ContextManager()
# Generate with context
audio = self.generator.generate(
text=args.text,
speaker=args.speaker,
context=self.context_manager.get_context(),
max_audio_length_ms=args.max_duration_ms)
# Add to context for future generations
self.context_manager.add_segment(args.text, args.speaker, audio)
# Return audio...
3. Quality Control
def validate_audio_quality(audio_tensor, sample_rate: int) -> bool:
"""Check generated audio quality."""
import torch
# Check for silence (all zeros)
if torch.all(audio_tensor == 0):
return False
# Check for clipping
if torch.max(torch.abs(audio_tensor)) > 0.99:
return False
# Check minimum duration (avoid too short clips)
min_duration_ms = 500
min_samples = int(min_duration_ms * sample_rate / 1000)
if len(audio_tensor) < min_samples:
return False
return True
@chute.cord(public_api_path="/quality_speak", method="POST")
async def quality_controlled_speak(self, args: InputArgs) -> Response:
"""Generate speech with quality validation."""
max_retries = 3
for attempt in range(max_retries):
audio = self.generator.generate(
text=args.text,
speaker=args.speaker,
context=[],
max_audio_length_ms=args.max_duration_ms)
if validate_audio_quality(audio, self.generator.sample_rate):
# Quality passed, return audio
break
else:
logger.warning(f"Audio quality check failed, attempt {attempt + 1}")
if attempt == max_retries - 1:
raise HTTPException(
status_code=500,
detail="Failed to generate quality audio after multiple attempts"
)
# Save and return validated audio...
Performance Optimization
Memory Management
@chute.cord(public_api_path="/optimized_speak", method="POST")
async def optimized_speak(self, args: InputArgs) -> Response:
"""Memory-optimized speech generation."""
import torch
try:
# Clear cache before generation
torch.cuda.empty_cache()
# Generate with memory efficiency
with torch.inference_mode():
audio = self.generator.generate(
text=args.text,
speaker=args.speaker,
context=args.context,
max_audio_length_ms=args.max_duration_ms)
# Process and return immediately
path = f"/tmp/{uuid.uuid4()}.wav"
self.torchaudio.save(path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
# Read and clean up immediately
with open(path, "rb") as infile:
content = infile.read()
os.remove(path)
return Response(
content=content,
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=speech.wav"})
finally:
# Always clear cache after generation
torch.cuda.empty_cache()
Caching for Repeated Requests
import hashlib
from typing import Dict
class TTSCache:
"""Simple cache for TTS results."""
def __init__(self, max_size: int = 100):
self.cache: Dict[str, bytes] = {}
self.max_size = max_size
def get_key(self, text: str, speaker: int) -> str:
"""Generate cache key."""
content = f"{text}_{speaker}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, text: str, speaker: int) -> Optional[bytes]:
"""Get cached result."""
key = self.get_key(text, speaker)
return self.cache.get(key)
def set(self, text: str, speaker: int, audio_bytes: bytes):
"""Cache result."""
if len(self.cache) >= self.max_size:
# Remove oldest item (simple FIFO)
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
key = self.get_key(text, speaker)
self.cache[key] = audio_bytes
# Add to chute initialization
@chute.on_startup()
async def initialize_with_cache(self):
# ... existing initialization ...
self.tts_cache = TTSCache(max_size=200)
@chute.cord(public_api_path="/cached_speak", method="POST")
async def cached_speak(self, args: InputArgs) -> Response:
"""TTS with caching for repeated requests."""
# Check cache first (only for simple requests without context)
if not args.context:
cached_result = self.tts_cache.get(args.text, args.speaker)
if cached_result:
return Response(
content=cached_result,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=cached_speech.wav"})
# Generate new audio
audio = self.generator.generate(
text=args.text,
speaker=args.speaker,
context=[],
max_audio_length_ms=args.max_duration_ms)
# Save to file and cache
path = f"/tmp/{uuid.uuid4()}.wav"
self.torchaudio.save(path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
with open(path, "rb") as infile:
audio_bytes = infile.read()
os.remove(path)
# Cache result
if not args.context:
self.tts_cache.set(args.text, args.speaker, audio_bytes)
return Response(
content=audio_bytes,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=speech.wav"})
Monitoring and Troubleshooting
Performance Monitoring
# Check service health
chutes chutes get my-tts
# View generation logs
chutes chutes logs my-tts --tail 100
# Monitor GPU utilization
chutes chutes metrics my-tts
Common Issues and Solutions
# Handle common TTS issues
@chute.cord(public_api_path="/robust_speak", method="POST")
async def robust_speak(self, args: InputArgs) -> Response:
"""TTS with comprehensive error handling."""
try:
# Preprocess text
processed_text = preprocess_text(args.text)
# Validate text length
if len(processed_text) > 1000:
raise HTTPException(
status_code=400,
detail="Text too long. Maximum 1000 characters allowed."
)
# Generate audio
audio = self.generator.generate(
text=processed_text,
speaker=args.speaker,
context=[],
max_audio_length_ms=args.max_duration_ms)
# Validate output
if not validate_audio_quality(audio, self.generator.sample_rate):
raise HTTPException(
status_code=500,
detail="Generated audio failed quality checks"
)
# Return successful result
path = f"/tmp/{uuid.uuid4()}.wav"
self.torchaudio.save(path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
with open(path, "rb") as infile:
content = infile.read()
os.remove(path)
return Response(
content=content,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=speech.wav"})
except torch.cuda.OutOfMemoryError:
raise HTTPException(
status_code=503,
detail="GPU memory exhausted. Please try again or reduce duration."
)
except Exception as e:
logger.error(f"TTS generation failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Speech generation failed: {str(e)}"
)
Next Steps
- Custom Voice Training: Train CSM-1B on your own voice data
- Multilingual Support: Experiment with different languages
- Real-time Streaming: Implement streaming TTS for live applications
- Integration: Build voice assistants and interactive applications
For more advanced examples, see: