sqrtspace-python/src/sqrtspace_spacetime/checkpoint/decorators.py
2025-07-20 04:11:04 -04:00

295 lines
10 KiB
Python

"""
Decorators for automatic checkpointing.
"""
import functools
import inspect
from typing import Any, Callable, List, Optional, Union
from sqrtspace_spacetime.checkpoint.manager import (
CheckpointManager,
CheckpointConfig,
CheckpointStrategy
)
def auto_checkpoint(
total_iterations: Optional[int] = None,
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE,
checkpoint_vars: Optional[List[str]] = None,
checkpoint_dir: str = ".checkpoints",
verbose: bool = True
) -> Callable:
"""
Decorator to automatically checkpoint long-running functions.
Args:
total_iterations: Total iterations (for √n strategy)
strategy: Checkpointing strategy
checkpoint_vars: Variables to checkpoint (None for auto-detect)
checkpoint_dir: Directory for checkpoints
verbose: Print checkpoint info
Example:
@auto_checkpoint(total_iterations=1000000)
def process_data(data):
for i, item in enumerate(data):
# Process item
checkpoint_state = {'i': i, 'processed': processed}
yield checkpoint_state
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Create checkpoint manager
config = CheckpointConfig(
strategy=strategy,
checkpoint_dir=checkpoint_dir,
verbose=verbose
)
manager = CheckpointManager(config=config)
if total_iterations:
manager.set_total_iterations(total_iterations)
# Check if resuming from checkpoint
resume_checkpoint = kwargs.pop('resume_checkpoint', None)
if resume_checkpoint:
state, metadata = manager.load(resume_checkpoint)
print(f"Resuming from checkpoint at iteration {metadata.iteration}")
# Update function state
if 'update_state' in kwargs:
kwargs['update_state'](state)
# Wrap generator functions
if inspect.isgeneratorfunction(func):
return _checkpoint_generator(func, manager, checkpoint_vars,
*args, **kwargs)
else:
# For regular functions, checkpoint based on time/memory
result = None
for i in range(total_iterations or 1):
if manager.should_checkpoint(i):
# Get state from function
if hasattr(func, 'get_checkpoint_state'):
state = func.get_checkpoint_state()
else:
state = {'iteration': i, 'args': args, 'kwargs': kwargs}
manager.save(state)
# Execute function
result = func(*args, **kwargs)
# Break if function doesn't need iterations
if total_iterations is None:
break
return result
# Store checkpoint info on function
wrapper.checkpoint_manager = None
wrapper.checkpoint_config = CheckpointConfig(
strategy=strategy,
checkpoint_dir=checkpoint_dir
)
return wrapper
return decorator
def checkpoint_method(
checkpoint_attrs: Optional[List[str]] = None,
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE
) -> Callable:
"""
Decorator for checkpointing class methods.
Args:
checkpoint_attrs: Instance attributes to checkpoint
strategy: Checkpointing strategy
Example:
class DataProcessor:
@checkpoint_method(checkpoint_attrs=['processed_count', 'results'])
def process_batch(self, batch):
for item in batch:
self.process_item(item)
self.processed_count += 1
"""
def decorator(method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
# Get or create checkpoint manager
if not hasattr(self, '_checkpoint_manager'):
config = CheckpointConfig(strategy=strategy)
self._checkpoint_manager = CheckpointManager(config=config)
# Execute method with checkpointing
if inspect.isgeneratorfunction(method):
return _checkpoint_method_generator(
method, self, self._checkpoint_manager,
checkpoint_attrs, *args, **kwargs
)
else:
# Regular method
result = method(self, *args, **kwargs)
# Check if checkpoint needed
if self._checkpoint_manager.should_checkpoint():
state = _get_instance_state(self, checkpoint_attrs)
self._checkpoint_manager.save(state)
return result
return wrapper
return decorator
def resumable(
checkpoint_dir: str = ".checkpoints",
auto_resume: bool = True
) -> Callable:
"""
Make function resumable from checkpoints.
Args:
checkpoint_dir: Directory for checkpoints
auto_resume: Automatically resume from latest checkpoint
Example:
@resumable()
def long_computation():
for i in range(1000000):
# Computation
if should_checkpoint(i):
save_checkpoint({'i': i, 'state': state})
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Create checkpoint manager
manager = CheckpointManager(
checkpoint_id=f"{func.__module__}.{func.__name__}",
config=CheckpointConfig(checkpoint_dir=checkpoint_dir)
)
# Check for existing checkpoints
checkpoints = manager.list_checkpoints()
if checkpoints and auto_resume:
latest = checkpoints[-1]
print(f"Found checkpoint at iteration {latest.iteration}")
# Resume from checkpoint
state, metadata = manager.load()
# Call function with resume state
return func(*args, resume_state=state, resume_iteration=metadata.iteration, **kwargs)
else:
# Normal execution
return func(*args, **kwargs)
# Add checkpoint methods to function
wrapper.save_checkpoint = lambda state: manager.save(state)
wrapper.list_checkpoints = lambda: manager.list_checkpoints()
wrapper.cleanup_checkpoints = lambda: manager.cleanup()
return wrapper
return decorator
def _checkpoint_generator(func: Callable, manager: CheckpointManager,
checkpoint_vars: Optional[List[str]],
*args, **kwargs):
"""Handle checkpointing for generator functions."""
generator = func(*args, **kwargs)
iteration = 0
try:
while True:
# Get next value
if iteration == 0 and 'resume_state' in kwargs:
# Skip to resume point
resume_iter = kwargs['resume_state'].get('iteration', 0)
for _ in range(resume_iter):
next(generator)
iteration = resume_iter
value = next(generator)
# Check if checkpoint needed
if manager.should_checkpoint(iteration):
# Get state
if isinstance(value, dict):
state = value
else:
state = {'iteration': iteration, 'value': value}
# Add checkpoint vars if specified
if checkpoint_vars:
frame = inspect.currentframe().f_back
for var in checkpoint_vars:
if var in frame.f_locals:
state[var] = frame.f_locals[var]
manager.save(state)
yield value
iteration += 1
except StopIteration:
pass
finally:
if manager.config.verbose:
stats = manager.get_stats()
print(f"\nCheckpoint stats: {stats.total_checkpoints} checkpoints, "
f"{stats.average_compression:.1f}x compression")
def _checkpoint_method_generator(method: Callable, instance: Any,
manager: CheckpointManager,
checkpoint_attrs: Optional[List[str]],
*args, **kwargs):
"""Handle checkpointing for generator methods."""
generator = method(instance, *args, **kwargs)
iteration = 0
try:
while True:
value = next(generator)
if manager.should_checkpoint(iteration):
state = _get_instance_state(instance, checkpoint_attrs)
state['iteration'] = iteration
manager.save(state)
yield value
iteration += 1
except StopIteration:
pass
def _get_instance_state(instance: Any, attrs: Optional[List[str]] = None) -> dict:
"""Extract state from instance."""
if attrs:
return {attr: getattr(instance, attr, None) for attr in attrs}
else:
# Auto-detect state (exclude private and callable)
state = {}
for attr in dir(instance):
if not attr.startswith('_') and hasattr(instance, attr):
value = getattr(instance, attr)
if not callable(value):
try:
# Test if pickleable
import pickle
pickle.dumps(value)
state[attr] = value
except:
pass
return state