sqrtspace-python/src/sqrtspace_spacetime/ml/checkpointing.py
2025-07-20 04:11:04 -04:00

286 lines
9.4 KiB
Python

"""
Gradient checkpointing utilities for memory-efficient training.
"""
import math
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
# Framework imports
try:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
import tensorflow as tf
HAS_TF = True
except ImportError:
HAS_TF = False
class CheckpointStrategy(Enum):
"""Checkpointing strategies."""
SQRT_N = "sqrt_n" # Checkpoint every √n layers
UNIFORM = "uniform" # Uniform intervals
MEMORY_BASED = "memory" # Based on memory usage
SELECTIVE = "selective" # Only expensive layers
class GradientCheckpointer:
"""
Gradient checkpointing for memory-efficient training.
Implements Williams' √n strategy for optimal space-time tradeoff.
"""
def __init__(self, strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N):
self.strategy = strategy
def apply_checkpointing(self,
model: Any,
checkpoint_layers: Optional[List[str]] = None) -> Any:
"""
Apply gradient checkpointing to model.
Args:
model: Neural network model
checkpoint_layers: Specific layers to checkpoint (None for auto)
Returns:
Model with checkpointing applied
"""
if HAS_TORCH and isinstance(model, nn.Module):
return self._apply_torch_checkpointing(model, checkpoint_layers)
elif HAS_TF:
return self._apply_tf_checkpointing(model, checkpoint_layers)
else:
print("Warning: No supported framework found for checkpointing")
return model
def _apply_torch_checkpointing(self,
model: nn.Module,
checkpoint_layers: Optional[List[str]] = None) -> nn.Module:
"""Apply checkpointing to PyTorch model."""
if checkpoint_layers is None:
checkpoint_layers = self._select_checkpoint_layers_torch(model)
# Wrap forward methods of selected layers
for name, module in model.named_modules():
if name in checkpoint_layers:
self._wrap_module_torch(module)
return model
def _wrap_module_torch(self, module: nn.Module) -> None:
"""Wrap PyTorch module with gradient checkpointing."""
original_forward = module.forward
def checkpointed_forward(*args, **kwargs):
# Use PyTorch's checkpoint function
if module.training:
return checkpoint(original_forward, *args, **kwargs)
else:
return original_forward(*args, **kwargs)
module.forward = checkpointed_forward
def _apply_tf_checkpointing(self,
model: Any,
checkpoint_layers: Optional[List[str]] = None) -> Any:
"""Apply checkpointing to TensorFlow model."""
if checkpoint_layers is None:
checkpoint_layers = self._select_checkpoint_layers_tf(model)
# TensorFlow implementation
# Note: TF2 has different checkpointing mechanism
print(f"TensorFlow checkpointing selected {len(checkpoint_layers)} layers")
return model
def _select_checkpoint_layers_torch(self, model: nn.Module) -> List[str]:
"""Select layers to checkpoint for PyTorch model."""
layers = []
# Get all layers
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf modules
layers.append((name, module))
if self.strategy == CheckpointStrategy.SQRT_N:
# Select √n evenly spaced layers
n = len(layers)
if n == 0:
return []
interval = max(1, int(math.sqrt(n)))
selected = []
for i in range(0, n, interval):
name, module = layers[i]
if self._can_checkpoint_module(module):
selected.append(name)
return selected
elif self.strategy == CheckpointStrategy.MEMORY_BASED:
# Select layers with large activation memory
memory_layers = []
for name, module in layers:
memory = self._estimate_module_memory(module)
memory_layers.append((name, memory))
# Sort by memory and select top √n
memory_layers.sort(key=lambda x: x[1], reverse=True)
n_checkpoint = max(1, int(math.sqrt(len(memory_layers))))
return [name for name, _ in memory_layers[:n_checkpoint]]
else:
# Default: checkpoint all eligible layers
return [name for name, module in layers if self._can_checkpoint_module(module)]
def _select_checkpoint_layers_tf(self, model: Any) -> List[str]:
"""Select layers to checkpoint for TensorFlow model."""
if not hasattr(model, 'layers'):
return []
layers = [(layer.name, layer) for layer in model.layers]
if self.strategy == CheckpointStrategy.SQRT_N:
n = len(layers)
interval = max(1, int(math.sqrt(n)))
selected = []
for i in range(0, n, interval):
name, layer = layers[i]
selected.append(name)
return selected
return [name for name, _ in layers]
def _can_checkpoint_module(self, module: Any) -> bool:
"""Check if module can be safely checkpointed."""
if HAS_TORCH:
# Avoid checkpointing modules with randomness
no_checkpoint = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return not isinstance(module, no_checkpoint)
return True
def _estimate_module_memory(self, module: Any) -> int:
"""Estimate memory usage of module activations."""
if HAS_TORCH and isinstance(module, nn.Module):
# Estimate based on output size
if isinstance(module, nn.Linear):
return module.out_features * 4 # FP32
elif isinstance(module, nn.Conv2d):
# Rough estimate
return module.out_channels * 100 * 100 * 4
else:
# Default estimate
params = sum(p.numel() for p in module.parameters())
return params * 4
return 0
@staticmethod
def create_checkpoint_segments(model: Any,
n_segments: Optional[int] = None) -> List[List[str]]:
"""
Create checkpoint segments using √n strategy.
Args:
model: Neural network model
n_segments: Number of segments (None for √n)
Returns:
List of layer name segments
"""
# Get all layers
if HAS_TORCH and isinstance(model, nn.Module):
all_layers = [name for name, _ in model.named_modules()
if len(list(_.children())) == 0]
elif HAS_TF and hasattr(model, 'layers'):
all_layers = [layer.name for layer in model.layers]
else:
return []
n = len(all_layers)
if n == 0:
return []
# Use √n segments by default
if n_segments is None:
n_segments = max(1, int(math.sqrt(n)))
# Create segments
segment_size = max(1, n // n_segments)
segments = []
for i in range(0, n, segment_size):
segment = all_layers[i:i + segment_size]
if segment:
segments.append(segment)
return segments
def checkpoint_sequential(modules: List[Any],
input: Any,
segments: Optional[int] = None) -> Any:
"""
Checkpoint a sequential model using √n segments.
Args:
modules: List of modules to execute sequentially
input: Input tensor
segments: Number of checkpoint segments (None for √n)
Returns:
Output tensor
"""
if not HAS_TORCH:
# Fallback to normal execution
x = input
for module in modules:
x = module(x)
return x
n = len(modules)
if n == 0:
return input
# Use √n segments
if segments is None:
segments = max(1, int(math.sqrt(n)))
segment_size = max(1, n // segments)
# Execute with checkpointing
x = input
for i in range(0, n, segment_size):
segment = modules[i:i + segment_size]
if len(segment) == 1:
# Single module
if modules[0].training:
x = checkpoint(segment[0], x)
else:
x = segment[0](x)
else:
# Multiple modules - create sequential wrapper
def run_segment(x, *modules):
for module in modules:
x = module(x)
return x
if modules[0].training:
x = checkpoint(run_segment, x, *segment)
else:
x = run_segment(x, *segment)
return x