Music Generation with DiffRhythm
This guide demonstrates how to build a sophisticated music generation service using DiffRhythm, capable of creating music from text prompts and lyrics with advanced rhythm and style control.
Overview
DiffRhythm (ASLP-lab/DiffRhythm) is a state-of-the-art music generation model that can:
- Generate music from text descriptions and style prompts
- Convert lyrics with timing information into musical performances
- Use reference audio to guide musical style
- Support multiple languages and musical genres
- Generate high-quality 44.1kHz audio output
Complete Implementation
Input Schema Design
Define comprehensive input validation for music generation:
import re
from typing import Optional
from pydantic import BaseModel
from fastapi import HTTPException, status
# Regex for validating LRC (lyric) format timestamps
LRC_RE = re.compile(r"\[(\d+):(\d+\.\d+)\]")
class InputArgs(BaseModel):
style_prompt: Optional[str] = None
lyrics: Optional[str] = None
audio_b64: Optional[str] = None # Reference audio in base64
Custom Image with DiffRhythm
Build a custom image with all required dependencies:
from chutes.image import Image
from chutes.chute import Chute, NodeSelector
image = (
Image(
username="myuser",
name="diffrhythm",
tag="0.0.2",
readme="Music generation with ASLP-lab/DiffRhythm")
.from_base("parachutes/base-python:3.12.9")
.set_user("root")
.run_command("apt update && apt -y install espeak-ng") # For text processing
.set_user("chutes")
.run_command("git clone https://github.com/ASLP-lab/DiffRhythm.git")
.run_command("pip install -r DiffRhythm/requirements.txt")
.run_command("pip install pybase64 py3langid") # Additional dependencies
.run_command("mv -f /app/DiffRhythm/* /app") # Move to app directory
.with_env("PYTHONPATH", "/app/infer") # Set Python path
)
Chute Configuration
Configure the service with appropriate GPU requirements:
chute = Chute(
username="myuser",
name="diffrhythm-music",
tagline="AI Music Generation with DiffRhythm",
readme="Generate music from text descriptions and lyrics using advanced AI",
image=image,
node_selector=NodeSelector(gpu_count=1), # Single GPU sufficient
)
Model Initialization
Load and initialize all required models on startup:
@chute.on_startup()
async def initialize(self):
"""
Initialize DiffRhythm models and dependencies.
"""
from huggingface_hub import snapshot_download
import torchaudio
import torch
import soundfile
from infer_utils import (
decode_audio,
get_lrc_token,
get_negative_style_prompt,
get_reference_latent,
get_style_prompt,
load_checkpoint,
CNENTokenizer)
from infer import inference
from muq import MuQMuLan
from model import DiT, CFM
import json
import os
# Download required models
revision = "613846abae8e5b869b3845a5dfabc9ecc37ecdab"
repo_id = "ASLP-lab/DiffRhythm-full"
path = snapshot_download(repo_id, revision=revision)
vae_path = snapshot_download(
"ASLP-lab/DiffRhythm-vae",
revision="4656f626776f5f924c03471acb25bea6734e774f"
)
# Load model configuration
dit_config_path = "/app/config/diffrhythm-1b.json"
with open(dit_config_path) as f:
model_config = json.load(f)
# Initialize models
dit_model_cls = DiT
self.max_frames = 6144
# CFM (Conditional Flow Matching) model
self.cfm = CFM(
transformer=dit_model_cls(**model_config["model"], max_frames=self.max_frames),
num_channels=model_config["model"]["mel_dim"],
max_frames=self.max_frames
).to("cuda")
# Load trained weights
self.cfm = load_checkpoint(
self.cfm,
os.path.join(path, "cfm_model.pt"),
device="cuda",
use_ema=False
)
# Initialize tokenizer and style model
self.tokenizer = CNENTokenizer()
self.muq = MuQMuLan.from_pretrained(
"OpenMuQ/MuQ-MuLan-large",
revision="8a081dbcf84edd47ea7db3c4ecb8fd1ec1ddacfe"
).to("cuda")
# Load VAE for audio decoding
vae_ckpt_path = os.path.join(vae_path, "vae_model.pt")
self.vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to("cuda")
# Warmup with example generation
await self._warmup_model()
# Store utilities
self.torchaudio = torchaudio
self.torch = torch
self.soundfile = soundfile
self.decode_audio = decode_audio
self.inference = inference
self.get_lrc_token = get_lrc_token
self.get_reference_latent = get_reference_latent
self.get_style_prompt = get_style_prompt
async def _warmup_model(self):
"""Perform warmup generation to load models into memory."""
from infer_utils import get_lrc_token, get_negative_style_prompt, get_reference_latent, get_style_prompt
from infer import inference
# Load example lyrics
with open("/app/infer/example/eg_en_full.lrc", "r", encoding="utf-8") as infile:
lrc = infile.read()
# Prepare warmup data
lrc_prompt, start_time = get_lrc_token(self.max_frames, lrc, self.tokenizer, "cuda")
self.negative_style_prompt = get_negative_style_prompt("cuda")
self.latent_prompt = get_reference_latent("cuda", self.max_frames)
style_prompt = get_style_prompt(self.muq, prompt="classical genres, hopeful mood, piano.")
# Perform warmup generation
with self.torch.no_grad():
generated_song = inference(
cfm_model=self.cfm,
vae_model=self.vae,
cond=self.latent_prompt,
text=lrc_prompt,
duration=self.max_frames,
style_prompt=style_prompt,
negative_style_prompt=self.negative_style_prompt,
start_time=start_time,
chunked=True)
# Save warmup output
output_path = "/app/warmup.mp3"
self.torchaudio.save(output_path, generated_song, sample_rate=44100, format="mp3")
Audio Processing Utilities
Add utilities for handling audio input:
import pybase64 as base64
import tempfile
from io import BytesIO
from loguru import logger
def load_audio(self, audio_b64):
"""
Convert base64 audio to tensor for style extraction.
"""
try:
audio_bytes = BytesIO(base64.b64decode(audio_b64))
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file.write(audio_bytes.getvalue())
temp_path = temp_file.name
waveform, sample_rate = self.torchaudio.load(temp_path)
return temp_path
except Exception as exc:
logger.error(f"Error loading audio: {exc}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input audio_b64 provided: {exc}")
Lyrics Validation
Implement comprehensive lyrics validation with timing:
def validate_lyrics(lyrics: str, total_length: int):
"""
Validate LRC format lyrics for proper timing and format.
"""
def format_time(seconds: float) -> str:
minutes = int(seconds // 60)
remaining_seconds = seconds % 60
return f"{minutes:02d}:{remaining_seconds:05.2f}"
previous_time = -1.0
last_timestamp = 0.0
try:
for line_num, line in enumerate(lyrics.splitlines()):
if not line.strip():
continue
# Check line length
if len(line) > 256:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Line {line_num} exceeds 256 characters: {len(line)} chars")
# Validate timestamp format
valid_match = LRC_RE.match(line)
if valid_match:
minutes = int(valid_match.group(1))
seconds = float(valid_match.group(2))
current_time = minutes * 60 + seconds
last_timestamp = max(last_timestamp, current_time)
# Check chronological order
if current_time < previous_time:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Line {line_num}: Timestamp {format_time(current_time)} "
f"is before previous timestamp {format_time(previous_time)}")
previous_time = current_time
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error validating lyrics: {exc}")
# Check total duration
if last_timestamp > total_length:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Total duration ({format_time(last_timestamp)}) "
f"exceeds maximum allowed length ({format_time(total_length)})")
Music Generation Endpoint
Create the main generation endpoint:
import uuid
import os
from fastapi.responses import Response
@chute.cord(
public_api_path="/generate",
public_api_method="POST",
stream=False,
output_content_type="audio/mp3")
async def generate(self, args: InputArgs) -> Response:
"""
Generate music from style prompts and/or lyrics.
"""
input_path = None
inference_kwargs = dict(
cfm_model=self.cfm,
vae_model=self.vae,
cond=self.latent_prompt,
duration=self.max_frames,
negative_style_prompt=self.negative_style_prompt,
chunked=True)
# Extract style from prompt or reference audio
style_prompt = None
if args.style_prompt:
style_prompt = self.get_style_prompt(self.muq, prompt=args.style_prompt)
elif args.audio_b64:
input_path = load_audio(self, args.audio_b64)
try:
style_prompt = self.get_style_prompt(self.muq, input_path)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input audio: {exc}")
finally:
if input_path and os.path.exists(input_path):
os.remove(input_path)
if style_prompt is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You must provide either style_prompt or audio_b64!")
inference_kwargs["style_prompt"] = style_prompt
# Process lyrics if provided
if args.lyrics:
validate_lyrics(args.lyrics, 285) # Max ~4.75 minutes
lrc_prompt, start_time = self.get_lrc_token(
self.max_frames, args.lyrics or "", self.tokenizer, "cuda"
)
inference_kwargs["text"] = lrc_prompt
inference_kwargs["start_time"] = start_time
# Generate the music
output_path = f"/tmp/{uuid.uuid4()}.mp3"
try:
with self.torch.no_grad():
generated_song = self.inference(**inference_kwargs)
self.torchaudio.save(
output_path, generated_song, sample_rate=44100, format="mp3"
)
# Return audio file
with open(output_path, "rb") as infile:
return Response(
content=infile.read(),
media_type="audio/mp3",
headers={
"Content-Disposition": f"attachment; filename={uuid.uuid4()}.mp3",
})
finally:
if os.path.exists(output_path):
os.remove(output_path)
Advanced Features
Style-Guided Generation
Create endpoint for style-specific music generation:
class StyleRequest(BaseModel):
style_description: str
mood: Optional[str] = "neutral"
genre: Optional[str] = "pop"
instruments: Optional[str] = "piano, guitar"
tempo: Optional[str] = "medium"
@chute.cord(public_api_path="/style_generate", method="POST")
async def generate_with_style(self, request: StyleRequest) -> Response:
"""Generate music with detailed style control."""
# Construct detailed style prompt
style_prompt = f"{request.genre} genre, {request.mood} mood, {request.instruments}"
if request.tempo:
style_prompt += f", {request.tempo} tempo"
if request.style_description:
style_prompt += f", {request.style_description}"
# Generate using style prompt
args = InputArgs(style_prompt=style_prompt)
return await self.generate(args)
Lyrics-to-Music with Timing
Example of properly formatted lyrics with timestamps:
# Example LRC format lyrics
example_lyrics = """
[00:00.00]Verse 1
[00:05.50]In the morning light so bright
[00:10.00]I can see a better sight
[00:15.50]Dreams are calling out my name
[00:20.00]Nothing will be quite the same
[00:25.00]Chorus
[00:27.50]We are rising with the sun
[00:32.00]A new journey has begun
[00:37.50]Every step we take today
[00:42.00]Leads us down a brighter way
[00:47.00]Verse 2
[00:50.00]Through the valleys and the hills
[00:55.50]We will chase away our fears
[01:00.00]With the music in our hearts
[01:05.50]We will make a brand new start
"""
class LyricsRequest(BaseModel):
lyrics: str
style_prompt: str = "uplifting pop song, piano and strings"
@chute.cord(public_api_path="/lyrics_to_music", method="POST")
async def lyrics_to_music(self, request: LyricsRequest) -> Response:
"""Convert timestamped lyrics into a complete song."""
args = InputArgs(
style_prompt=request.style_prompt,
lyrics=request.lyrics
)
return await self.generate(args)
Reference Audio Style Transfer
Extract musical style from uploaded audio:
class StyleTransferRequest(BaseModel):
reference_audio_b64: str
new_lyrics: Optional[str] = None
style_blend: float = Field(default=1.0, ge=0.1, le=1.0)
@chute.cord(public_api_path="/style_transfer", method="POST")
async def style_transfer(self, request: StyleTransferRequest) -> Response:
"""Generate music using the style from reference audio."""
args = InputArgs(
audio_b64=request.reference_audio_b64,
lyrics=request.new_lyrics
)
return await self.generate(args)
Deployment and Usage
Deploy the Service
# Build and deploy the music generation service
chutes deploy my_music_gen:chute
# Monitor the deployment
chutes chutes get my-music-gen
Using the API
Generate with Style Prompt
curl -X POST "https://myuser-my-music-gen.chutes.ai/generate" \
-H "Content-Type: application/json" \
-d '{
"style_prompt": "upbeat electronic dance music, synthesizers, energetic"
}' \
--output generated_music.mp3
Generate with Lyrics
curl -X POST "https://myuser-my-music-gen.chutes.ai/lyrics_to_music" \
-H "Content-Type: application/json" \
-d '{
"lyrics": "[00:00.00]Hello world\n[00:05.00]This is my song\n[00:10.00]Made with AI",
"style_prompt": "acoustic folk, guitar and violin, heartfelt"
}' \
--output lyrical_song.mp3
Python Client Example
import requests
import base64
class MusicGenerator:
def __init__(self, base_url):
self.base_url = base_url
def generate_from_style(self, style_prompt):
"""Generate music from style description."""
response = requests.post(
f"{self.base_url}/generate",
json={"style_prompt": style_prompt}
)
if response.status_code == 200:
return response.content
else:
raise Exception(f"Generation failed: {response.status_code}")
def generate_from_lyrics(self, lyrics, style="pop"):
"""Generate music from timestamped lyrics."""
response = requests.post(
f"{self.base_url}/lyrics_to_music",
json={
"lyrics": lyrics,
"style_prompt": f"{style} style, full band arrangement"
}
)
return response.content
def style_transfer(self, reference_audio_path, new_lyrics=None):
"""Generate music using style from reference audio."""
with open(reference_audio_path, "rb") as f:
audio_b64 = base64.b64encode(f.read()).decode()
payload = {"reference_audio_b64": audio_b64}
if new_lyrics:
payload["new_lyrics"] = new_lyrics
response = requests.post(
f"{self.base_url}/style_transfer",
json=payload
)
return response.content
# Usage example
generator = MusicGenerator("https://myuser-my-music-gen.chutes.ai")
# Generate upbeat electronic music
music = generator.generate_from_style(
"energetic electronic dance music, heavy bass, futuristic sounds"
)
with open("edm_track.mp3", "wb") as f:
f.write(music)
# Generate from lyrics
lyrics = """
[00:00.00]Verse 1
[00:03.00]AI creates the beat
[00:06.00]Technology so sweet
[00:09.00]Music from the future
[00:12.00]Is here to greet ya
"""
song = generator.generate_from_lyrics(lyrics, "electronic pop")
with open("ai_song.mp3", "wb") as f:
f.write(song)
Best Practices
1. Lyrics Formatting
# Good LRC format - clear timing and structure
good_lyrics = """
[00:00.00]Intro
[00:08.00]Verse 1
[00:10.50]Walking down the street tonight
[00:15.00]City lights are shining bright
[00:20.50]Every step I take feels right
[00:25.00]In this neon-colored light
[00:30.00]Chorus
[00:32.50]We are alive, we are free
[00:37.00]This is who we're meant to be
[00:42.50]Dancing through eternity
[00:47.00]In perfect harmony
"""
# Bad format - inconsistent timing
bad_lyrics = """
[00:00]Start
[0:5]Some lyrics here
[15.5]More lyrics without proper format
Random text without timestamp
"""
2. Style Prompt Engineering
# Effective style prompts are specific and descriptive
effective_styles = [
"jazz ballad, piano and saxophone, slow tempo, romantic mood",
"rock anthem, electric guitars, powerful drums, energetic",
"classical orchestral, strings and brass, dramatic, cinematic",
"ambient electronic, synthesizers, dreamy, ethereal atmosphere",
"country folk, acoustic guitar, harmonica, storytelling style"
]
# Avoid vague prompts
vague_styles = [
"good music",
"nice song",
"popular style"
]
3. Audio Quality Optimization
# For highest quality output
@chute.cord(public_api_path="/hq_generate", method="POST")
async def high_quality_generate(self, args: InputArgs) -> Response:
"""Generate high-quality music with extended processing."""
# Use maximum duration for better quality
inference_kwargs = dict(
cfm_model=self.cfm,
vae_model=self.vae,
cond=self.latent_prompt,
duration=self.max_frames, # Use full duration
negative_style_prompt=self.negative_style_prompt,
chunked=False, # Don't chunk for better coherence
)
# ... rest of generation logic
4. Error Handling and Validation
def validate_audio_input(audio_b64: str, max_size_mb: int = 10):
"""Validate audio input size and format."""
try:
audio_data = base64.b64decode(audio_b64)
size_mb = len(audio_data) / (1024 * 1024)
if size_mb > max_size_mb:
raise HTTPException(
status_code=400,
detail=f"Audio file too large: {size_mb:.1f}MB (max: {max_size_mb}MB)"
)
return audio_data
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid audio data: {str(e)}"
)
Performance and Scaling
Memory Optimization
# Clear GPU memory between generations
@chute.cord(public_api_path="/generate", method="POST")
async def generate_optimized(self, args: InputArgs) -> Response:
"""Memory-optimized generation."""
try:
# Clear cache before generation
if hasattr(self, 'torch'):
self.torch.cuda.empty_cache()
# Generate music
result = await self.generate(args)
return result
finally:
# Clean up after generation
if hasattr(self, 'torch'):
self.torch.cuda.empty_cache()
Concurrent Processing
# Configure for multiple concurrent generations
chute = Chute(
username="myuser",
name="diffrhythm-music",
image=image,
node_selector=NodeSelector(
gpu_count=2, # Multiple GPUs for parallel processing
min_vram_gb_per_gpu=24
),
concurrency=4, # Handle multiple requests
)
Monitoring and Troubleshooting
Common Issues and Solutions
# Check service health
chutes chutes get my-music-gen
# View generation logs
chutes chutes logs my-music-gen --tail 50
# Monitor GPU utilization
chutes chutes metrics my-music-gen
Performance Monitoring
import time
from loguru import logger
@chute.cord(public_api_path="/generate_timed", method="POST")
async def generate_with_timing(self, args: InputArgs) -> Response:
"""Generation with performance monitoring."""
start_time = time.time()
try:
result = await self.generate(args)
generation_time = time.time() - start_time
logger.info(f"Generation completed in {generation_time:.2f} seconds")
return result
except Exception as e:
error_time = time.time() - start_time
logger.error(f"Generation failed after {error_time:.2f} seconds: {e}")
raise
Next Steps
- Custom Models: Train DiffRhythm on your own musical datasets
- Style Control: Experiment with different musical genres and moods
- Integration: Build music creation apps and platforms
- Real-time: Implement streaming music generation
For more advanced examples, see: