Initial
This commit is contained in:
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Gradient checkpointing utilities for memory-efficient training.
|
||||
"""
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
# Framework imports
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
HAS_TF = True
|
||||
except ImportError:
|
||||
HAS_TF = False
|
||||
|
||||
|
||||
class CheckpointStrategy(Enum):
|
||||
"""Checkpointing strategies."""
|
||||
SQRT_N = "sqrt_n" # Checkpoint every √n layers
|
||||
UNIFORM = "uniform" # Uniform intervals
|
||||
MEMORY_BASED = "memory" # Based on memory usage
|
||||
SELECTIVE = "selective" # Only expensive layers
|
||||
|
||||
|
||||
class GradientCheckpointer:
|
||||
"""
|
||||
Gradient checkpointing for memory-efficient training.
|
||||
|
||||
Implements Williams' √n strategy for optimal space-time tradeoff.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N):
|
||||
self.strategy = strategy
|
||||
|
||||
def apply_checkpointing(self,
|
||||
model: Any,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||
"""
|
||||
Apply gradient checkpointing to model.
|
||||
|
||||
Args:
|
||||
model: Neural network model
|
||||
checkpoint_layers: Specific layers to checkpoint (None for auto)
|
||||
|
||||
Returns:
|
||||
Model with checkpointing applied
|
||||
"""
|
||||
if HAS_TORCH and isinstance(model, nn.Module):
|
||||
return self._apply_torch_checkpointing(model, checkpoint_layers)
|
||||
elif HAS_TF:
|
||||
return self._apply_tf_checkpointing(model, checkpoint_layers)
|
||||
else:
|
||||
print("Warning: No supported framework found for checkpointing")
|
||||
return model
|
||||
|
||||
def _apply_torch_checkpointing(self,
|
||||
model: nn.Module,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> nn.Module:
|
||||
"""Apply checkpointing to PyTorch model."""
|
||||
if checkpoint_layers is None:
|
||||
checkpoint_layers = self._select_checkpoint_layers_torch(model)
|
||||
|
||||
# Wrap forward methods of selected layers
|
||||
for name, module in model.named_modules():
|
||||
if name in checkpoint_layers:
|
||||
self._wrap_module_torch(module)
|
||||
|
||||
return model
|
||||
|
||||
def _wrap_module_torch(self, module: nn.Module) -> None:
|
||||
"""Wrap PyTorch module with gradient checkpointing."""
|
||||
original_forward = module.forward
|
||||
|
||||
def checkpointed_forward(*args, **kwargs):
|
||||
# Use PyTorch's checkpoint function
|
||||
if module.training:
|
||||
return checkpoint(original_forward, *args, **kwargs)
|
||||
else:
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
module.forward = checkpointed_forward
|
||||
|
||||
def _apply_tf_checkpointing(self,
|
||||
model: Any,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||
"""Apply checkpointing to TensorFlow model."""
|
||||
if checkpoint_layers is None:
|
||||
checkpoint_layers = self._select_checkpoint_layers_tf(model)
|
||||
|
||||
# TensorFlow implementation
|
||||
# Note: TF2 has different checkpointing mechanism
|
||||
print(f"TensorFlow checkpointing selected {len(checkpoint_layers)} layers")
|
||||
|
||||
return model
|
||||
|
||||
def _select_checkpoint_layers_torch(self, model: nn.Module) -> List[str]:
|
||||
"""Select layers to checkpoint for PyTorch model."""
|
||||
layers = []
|
||||
|
||||
# Get all layers
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # Leaf modules
|
||||
layers.append((name, module))
|
||||
|
||||
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||
# Select √n evenly spaced layers
|
||||
n = len(layers)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
interval = max(1, int(math.sqrt(n)))
|
||||
selected = []
|
||||
|
||||
for i in range(0, n, interval):
|
||||
name, module = layers[i]
|
||||
if self._can_checkpoint_module(module):
|
||||
selected.append(name)
|
||||
|
||||
return selected
|
||||
|
||||
elif self.strategy == CheckpointStrategy.MEMORY_BASED:
|
||||
# Select layers with large activation memory
|
||||
memory_layers = []
|
||||
|
||||
for name, module in layers:
|
||||
memory = self._estimate_module_memory(module)
|
||||
memory_layers.append((name, memory))
|
||||
|
||||
# Sort by memory and select top √n
|
||||
memory_layers.sort(key=lambda x: x[1], reverse=True)
|
||||
n_checkpoint = max(1, int(math.sqrt(len(memory_layers))))
|
||||
|
||||
return [name for name, _ in memory_layers[:n_checkpoint]]
|
||||
|
||||
else:
|
||||
# Default: checkpoint all eligible layers
|
||||
return [name for name, module in layers if self._can_checkpoint_module(module)]
|
||||
|
||||
def _select_checkpoint_layers_tf(self, model: Any) -> List[str]:
|
||||
"""Select layers to checkpoint for TensorFlow model."""
|
||||
if not hasattr(model, 'layers'):
|
||||
return []
|
||||
|
||||
layers = [(layer.name, layer) for layer in model.layers]
|
||||
|
||||
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||
n = len(layers)
|
||||
interval = max(1, int(math.sqrt(n)))
|
||||
|
||||
selected = []
|
||||
for i in range(0, n, interval):
|
||||
name, layer = layers[i]
|
||||
selected.append(name)
|
||||
|
||||
return selected
|
||||
|
||||
return [name for name, _ in layers]
|
||||
|
||||
def _can_checkpoint_module(self, module: Any) -> bool:
|
||||
"""Check if module can be safely checkpointed."""
|
||||
if HAS_TORCH:
|
||||
# Avoid checkpointing modules with randomness
|
||||
no_checkpoint = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
|
||||
return not isinstance(module, no_checkpoint)
|
||||
return True
|
||||
|
||||
def _estimate_module_memory(self, module: Any) -> int:
|
||||
"""Estimate memory usage of module activations."""
|
||||
if HAS_TORCH and isinstance(module, nn.Module):
|
||||
# Estimate based on output size
|
||||
if isinstance(module, nn.Linear):
|
||||
return module.out_features * 4 # FP32
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
# Rough estimate
|
||||
return module.out_channels * 100 * 100 * 4
|
||||
else:
|
||||
# Default estimate
|
||||
params = sum(p.numel() for p in module.parameters())
|
||||
return params * 4
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def create_checkpoint_segments(model: Any,
|
||||
n_segments: Optional[int] = None) -> List[List[str]]:
|
||||
"""
|
||||
Create checkpoint segments using √n strategy.
|
||||
|
||||
Args:
|
||||
model: Neural network model
|
||||
n_segments: Number of segments (None for √n)
|
||||
|
||||
Returns:
|
||||
List of layer name segments
|
||||
"""
|
||||
# Get all layers
|
||||
if HAS_TORCH and isinstance(model, nn.Module):
|
||||
all_layers = [name for name, _ in model.named_modules()
|
||||
if len(list(_.children())) == 0]
|
||||
elif HAS_TF and hasattr(model, 'layers'):
|
||||
all_layers = [layer.name for layer in model.layers]
|
||||
else:
|
||||
return []
|
||||
|
||||
n = len(all_layers)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
# Use √n segments by default
|
||||
if n_segments is None:
|
||||
n_segments = max(1, int(math.sqrt(n)))
|
||||
|
||||
# Create segments
|
||||
segment_size = max(1, n // n_segments)
|
||||
segments = []
|
||||
|
||||
for i in range(0, n, segment_size):
|
||||
segment = all_layers[i:i + segment_size]
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def checkpoint_sequential(modules: List[Any],
|
||||
input: Any,
|
||||
segments: Optional[int] = None) -> Any:
|
||||
"""
|
||||
Checkpoint a sequential model using √n segments.
|
||||
|
||||
Args:
|
||||
modules: List of modules to execute sequentially
|
||||
input: Input tensor
|
||||
segments: Number of checkpoint segments (None for √n)
|
||||
|
||||
Returns:
|
||||
Output tensor
|
||||
"""
|
||||
if not HAS_TORCH:
|
||||
# Fallback to normal execution
|
||||
x = input
|
||||
for module in modules:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
n = len(modules)
|
||||
if n == 0:
|
||||
return input
|
||||
|
||||
# Use √n segments
|
||||
if segments is None:
|
||||
segments = max(1, int(math.sqrt(n)))
|
||||
|
||||
segment_size = max(1, n // segments)
|
||||
|
||||
# Execute with checkpointing
|
||||
x = input
|
||||
for i in range(0, n, segment_size):
|
||||
segment = modules[i:i + segment_size]
|
||||
|
||||
if len(segment) == 1:
|
||||
# Single module
|
||||
if modules[0].training:
|
||||
x = checkpoint(segment[0], x)
|
||||
else:
|
||||
x = segment[0](x)
|
||||
else:
|
||||
# Multiple modules - create sequential wrapper
|
||||
def run_segment(x, *modules):
|
||||
for module in modules:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
if modules[0].training:
|
||||
x = checkpoint(run_segment, x, *segment)
|
||||
else:
|
||||
x = run_segment(x, *segment)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user