Machine Learning Integration ============================= The spectrograms library integrates seamlessly with deep learning frameworks through the **DLPack protocol**, enabling tensor exchange with PyTorch, JAX, TensorFlow, and other ML libraries. Overview -------- **Key Features:** - **zero-copy exchange**: Direct memory sharing without data duplication - **Framework support**: PyTorch, JAX, and any DLPack-compatible library - **Convenience wrappers**: High-level ``.to_torch()`` and ``.to_jax()`` methods - **Metadata preservation**: Optional retention of frequency/time axes and parameters - **Batching utilities**: Efficient multi-spectrogram batching for training - **Device flexibility**: CPU and GPU support (framework-dependent) The library provides two integration approaches: 1. **Direct DLPack** (``torch.from_dlpack()``): Universal, standard 2. **Convenience modules** (``spectrograms.torch``, ``spectrograms.jax``): Enhanced ergonomics with metadata DLPack Protocol --------------- All spectrogram and chromagram objects implement the Python `DLPack protocol `_, which enables tensor exchange between libraries. Basic Usage ~~~~~~~~~~~ .. code-block:: python import spectrograms as sg import torch # Compute a spectrogram spec = sg.compute_mel_power_spectrogram(samples, params, mel_params) # Convert to PyTorch tensor () tensor = torch.from_dlpack(spec) print(tensor.shape) # (n_mels, n_frames) This works with **any DLPack-compatible framework**: .. code-block:: python # PyTorch import torch tensor = torch.from_dlpack(spec) # JAX import jax.dlpack array = jax.dlpack.from_dlpack(spec) # TensorFlow (v2.15+) import tensorflow as tf tensor = tf.experimental.dlpack.from_dlpack(spec) # CuPy import cupy array = cupy.from_dlpack(spec) Memory Efficiency ~~~~~~~~~~~~~~~~~ DLPack creates a **view** of the spectrogram data without copying: .. code-block:: python spec = sg.compute_mel_db_spectrogram(samples, params, mel_params, db_params) tensor = torch.from_dlpack(spec) # Both share the same underlying memory spec.data[0, 0] = 999.0 print(tensor[0, 0]) # 999.0 **Important**: Keep the original ``Spectrogram`` object alive while using the tensor, as they share memory. Supported Types ~~~~~~~~~~~~~~~ The DLPack protocol works with all spectrogram types: - ``Spectrogram`` (from any ``compute_*`` function) - ``Chromagram`` (from ``compute_chromagram``) - ``Fft2dResult`` (2D FFT results) PyTorch Integration ------------------- For enhanced ergonomics with PyTorch, import the ``spectrograms.torch`` module: .. code-block:: python import spectrograms as sg import spectrograms.torch # Adds .to_torch() method Basic Conversion ~~~~~~~~~~~~~~~~ .. code-block:: python spec = sg.compute_mel_power_spectrogram(samples, params, mel_params) # Simple conversion (returns torch.Tensor) tensor = spec.to_torch() # GPU conversion tensor = spec.to_torch(device='cuda') # With specific dtype tensor = spec.to_torch(device='cpu', dtype=torch.float32) With Metadata ~~~~~~~~~~~~~ Preserve frequency/time axes and computation parameters: .. code-block:: python result = spec.to_torch(device='cuda', with_metadata=True) # Access tensor and metadata result.tensor # torch.Tensor on GPU result.frequencies # np.ndarray of frequency values result.times # np.ndarray of time values result.params # SpectrogramParams object result.shape # (n_bins, n_frames) result.db_range # (min_db, max_db) if applicable # Move to different device result_cpu = result.cpu() result_gpu = result.cuda(device=0) Batching for Training ~~~~~~~~~~~~~~~~~~~~~ Create batched tensors from multiple spectrograms: .. code-block:: python import spectrograms.torch as sgt # Compute spectrograms for batch of audio samples specs = [ sg.compute_mel_power_spectrogram(audio, params, mel_params) for audio in audio_batch ] # Stack into batch tensor (batch_size, n_mels, n_frames) batch_tensor = sgt.batch(specs, device='cuda') # With automatic padding for variable lengths batch_tensor = sgt.batch(specs, device='cuda', pad=True) # Preserve metadata for each sample batch_tensor, metadata = sgt.batch_with_metadata(specs, device='cuda') Neural Network Pipeline ~~~~~~~~~~~~~~~~~~~~~~~~ Complete example for audio classification: .. code-block:: python import torch import torch.nn as nn import spectrograms as sg import spectrograms.torch as sgt class AudioClassifier(nn.Module): def __init__(self, n_classes=10): super().__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.pool = nn.AdaptiveAvgPool2d((4, 4)) self.fc = nn.Linear(64 * 4 * 4, n_classes) def forward(self, x): # x: (batch, n_mels, n_frames) x = x.unsqueeze(1) # Add channel dim: (batch, 1, n_mels, n_frames) x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = self.pool(x) x = x.flatten(1) return self.fc(x) # Training setup model = AudioClassifier(n_classes=10).cuda() stft = sg.StftParams(n_fft=512, hop_size=256, window=sg.WindowType.hanning) params = sg.SpectrogramParams(stft, sample_rate=16000.0) mel_params = sg.MelParams(n_mels=128, f_min=0.0, f_max=8000.0) db_params = sg.LogParams(floor_db=-80.0) # Training loop for audio_batch, labels in dataloader: # Compute spectrograms on CPU (fast with Rust backend) specs = [ sg.compute_mel_db_spectrogram(audio, params, mel_params, db_params) for audio in audio_batch ] # Batch and transfer to GPU inputs = sgt.batch(specs, device='cuda', dtype=torch.float32) labels = labels.cuda() # Forward pass outputs = model(inputs) loss = nn.CrossEntropyLoss()(outputs, labels) # Backward and optimize loss.backward() optimizer.step() JAX Integration --------------- For JAX workflows, import the ``spectrograms.jax`` module: .. code-block:: python import spectrograms as sg import spectrograms.jax # Adds .to_jax() method Basic Conversion ~~~~~~~~~~~~~~~~ .. code-block:: python spec = sg.compute_mel_power_spectrogram(samples, params, mel_params) # Simple conversion (returns jax.Array) array = spec.to_jax() # GPU placement array = spec.to_jax(device='gpu') # CPU placement array = spec.to_jax(device='cpu') With Metadata ~~~~~~~~~~~~~ .. code-block:: python result = spec.to_jax(device='gpu', with_metadata=True) result.array # jax.Array on GPU result.frequencies # np.ndarray result.times # np.ndarray result.params # SpectrogramParams result.shape # (n_bins, n_frames) # Move between devices result_cpu = result.cpu() result_gpu = result.gpu() Batching ~~~~~~~~ .. code-block:: python import spectrograms.jax as sgj specs = [ sg.compute_mel_power_spectrogram(audio, params, mel_params) for audio in audio_batch ] # Stack into batch array batch_array = sgj.batch(specs, device='gpu') JAX Training Example ~~~~~~~~~~~~~~~~~~~~ .. code-block:: python import jax import jax.numpy as jnp import spectrograms as sg import spectrograms.jax as sgj from flax import linen as nn class AudioCNN(nn.Module): n_classes: int = 10 @nn.compact def __call__(self, x): # x: (batch, n_mels, n_frames) x = x[..., jnp.newaxis] # Add channel: (batch, n_mels, n_frames, 1) x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = jnp.mean(x, axis=(1, 2)) # Global average pooling return nn.Dense(self.n_classes)(x) # Initialize model model = AudioCNN(n_classes=10) rng = jax.random.PRNGKey(0) params_model = model.init(rng, jnp.ones((1, 128, 100))) # Training step @jax.jit def train_step(params_model, inputs, labels): def loss_fn(params_model): logits = model.apply(params_model, inputs) return jnp.mean(optax.softmax_cross_entropy(logits, labels)) loss, grads = jax.value_and_grad(loss_fn)(params_model) return loss, grads # Compute spectrograms and convert to JAX stft = sg.StftParams(n_fft=512, hop_size=256, window=sg.WindowType.hanning) spec_params = sg.SpectrogramParams(stft, sample_rate=16000.0) mel_params = sg.MelParams(n_mels=128, f_min=0.0, f_max=8000.0) for audio_batch, labels in dataloader: specs = [ sg.compute_mel_power_spectrogram(audio, spec_params, mel_params) for audio in audio_batch ] inputs = sgj.batch(specs, device='gpu') loss, grads = train_step(params_model, inputs, labels) API Reference ------------- Spectrogram Methods ~~~~~~~~~~~~~~~~~~~ All ``Spectrogram`` and ``Chromagram`` objects provide: **DLPack Protocol:** .. code-block:: python __dlpack__(*, stream=None, max_version=None, dl_device=None, copy=None) __dlpack_device__() -> tuple[int, int] **PyTorch** (when ``spectrograms.torch`` is imported): .. code-block:: python .to_torch(device='cpu', with_metadata=False, dtype=None) -> torch.Tensor | TorchSpectrogram **JAX** (when ``spectrograms.jax`` is imported): .. code-block:: python .to_jax(device='cpu', with_metadata=False) -> jax.Array | JaxSpectrogram Module Functions ~~~~~~~~~~~~~~~~ **spectrograms.torch:** - ``batch(specs, device='cpu', dtype=None, pad=False) -> torch.Tensor`` - ``batch_with_metadata(specs, device='cpu', dtype=None, pad=False) -> tuple[torch.Tensor, list[dict]]`` **spectrograms.jax:** - ``batch(specs, device='cpu', pad=False) -> jax.Array`` - ``batch_with_metadata(specs, device='cpu', pad=False) -> tuple[jax.Array, list[dict]]`` Best Practices -------------- Memory Management ~~~~~~~~~~~~~~~~~ When using DLPack, the tensor shares memory with the original spectrogram: .. code-block:: python # ✓ Good: Keep spectrogram alive spec = sg.compute_mel_power_spectrogram(samples, params, mel_params) tensor = torch.from_dlpack(spec) process(tensor) # spec is still in scope # ✗ Bad: Spectrogram may be garbage collected tensor = torch.from_dlpack( sg.compute_mel_power_spectrogram(samples, params, mel_params) ) # Original data may be freed! If you need independent memory, use ``copy=True``: .. code-block:: python spec = sg.compute_mel_power_spectrogram(samples, params, mel_params) tensor = torch.from_dlpack(spec.__dlpack__(copy=True)) Performance Tips ~~~~~~~~~~~~~~~~ 1. **Batch on CPU, transfer once**: Compute spectrograms with CPU threads, then batch transfer to GPU .. code-block:: python # Parallel CPU computation (GIL-free) specs = [compute_spec(audio) for audio in batch] # Single GPU transfer gpu_batch = sgt.batch(specs, device='cuda') 2. **Reuse plans for batches**: Use planner for consistent parameters .. code-block:: python planner = sg.SpectrogramPlanner() plan = planner.mel_db_plan(params, mel_params, db_params) for batch in dataloader: specs = [plan.compute(audio) for audio in batch] gpu_batch = sgt.batch(specs, device='cuda') 3. **Choose appropriate dtype**: Use ``float32`` for ML to reduce memory .. code-block:: python tensor = spec.to_torch(device='cuda', dtype=torch.float32) Device Support ~~~~~~~~~~~~~~ **Current limitations:** - Spectrograms are computed on **CPU only** (using FFTW) - DLPack transfers to GPU happen **after** computation - GPU FFT computation is not yet supported This is typically fine because: - CPU spectrogram computation is very fast (Rust + FFTW) - GPU is used for model training/inference where it's most beneficial - Data transfer is a one-time cost per batch See Also -------- - :doc:`performance` - General performance optimization - :doc:`planner_guide` - Batch processing with reusable plans - `DLPack Specification `_ - `PyTorch DLPack docs `_ - `JAX DLPack docs `_