Machine Learning Integration
The spectrograms library integrates seamlessly with deep learning frameworks through the DLPack protocol, enabling tensor exchange with PyTorch, JAX, TensorFlow, and other ML libraries.
Overview
Key Features:
zero-copy exchange: Direct memory sharing without data duplication
Framework support: PyTorch, JAX, and any DLPack-compatible library
Convenience wrappers: High-level
.to_torch()and.to_jax()methodsMetadata preservation: Optional retention of frequency/time axes and parameters
Batching utilities: Efficient multi-spectrogram batching for training
Device flexibility: CPU and GPU support (framework-dependent)
The library provides two integration approaches:
Direct DLPack (
torch.from_dlpack()): Universal, standardConvenience modules (
spectrograms.torch,spectrograms.jax): Enhanced ergonomics with metadata
DLPack Protocol
All spectrogram and chromagram objects implement the Python DLPack protocol, which enables tensor exchange between libraries.
Basic Usage
import spectrograms as sg
import torch
# Compute a spectrogram
spec = sg.compute_mel_power_spectrogram(samples, params, mel_params)
# Convert to PyTorch tensor ()
tensor = torch.from_dlpack(spec)
print(tensor.shape) # (n_mels, n_frames)
This works with any DLPack-compatible framework:
# PyTorch
import torch
tensor = torch.from_dlpack(spec)
# JAX
import jax.dlpack
array = jax.dlpack.from_dlpack(spec)
# TensorFlow (v2.15+)
import tensorflow as tf
tensor = tf.experimental.dlpack.from_dlpack(spec)
# CuPy
import cupy
array = cupy.from_dlpack(spec)
Memory Efficiency
DLPack creates a view of the spectrogram data without copying:
spec = sg.compute_mel_db_spectrogram(samples, params, mel_params, db_params)
tensor = torch.from_dlpack(spec)
# Both share the same underlying memory
spec.data[0, 0] = 999.0
print(tensor[0, 0]) # 999.0
Important: Keep the original Spectrogram object alive while using the tensor, as they share memory.
Supported Types
The DLPack protocol works with all spectrogram types:
Spectrogram(from anycompute_*function)Chromagram(fromcompute_chromagram)Fft2dResult(2D FFT results)
PyTorch Integration
For enhanced ergonomics with PyTorch, import the spectrograms.torch module:
import spectrograms as sg
import spectrograms.torch # Adds .to_torch() method
Basic Conversion
spec = sg.compute_mel_power_spectrogram(samples, params, mel_params)
# Simple conversion (returns torch.Tensor)
tensor = spec.to_torch()
# GPU conversion
tensor = spec.to_torch(device='cuda')
# With specific dtype
tensor = spec.to_torch(device='cpu', dtype=torch.float32)
With Metadata
Preserve frequency/time axes and computation parameters:
result = spec.to_torch(device='cuda', with_metadata=True)
# Access tensor and metadata
result.tensor # torch.Tensor on GPU
result.frequencies # np.ndarray of frequency values
result.times # np.ndarray of time values
result.params # SpectrogramParams object
result.shape # (n_bins, n_frames)
result.db_range # (min_db, max_db) if applicable
# Move to different device
result_cpu = result.cpu()
result_gpu = result.cuda(device=0)
Batching for Training
Create batched tensors from multiple spectrograms:
import spectrograms.torch as sgt
# Compute spectrograms for batch of audio samples
specs = [
sg.compute_mel_power_spectrogram(audio, params, mel_params)
for audio in audio_batch
]
# Stack into batch tensor (batch_size, n_mels, n_frames)
batch_tensor = sgt.batch(specs, device='cuda')
# With automatic padding for variable lengths
batch_tensor = sgt.batch(specs, device='cuda', pad=True)
# Preserve metadata for each sample
batch_tensor, metadata = sgt.batch_with_metadata(specs, device='cuda')
Neural Network Pipeline
Complete example for audio classification:
import torch
import torch.nn as nn
import spectrograms as sg
import spectrograms.torch as sgt
class AudioClassifier(nn.Module):
def __init__(self, n_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((4, 4))
self.fc = nn.Linear(64 * 4 * 4, n_classes)
def forward(self, x):
# x: (batch, n_mels, n_frames)
x = x.unsqueeze(1) # Add channel dim: (batch, 1, n_mels, n_frames)
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.flatten(1)
return self.fc(x)
# Training setup
model = AudioClassifier(n_classes=10).cuda()
stft = sg.StftParams(n_fft=512, hop_size=256, window=sg.WindowType.hanning)
params = sg.SpectrogramParams(stft, sample_rate=16000.0)
mel_params = sg.MelParams(n_mels=128, f_min=0.0, f_max=8000.0)
db_params = sg.LogParams(floor_db=-80.0)
# Training loop
for audio_batch, labels in dataloader:
# Compute spectrograms on CPU (fast with Rust backend)
specs = [
sg.compute_mel_db_spectrogram(audio, params, mel_params, db_params)
for audio in audio_batch
]
# Batch and transfer to GPU
inputs = sgt.batch(specs, device='cuda', dtype=torch.float32)
labels = labels.cuda()
# Forward pass
outputs = model(inputs)
loss = nn.CrossEntropyLoss()(outputs, labels)
# Backward and optimize
loss.backward()
optimizer.step()
JAX Integration
For JAX workflows, import the spectrograms.jax module:
import spectrograms as sg
import spectrograms.jax # Adds .to_jax() method
Basic Conversion
spec = sg.compute_mel_power_spectrogram(samples, params, mel_params)
# Simple conversion (returns jax.Array)
array = spec.to_jax()
# GPU placement
array = spec.to_jax(device='gpu')
# CPU placement
array = spec.to_jax(device='cpu')
With Metadata
result = spec.to_jax(device='gpu', with_metadata=True)
result.array # jax.Array on GPU
result.frequencies # np.ndarray
result.times # np.ndarray
result.params # SpectrogramParams
result.shape # (n_bins, n_frames)
# Move between devices
result_cpu = result.cpu()
result_gpu = result.gpu()
Batching
import spectrograms.jax as sgj
specs = [
sg.compute_mel_power_spectrogram(audio, params, mel_params)
for audio in audio_batch
]
# Stack into batch array
batch_array = sgj.batch(specs, device='gpu')
JAX Training Example
import jax
import jax.numpy as jnp
import spectrograms as sg
import spectrograms.jax as sgj
from flax import linen as nn
class AudioCNN(nn.Module):
n_classes: int = 10
@nn.compact
def __call__(self, x):
# x: (batch, n_mels, n_frames)
x = x[..., jnp.newaxis] # Add channel: (batch, n_mels, n_frames, 1)
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = jnp.mean(x, axis=(1, 2)) # Global average pooling
return nn.Dense(self.n_classes)(x)
# Initialize model
model = AudioCNN(n_classes=10)
rng = jax.random.PRNGKey(0)
params_model = model.init(rng, jnp.ones((1, 128, 100)))
# Training step
@jax.jit
def train_step(params_model, inputs, labels):
def loss_fn(params_model):
logits = model.apply(params_model, inputs)
return jnp.mean(optax.softmax_cross_entropy(logits, labels))
loss, grads = jax.value_and_grad(loss_fn)(params_model)
return loss, grads
# Compute spectrograms and convert to JAX
stft = sg.StftParams(n_fft=512, hop_size=256, window=sg.WindowType.hanning)
spec_params = sg.SpectrogramParams(stft, sample_rate=16000.0)
mel_params = sg.MelParams(n_mels=128, f_min=0.0, f_max=8000.0)
for audio_batch, labels in dataloader:
specs = [
sg.compute_mel_power_spectrogram(audio, spec_params, mel_params)
for audio in audio_batch
]
inputs = sgj.batch(specs, device='gpu')
loss, grads = train_step(params_model, inputs, labels)
API Reference
Spectrogram Methods
All Spectrogram and Chromagram objects provide:
DLPack Protocol:
__dlpack__(*, stream=None, max_version=None, dl_device=None, copy=None)
__dlpack_device__() -> tuple[int, int]
PyTorch (when spectrograms.torch is imported):
.to_torch(device='cpu', with_metadata=False, dtype=None)
-> torch.Tensor | TorchSpectrogram
JAX (when spectrograms.jax is imported):
.to_jax(device='cpu', with_metadata=False)
-> jax.Array | JaxSpectrogram
Module Functions
spectrograms.torch:
batch(specs, device='cpu', dtype=None, pad=False) -> torch.Tensorbatch_with_metadata(specs, device='cpu', dtype=None, pad=False) -> tuple[torch.Tensor, list[dict]]
spectrograms.jax:
batch(specs, device='cpu', pad=False) -> jax.Arraybatch_with_metadata(specs, device='cpu', pad=False) -> tuple[jax.Array, list[dict]]
Best Practices
Memory Management
When using DLPack, the tensor shares memory with the original spectrogram:
# ✓ Good: Keep spectrogram alive
spec = sg.compute_mel_power_spectrogram(samples, params, mel_params)
tensor = torch.from_dlpack(spec)
process(tensor) # spec is still in scope
# ✗ Bad: Spectrogram may be garbage collected
tensor = torch.from_dlpack(
sg.compute_mel_power_spectrogram(samples, params, mel_params)
)
# Original data may be freed!
If you need independent memory, use copy=True:
spec = sg.compute_mel_power_spectrogram(samples, params, mel_params)
tensor = torch.from_dlpack(spec.__dlpack__(copy=True))
Performance Tips
Batch on CPU, transfer once: Compute spectrograms with CPU threads, then batch transfer to GPU
# Parallel CPU computation (GIL-free) specs = [compute_spec(audio) for audio in batch] # Single GPU transfer gpu_batch = sgt.batch(specs, device='cuda')
Reuse plans for batches: Use planner for consistent parameters
planner = sg.SpectrogramPlanner() plan = planner.mel_db_plan(params, mel_params, db_params) for batch in dataloader: specs = [plan.compute(audio) for audio in batch] gpu_batch = sgt.batch(specs, device='cuda')
Choose appropriate dtype: Use
float32for ML to reduce memorytensor = spec.to_torch(device='cuda', dtype=torch.float32)
Device Support
Current limitations:
Spectrograms are computed on CPU only (using FFTW)
DLPack transfers to GPU happen after computation
GPU FFT computation is not yet supported
This is typically fine because:
CPU spectrogram computation is very fast (Rust + FFTW)
GPU is used for model training/inference where it’s most beneficial
Data transfer is a one-time cost per batch
See Also
Performance and Benchmarks - General performance optimization
Batch Processing - Batch processing with reusable plans