273 lines
8.9 KiB
Python
273 lines
8.9 KiB
Python
"""
|
|
SpaceTimeArray: A memory-efficient array that automatically spills to disk.
|
|
"""
|
|
|
|
import os
|
|
import pickle
|
|
import tempfile
|
|
import weakref
|
|
from typing import Any, Iterator, Optional, Union, List
|
|
from collections.abc import MutableSequence
|
|
|
|
from sqrtspace_spacetime.config import config
|
|
from sqrtspace_spacetime.memory import monitor, MemoryPressureLevel
|
|
|
|
|
|
class SpaceTimeArray(MutableSequence):
|
|
"""
|
|
A list-like container that automatically manages memory usage by
|
|
spilling to disk when threshold is reached.
|
|
"""
|
|
|
|
_instances = weakref.WeakSet()
|
|
|
|
def __init__(self, threshold: Optional[Union[int, str]] = None, storage_path: Optional[str] = None):
|
|
"""
|
|
Initialize SpaceTimeArray.
|
|
|
|
Args:
|
|
threshold: Number of items to keep in memory (None or 'auto' for automatic)
|
|
storage_path: Path for external storage (None for temp)
|
|
"""
|
|
if threshold == 'auto' or threshold is None:
|
|
self.threshold = config.calculate_chunk_size(10000)
|
|
else:
|
|
self.threshold = int(threshold)
|
|
self.storage_path = storage_path or config.external_storage_path
|
|
|
|
self._hot_data: List[Any] = []
|
|
self._cold_indices: set = set()
|
|
self._cold_storage: Optional[str] = None
|
|
self._length = 0
|
|
self._cold_file_handle = None
|
|
|
|
# Register for memory pressure handling
|
|
SpaceTimeArray._instances.add(self)
|
|
|
|
def __len__(self) -> int:
|
|
return self._length
|
|
|
|
def __getitem__(self, index: Union[int, slice]) -> Any:
|
|
if isinstance(index, slice):
|
|
return [self[i] for i in range(*index.indices(len(self)))]
|
|
|
|
if index < 0:
|
|
index += self._length
|
|
|
|
if not 0 <= index < self._length:
|
|
raise IndexError("list index out of range")
|
|
|
|
# Check if in hot storage
|
|
if index not in self._cold_indices:
|
|
hot_index = index - len(self._cold_indices)
|
|
return self._hot_data[hot_index]
|
|
|
|
# Load from cold storage
|
|
return self._load_from_cold(index)
|
|
|
|
def __setitem__(self, index: Union[int, slice], value: Any) -> None:
|
|
if isinstance(index, slice):
|
|
for i, v in zip(range(*index.indices(len(self))), value):
|
|
self[i] = v
|
|
return
|
|
|
|
if index < 0:
|
|
index += self._length
|
|
|
|
if not 0 <= index < self._length:
|
|
raise IndexError("list assignment index out of range")
|
|
|
|
if index not in self._cold_indices:
|
|
hot_index = index - len(self._cold_indices)
|
|
self._hot_data[hot_index] = value
|
|
else:
|
|
# Update cold storage
|
|
self._update_cold(index, value)
|
|
|
|
def __delitem__(self, index: Union[int, slice]) -> None:
|
|
if isinstance(index, slice):
|
|
# Delete in reverse order to maintain indices
|
|
for i in reversed(range(*index.indices(len(self)))):
|
|
del self[i]
|
|
return
|
|
|
|
if index < 0:
|
|
index += self._length
|
|
|
|
if not 0 <= index < self._length:
|
|
raise IndexError("list index out of range")
|
|
|
|
# This is complex with cold storage, so we'll reload everything
|
|
all_data = list(self)
|
|
del all_data[index]
|
|
self.clear()
|
|
self.extend(all_data)
|
|
|
|
def insert(self, index: int, value: Any) -> None:
|
|
if index < 0:
|
|
index += self._length
|
|
index = max(0, min(index, self._length))
|
|
|
|
# Simple implementation: reload all, insert, save back
|
|
all_data = list(self)
|
|
all_data.insert(index, value)
|
|
self.clear()
|
|
self.extend(all_data)
|
|
|
|
def append(self, value: Any) -> None:
|
|
"""Append an item to the array."""
|
|
self._hot_data.append(value)
|
|
self._length += 1
|
|
|
|
# Check if we need to spill
|
|
if len(self._hot_data) > self.threshold:
|
|
self._check_and_spill()
|
|
|
|
def extend(self, iterable) -> None:
|
|
"""Extend array with items from iterable."""
|
|
for item in iterable:
|
|
self.append(item)
|
|
|
|
def clear(self) -> None:
|
|
"""Remove all items."""
|
|
self._hot_data.clear()
|
|
self._cold_indices.clear()
|
|
self._length = 0
|
|
|
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
|
os.unlink(self._cold_storage)
|
|
self._cold_storage = None
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
"""Iterate over all items."""
|
|
# First yield cold items
|
|
for idx in sorted(self._cold_indices):
|
|
yield self._load_from_cold(idx)
|
|
|
|
# Then hot items
|
|
for item in self._hot_data:
|
|
yield item
|
|
|
|
def _check_and_spill(self) -> None:
|
|
"""Check memory pressure and spill to disk if needed."""
|
|
# Check memory pressure
|
|
pressure = monitor.check_memory_pressure()
|
|
|
|
if pressure >= MemoryPressureLevel.MEDIUM or len(self._hot_data) > self.threshold:
|
|
self._spill_to_disk()
|
|
|
|
def _spill_to_disk(self) -> None:
|
|
"""Spill oldest items to disk."""
|
|
if not self._cold_storage:
|
|
fd, self._cold_storage = tempfile.mkstemp(
|
|
suffix='.spacetime',
|
|
dir=self.storage_path
|
|
)
|
|
os.close(fd)
|
|
|
|
# Determine how many items to spill
|
|
spill_count = len(self._hot_data) // 2
|
|
|
|
# Load existing cold data
|
|
cold_data = {}
|
|
if os.path.exists(self._cold_storage):
|
|
with open(self._cold_storage, 'rb') as f:
|
|
try:
|
|
cold_data = pickle.load(f)
|
|
except EOFError:
|
|
cold_data = {}
|
|
|
|
# Move items to cold storage
|
|
current_cold_size = len(self._cold_indices)
|
|
for i in range(spill_count):
|
|
cold_data[current_cold_size + i] = self._hot_data[i]
|
|
self._cold_indices.add(current_cold_size + i)
|
|
|
|
# Remove from hot storage
|
|
self._hot_data = self._hot_data[spill_count:]
|
|
|
|
# Save cold data
|
|
with open(self._cold_storage, 'wb') as f:
|
|
pickle.dump(cold_data, f)
|
|
|
|
def _load_from_cold(self, index: int) -> Any:
|
|
"""Load an item from cold storage."""
|
|
if not self._cold_storage or not os.path.exists(self._cold_storage):
|
|
raise IndexError(f"Cold storage index {index} not found")
|
|
|
|
with open(self._cold_storage, 'rb') as f:
|
|
cold_data = pickle.load(f)
|
|
|
|
return cold_data.get(index)
|
|
|
|
def _update_cold(self, index: int, value: Any) -> None:
|
|
"""Update an item in cold storage."""
|
|
if not self._cold_storage:
|
|
return
|
|
|
|
with open(self._cold_storage, 'rb') as f:
|
|
cold_data = pickle.load(f)
|
|
|
|
cold_data[index] = value
|
|
|
|
with open(self._cold_storage, 'wb') as f:
|
|
pickle.dump(cold_data, f)
|
|
|
|
def memory_usage(self) -> int:
|
|
"""Estimate memory usage in bytes."""
|
|
# Rough estimate - actual usage may vary
|
|
return len(self._hot_data) * 50 # Assume 50 bytes per item average
|
|
|
|
def spill_to_disk(self, path: Optional[str] = None) -> None:
|
|
"""Force spill all data to disk."""
|
|
if path:
|
|
self.storage_path = path
|
|
|
|
while self._hot_data:
|
|
self._spill_to_disk()
|
|
|
|
def load_to_memory(self) -> None:
|
|
"""Load all data back to memory."""
|
|
if not self._cold_storage or not self._cold_indices:
|
|
return
|
|
|
|
# Load cold data
|
|
with open(self._cold_storage, 'rb') as f:
|
|
cold_data = pickle.load(f)
|
|
|
|
# Rebuild array in correct order
|
|
all_data = []
|
|
cold_count = 0
|
|
hot_count = 0
|
|
|
|
for i in range(self._length):
|
|
if i in self._cold_indices:
|
|
all_data.append(cold_data[i])
|
|
cold_count += 1
|
|
else:
|
|
all_data.append(self._hot_data[hot_count])
|
|
hot_count += 1
|
|
|
|
# Reset storage
|
|
self._hot_data = all_data
|
|
self._cold_indices.clear()
|
|
|
|
if os.path.exists(self._cold_storage):
|
|
os.unlink(self._cold_storage)
|
|
self._cold_storage = None
|
|
|
|
def __del__(self):
|
|
"""Clean up temporary files."""
|
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
|
try:
|
|
os.unlink(self._cold_storage)
|
|
except:
|
|
pass
|
|
|
|
@classmethod
|
|
def handle_memory_pressure(cls, level: MemoryPressureLevel) -> None:
|
|
"""Class method to handle memory pressure for all instances."""
|
|
if level >= MemoryPressureLevel.HIGH:
|
|
for instance in cls._instances:
|
|
if instance._hot_data:
|
|
instance._spill_to_disk() |