Initial
This commit is contained in:
commit
69b521b549
190
LICENSE
Normal file
190
LICENSE
Normal file
@ -0,0 +1,190 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Copyright 2024 Ubiquity SpaceTime Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
428
README.md
Normal file
428
README.md
Normal file
@ -0,0 +1,428 @@
|
||||
# SqrtSpace SpaceTime for Python
|
||||
|
||||
[](https://badge.fury.io/py/sqrtspace-spacetime)
|
||||
[](https://pypi.org/project/sqrtspace-spacetime/)
|
||||
[](https://github.com/sqrtspace/sqrtspace-python/blob/main/LICENSE)
|
||||
[](https://sqrtspace-spacetime.readthedocs.io/en/latest/?badge=latest)
|
||||
|
||||
Memory-efficient algorithms and data structures for Python using Williams' √n space-time tradeoffs.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sqrtspace-spacetime
|
||||
```
|
||||
|
||||
For ML features:
|
||||
```bash
|
||||
pip install sqrtspace-spacetime[ml]
|
||||
```
|
||||
|
||||
For all features:
|
||||
```bash
|
||||
pip install sqrtspace-spacetime[all]
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
SpaceTime implements theoretical computer science results showing that many algorithms can achieve better memory usage by accepting slightly slower runtime. The key insight is using √n memory instead of n memory, where n is the input size.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Memory-Efficient Collections**: Arrays and dictionaries that automatically spill to disk
|
||||
- **External Algorithms**: Sort and group large datasets using minimal memory
|
||||
- **Streaming Operations**: Process files larger than RAM with elegant API
|
||||
- **Auto-Checkpointing**: Resume long computations from where they left off
|
||||
- **Memory Profiling**: Identify optimization opportunities in your code
|
||||
- **ML Optimizations**: Reduce neural network training memory by up to 90%
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import SpaceTimeArray, external_sort, Stream
|
||||
|
||||
# Memory-efficient array that spills to disk
|
||||
array = SpaceTimeArray(threshold=10000)
|
||||
for i in range(1000000):
|
||||
array.append(i)
|
||||
|
||||
# Sort large datasets with minimal memory
|
||||
huge_list = list(range(10000000, 0, -1))
|
||||
sorted_data = external_sort(huge_list) # Uses only √n memory
|
||||
|
||||
# Stream processing
|
||||
Stream.from_csv('huge_file.csv') \
|
||||
.filter(lambda row: row['value'] > 100) \
|
||||
.map(lambda row: row['value'] * 1.1) \
|
||||
.group_by(lambda row: row['category']) \
|
||||
.to_csv('processed.csv')
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Examples
|
||||
See [`examples/basic_usage.py`](examples/basic_usage.py) for comprehensive examples of:
|
||||
- SpaceTimeArray and SpaceTimeDict usage
|
||||
- External sorting and grouping
|
||||
- Stream processing
|
||||
- Memory profiling
|
||||
- Auto-checkpointing
|
||||
|
||||
### FastAPI Web Application
|
||||
Check out [`examples/fastapi-app/`](examples/fastapi-app/) for a production-ready web application featuring:
|
||||
- Streaming endpoints for large datasets
|
||||
- Server-Sent Events (SSE) for real-time data
|
||||
- Memory-efficient CSV exports
|
||||
- Checkpointed background tasks
|
||||
- ML model serving with memory constraints
|
||||
|
||||
See the [FastAPI example README](examples/fastapi-app/README.md) for detailed documentation.
|
||||
|
||||
### Machine Learning Pipeline
|
||||
Explore [`examples/ml-pipeline/`](examples/ml-pipeline/) for ML-specific patterns:
|
||||
- Training models on datasets larger than RAM
|
||||
- Memory-efficient feature extraction
|
||||
- Checkpointed training loops
|
||||
- Streaming predictions
|
||||
- Integration with PyTorch and TensorFlow
|
||||
|
||||
See the [ML Pipeline README](examples/ml-pipeline/README.md) for complete documentation.
|
||||
|
||||
### Memory-Efficient Collections
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import SpaceTimeArray, SpaceTimeDict
|
||||
|
||||
# Array that automatically manages memory
|
||||
array = SpaceTimeArray(threshold=1000) # Keep 1000 items in memory
|
||||
for i in range(1000000):
|
||||
array.append(f"item_{i}")
|
||||
|
||||
# Dictionary with LRU eviction to disk
|
||||
cache = SpaceTimeDict(threshold=10000)
|
||||
for key, value in huge_dataset:
|
||||
cache[key] = expensive_computation(value)
|
||||
```
|
||||
|
||||
### External Algorithms
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import external_sort, external_groupby
|
||||
|
||||
# Sort 100M items using only ~10K memory
|
||||
data = list(range(100_000_000, 0, -1))
|
||||
sorted_data = external_sort(data)
|
||||
|
||||
# Group by with aggregation
|
||||
sales = [
|
||||
{'store': 'A', 'amount': 100},
|
||||
{'store': 'B', 'amount': 200},
|
||||
# ... millions more
|
||||
]
|
||||
|
||||
by_store = external_groupby(
|
||||
sales,
|
||||
key_func=lambda x: x['store']
|
||||
)
|
||||
|
||||
# Aggregate with minimal memory
|
||||
from sqrtspace_spacetime.algorithms import groupby_sum
|
||||
totals = groupby_sum(
|
||||
sales,
|
||||
key_func=lambda x: x['store'],
|
||||
value_func=lambda x: x['amount']
|
||||
)
|
||||
```
|
||||
|
||||
### Streaming Operations
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import Stream
|
||||
|
||||
# Process large files efficiently
|
||||
stream = Stream.from_csv('sales_2023.csv')
|
||||
.filter(lambda row: row['amount'] > 0)
|
||||
.map(lambda row: {
|
||||
'month': row['date'][:7],
|
||||
'amount': float(row['amount'])
|
||||
})
|
||||
.group_by(lambda row: row['month'])
|
||||
.to_csv('monthly_summary.csv')
|
||||
|
||||
# Chain operations
|
||||
top_products = Stream.from_jsonl('products.jsonl') \
|
||||
.filter(lambda p: p['in_stock']) \
|
||||
.sort(key=lambda p: p['revenue'], reverse=True) \
|
||||
.take(100) \
|
||||
.collect()
|
||||
```
|
||||
|
||||
### Auto-Checkpointing
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||
|
||||
@auto_checkpoint(total_iterations=1000000)
|
||||
def process_large_dataset(data):
|
||||
results = []
|
||||
for i, item in enumerate(data):
|
||||
# Process item
|
||||
result = expensive_computation(item)
|
||||
results.append(result)
|
||||
|
||||
# Yield state for checkpointing
|
||||
yield {'i': i, 'results': results}
|
||||
|
||||
return results
|
||||
|
||||
# Automatically resumes from checkpoint if interrupted
|
||||
results = process_large_dataset(huge_dataset)
|
||||
```
|
||||
|
||||
### Memory Profiling
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.profiler import profile, profile_memory
|
||||
|
||||
@profile(output_file="profile.json")
|
||||
def my_algorithm(data):
|
||||
# Process data
|
||||
return results
|
||||
|
||||
# Get detailed memory analysis
|
||||
result, report = my_algorithm(data)
|
||||
print(report.summary)
|
||||
|
||||
# Simple memory tracking
|
||||
@profile_memory(threshold_mb=100)
|
||||
def memory_heavy_function():
|
||||
# Alerts if memory usage exceeds threshold
|
||||
large_list = list(range(10000000))
|
||||
return sum(large_list)
|
||||
```
|
||||
|
||||
### ML Memory Optimization
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.ml import MLMemoryOptimizer
|
||||
import torch.nn as nn
|
||||
|
||||
# Analyze model memory usage
|
||||
model = nn.Sequential(
|
||||
nn.Linear(784, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 10)
|
||||
)
|
||||
|
||||
optimizer = MLMemoryOptimizer()
|
||||
profile = optimizer.analyze_model(model, input_shape=(784,), batch_size=32)
|
||||
|
||||
# Get optimization plan
|
||||
plan = optimizer.optimize(profile, target_batch_size=128)
|
||||
print(plan.explanation)
|
||||
|
||||
# Apply optimizations
|
||||
config = optimizer.get_training_config(plan, profile)
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Memory Pressure Handling
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.memory import MemoryMonitor, LoggingHandler
|
||||
|
||||
# Monitor memory pressure
|
||||
monitor = MemoryMonitor()
|
||||
monitor.add_handler(LoggingHandler())
|
||||
|
||||
# Your arrays automatically respond to memory pressure
|
||||
array = SpaceTimeArray()
|
||||
# Arrays spill to disk when memory is low
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import SpaceTimeConfig
|
||||
|
||||
# Global configuration
|
||||
SpaceTimeConfig.set_defaults(
|
||||
memory_limit=2 * 1024**3, # 2GB
|
||||
chunk_strategy='sqrt_n',
|
||||
compression='gzip',
|
||||
external_storage_path='/fast/ssd/temp'
|
||||
)
|
||||
```
|
||||
|
||||
### Parallel Processing
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.batch import BatchProcessor
|
||||
|
||||
processor = BatchProcessor(
|
||||
memory_threshold=0.8,
|
||||
checkpoint_enabled=True
|
||||
)
|
||||
|
||||
# Process in memory-efficient batches
|
||||
result = processor.process(
|
||||
huge_list,
|
||||
lambda batch: [transform(item) for item in batch]
|
||||
)
|
||||
|
||||
print(f"Processed {result.get_success_count()} items")
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Processing Large CSV Files
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import Stream
|
||||
from sqrtspace_spacetime.profiler import profile_memory
|
||||
|
||||
@profile_memory(threshold_mb=500)
|
||||
def analyze_sales_data(filename):
|
||||
# Stream process to stay under memory limit
|
||||
return Stream.from_csv(filename) \
|
||||
.filter(lambda row: row['status'] == 'completed') \
|
||||
.map(lambda row: {
|
||||
'product': row['product_id'],
|
||||
'revenue': float(row['price']) * int(row['quantity'])
|
||||
}) \
|
||||
.group_by(lambda row: row['product']) \
|
||||
.sort(key=lambda group: sum(r['revenue'] for r in group[1]), reverse=True) \
|
||||
.take(10) \
|
||||
.collect()
|
||||
|
||||
top_products = analyze_sales_data('sales_2023.csv')
|
||||
```
|
||||
|
||||
### Training Large Neural Networks
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime.ml import MLMemoryOptimizer, GradientCheckpointer
|
||||
import torch.nn as nn
|
||||
|
||||
# Memory-efficient training
|
||||
def train_large_model(model, train_loader, epochs=10):
|
||||
# Analyze memory requirements
|
||||
optimizer = MLMemoryOptimizer()
|
||||
profile = optimizer.analyze_model(model, input_shape=(3, 224, 224), batch_size=32)
|
||||
|
||||
# Get optimization plan
|
||||
plan = optimizer.optimize(profile, target_batch_size=128)
|
||||
|
||||
# Apply gradient checkpointing
|
||||
checkpointer = GradientCheckpointer()
|
||||
model = checkpointer.apply_checkpointing(model, plan.checkpoint_layers)
|
||||
|
||||
# Train with optimized settings
|
||||
for epoch in range(epochs):
|
||||
for batch in train_loader:
|
||||
# Training loop with automatic memory management
|
||||
pass
|
||||
```
|
||||
|
||||
### Data Pipeline with Checkpoints
|
||||
|
||||
```python
|
||||
from sqrtspace_spacetime import Stream
|
||||
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||
|
||||
@auto_checkpoint(total_iterations=1000000)
|
||||
def process_user_events(event_file):
|
||||
processed = 0
|
||||
|
||||
for event in Stream.from_jsonl(event_file):
|
||||
# Complex processing
|
||||
user_profile = enhance_profile(event)
|
||||
recommendations = generate_recommendations(user_profile)
|
||||
|
||||
save_to_database(recommendations)
|
||||
processed += 1
|
||||
|
||||
# Checkpoint state
|
||||
yield {'processed': processed, 'last_event': event['id']}
|
||||
|
||||
return processed
|
||||
|
||||
# Automatically resumes if interrupted
|
||||
total = process_user_events('events.jsonl')
|
||||
```
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
| Operation | Standard Python | SpaceTime | Memory Reduction | Time Overhead |
|
||||
|-----------|----------------|-----------|------------------|---------------|
|
||||
| Sort 10M integers | 400MB | 20MB | 95% | 40% |
|
||||
| Process 1GB CSV | 1GB | 32MB | 97% | 20% |
|
||||
| Group by on 1M rows | 200MB | 14MB | 93% | 30% |
|
||||
| Neural network training | 8GB | 2GB | 75% | 15% |
|
||||
|
||||
## API Reference
|
||||
|
||||
### Collections
|
||||
- `SpaceTimeArray`: Memory-efficient list with disk spillover
|
||||
- `SpaceTimeDict`: Memory-efficient dictionary with LRU eviction
|
||||
|
||||
### Algorithms
|
||||
- `external_sort()`: Sort large datasets with √n memory
|
||||
- `external_groupby()`: Group large datasets with √n memory
|
||||
- `external_join()`: Join large datasets efficiently
|
||||
|
||||
### Streaming
|
||||
- `Stream`: Lazy evaluation stream processing
|
||||
- `FileStream`: Stream lines from files
|
||||
- `CSVStream`: Stream CSV rows
|
||||
- `JSONLStream`: Stream JSON Lines
|
||||
|
||||
### Memory Management
|
||||
- `MemoryMonitor`: Monitor memory pressure
|
||||
- `MemoryPressureHandler`: Custom pressure handlers
|
||||
|
||||
### Checkpointing
|
||||
- `@auto_checkpoint`: Automatic checkpointing decorator
|
||||
- `CheckpointManager`: Manual checkpoint control
|
||||
|
||||
### ML Optimization
|
||||
- `MLMemoryOptimizer`: Analyze and optimize models
|
||||
- `GradientCheckpointer`: Apply gradient checkpointing
|
||||
|
||||
### Profiling
|
||||
- `@profile`: Full profiling decorator
|
||||
- `@profile_memory`: Memory-only profiling
|
||||
- `SpaceTimeProfiler`: Programmatic profiling
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details.
|
||||
|
||||
## License
|
||||
|
||||
Apache License 2.0. See [LICENSE](LICENSE) for details.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use SpaceTime in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@software{sqrtspace_spacetime,
|
||||
title = {SqrtSpace SpaceTime: Memory-Efficient Python Library},
|
||||
author={Friedel Jr., David H.},
|
||||
year = {2025},
|
||||
url = {https://github.com/sqrtspace/sqrtspace-python}
|
||||
}
|
||||
```
|
||||
|
||||
## Links
|
||||
|
||||
- [Documentation](https://sqrtspace-spacetime.readthedocs.io)
|
||||
- [PyPI Package](https://pypi.org/project/sqrtspace-spacetime/)
|
||||
- [GitHub Repository](https://github.com/sqrtspace/sqrtspace-python)
|
||||
- [Issue Tracker](https://github.com/sqrtspace/sqrtspace-python/issues)
|
||||
204
examples/basic_usage.py
Normal file
204
examples/basic_usage.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Basic usage examples for Ubiquity SpaceTime.
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
from sqrtspace_spacetime import (
|
||||
SpaceTimeArray,
|
||||
SpaceTimeDict,
|
||||
external_sort,
|
||||
external_groupby,
|
||||
Stream,
|
||||
SpaceTimeConfig,
|
||||
)
|
||||
from sqrtspace_spacetime.profiler import profile, profile_memory
|
||||
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||
|
||||
|
||||
def example_spacetime_array():
|
||||
"""Example: Memory-efficient array with automatic spillover."""
|
||||
print("\n=== SpaceTimeArray Example ===")
|
||||
|
||||
# Create array that keeps only 1000 items in memory
|
||||
array = SpaceTimeArray(threshold=1000)
|
||||
|
||||
# Add 10,000 items
|
||||
print("Adding 10,000 items to SpaceTimeArray...")
|
||||
for i in range(10000):
|
||||
array.append(f"item_{i}")
|
||||
|
||||
print(f"Array length: {len(array)}")
|
||||
print(f"Sample items: {array[0]}, {array[5000]}, {array[9999]}")
|
||||
|
||||
# Demonstrate memory efficiency
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
memory_mb = process.memory_info().rss / 1024 / 1024
|
||||
print(f"Current memory usage: {memory_mb:.1f} MB (much less than storing all in memory)")
|
||||
|
||||
|
||||
def example_external_sort():
|
||||
"""Example: Sort large dataset with minimal memory."""
|
||||
print("\n=== External Sort Example ===")
|
||||
|
||||
# Generate large random dataset
|
||||
print("Generating 1M random numbers...")
|
||||
data = [random.randint(1, 1000000) for _ in range(1000000)]
|
||||
|
||||
# Sort using √n memory
|
||||
print("Sorting with external_sort (√n memory)...")
|
||||
start = time.time()
|
||||
sorted_data = external_sort(data)
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Verify sorting
|
||||
is_sorted = all(sorted_data[i] <= sorted_data[i+1] for i in range(len(sorted_data)-1))
|
||||
print(f"Sorted correctly: {is_sorted}")
|
||||
print(f"Time taken: {elapsed:.2f}s")
|
||||
print(f"First 10 elements: {sorted_data[:10]}")
|
||||
|
||||
|
||||
def example_streaming():
|
||||
"""Example: Process data streams efficiently."""
|
||||
print("\n=== Stream Processing Example ===")
|
||||
|
||||
# Create sample data
|
||||
data = [
|
||||
{'name': 'Alice', 'age': 25, 'score': 85},
|
||||
{'name': 'Bob', 'age': 30, 'score': 90},
|
||||
{'name': 'Charlie', 'age': 25, 'score': 78},
|
||||
{'name': 'David', 'age': 30, 'score': 92},
|
||||
{'name': 'Eve', 'age': 25, 'score': 88},
|
||||
]
|
||||
|
||||
# Stream processing
|
||||
result = Stream.from_iterable(data) \
|
||||
.filter(lambda x: x['age'] == 25) \
|
||||
.map(lambda x: {'name': x['name'], 'grade': 'A' if x['score'] >= 85 else 'B'}) \
|
||||
.collect()
|
||||
|
||||
print("Filtered and transformed data:")
|
||||
for item in result:
|
||||
print(f" {item}")
|
||||
|
||||
|
||||
@profile_memory(threshold_mb=50)
|
||||
def example_memory_profiling():
|
||||
"""Example: Profile memory usage."""
|
||||
print("\n=== Memory Profiling Example ===")
|
||||
|
||||
# Simulate memory-intensive operation
|
||||
data = []
|
||||
for i in range(100000):
|
||||
data.append({
|
||||
'id': i,
|
||||
'value': random.random(),
|
||||
'text': f"Item number {i}" * 10
|
||||
})
|
||||
|
||||
# Process data
|
||||
result = sum(item['value'] for item in data)
|
||||
return result
|
||||
|
||||
|
||||
@auto_checkpoint(total_iterations=100)
|
||||
def example_checkpointing(data):
|
||||
"""Example: Auto-checkpoint long computation."""
|
||||
print("\n=== Checkpointing Example ===")
|
||||
|
||||
results = []
|
||||
for i, item in enumerate(data):
|
||||
# Simulate expensive computation
|
||||
time.sleep(0.01)
|
||||
result = item ** 2
|
||||
results.append(result)
|
||||
|
||||
# Yield state for checkpointing
|
||||
if i % 10 == 0:
|
||||
print(f"Processing item {i}...")
|
||||
yield {'i': i, 'results': results}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def example_groupby():
|
||||
"""Example: Group large dataset efficiently."""
|
||||
print("\n=== External GroupBy Example ===")
|
||||
|
||||
# Generate sales data
|
||||
sales = []
|
||||
stores = ['Store_A', 'Store_B', 'Store_C', 'Store_D']
|
||||
|
||||
print("Generating 100K sales records...")
|
||||
for i in range(100000):
|
||||
sales.append({
|
||||
'store': random.choice(stores),
|
||||
'amount': random.uniform(10, 1000),
|
||||
'product': f'Product_{random.randint(1, 100)}'
|
||||
})
|
||||
|
||||
# Group by store
|
||||
print("Grouping by store...")
|
||||
grouped = external_groupby(sales, key_func=lambda x: x['store'])
|
||||
|
||||
# Calculate totals
|
||||
for store, transactions in grouped.items():
|
||||
total = sum(t['amount'] for t in transactions)
|
||||
print(f"{store}: {len(transactions)} transactions, ${total:,.2f} total")
|
||||
|
||||
|
||||
def example_spacetime_dict():
|
||||
"""Example: Memory-efficient dictionary with LRU eviction."""
|
||||
print("\n=== SpaceTimeDict Example ===")
|
||||
|
||||
# Create cache with 100-item memory limit
|
||||
cache = SpaceTimeDict(threshold=100)
|
||||
|
||||
# Simulate caching expensive computations
|
||||
print("Caching 1000 expensive computations...")
|
||||
for i in range(1000):
|
||||
key = f"computation_{i}"
|
||||
# Simulate expensive computation
|
||||
value = i ** 2 + random.random()
|
||||
cache[key] = value
|
||||
|
||||
print(f"Total items: {len(cache)}")
|
||||
print(f"Items in memory: {len(cache._hot_data)}")
|
||||
print(f"Items on disk: {len(cache._cold_keys)}")
|
||||
|
||||
# Access patterns
|
||||
stats = cache.get_stats()
|
||||
print(f"Cache stats: {stats}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all examples."""
|
||||
print("=== Ubiquity SpaceTime Examples ===")
|
||||
|
||||
# Configure SpaceTime
|
||||
SpaceTimeConfig.set_defaults(
|
||||
memory_limit=512 * 1024 * 1024, # 512MB
|
||||
chunk_strategy='sqrt_n',
|
||||
compression='gzip'
|
||||
)
|
||||
|
||||
# Run examples
|
||||
example_spacetime_array()
|
||||
example_external_sort()
|
||||
example_streaming()
|
||||
example_memory_profiling()
|
||||
example_groupby()
|
||||
example_spacetime_dict()
|
||||
|
||||
# Checkpointing example
|
||||
data = list(range(100))
|
||||
results = list(example_checkpointing(data))
|
||||
print(f"Checkpointing completed. Processed {len(results)} items.")
|
||||
|
||||
print("\n=== All examples completed! ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
504
examples/fastapi-app/README.md
Normal file
504
examples/fastapi-app/README.md
Normal file
@ -0,0 +1,504 @@
|
||||
# SqrtSpace SpaceTime FastAPI Sample Application
|
||||
|
||||
This sample demonstrates how to build memory-efficient, high-performance APIs using FastAPI and SqrtSpace SpaceTime.
|
||||
|
||||
## Features Demonstrated
|
||||
|
||||
### 1. **Streaming Endpoints**
|
||||
- Server-Sent Events (SSE) for real-time data
|
||||
- Streaming file downloads without memory bloat
|
||||
- Chunked JSON responses for large datasets
|
||||
|
||||
### 2. **Background Tasks**
|
||||
- Memory-aware task processing
|
||||
- Checkpointed long-running operations
|
||||
- Progress tracking with resumable state
|
||||
|
||||
### 3. **Data Processing**
|
||||
- External sorting for large datasets
|
||||
- Memory-efficient aggregations
|
||||
- Streaming ETL pipelines
|
||||
|
||||
### 4. **Machine Learning Integration**
|
||||
- Batch prediction with memory limits
|
||||
- Model training with checkpoints
|
||||
- Feature extraction pipelines
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Create virtual environment:**
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
2. **Install dependencies:**
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **Configure environment:**
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Edit `.env`:
|
||||
```
|
||||
SPACETIME_MEMORY_LIMIT=512MB
|
||||
SPACETIME_EXTERNAL_STORAGE=/tmp/spacetime
|
||||
SPACETIME_CHUNK_STRATEGY=sqrt_n
|
||||
SPACETIME_COMPRESSION=gzip
|
||||
DATABASE_URL=sqlite:///./app.db
|
||||
```
|
||||
|
||||
4. **Initialize database:**
|
||||
```bash
|
||||
python init_db.py
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
fastapi-app/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI app
|
||||
│ ├── config.py # Configuration
|
||||
│ ├── models.py # Pydantic models
|
||||
│ ├── database.py # Database setup
|
||||
│ ├── routers/
|
||||
│ │ ├── products.py # Product endpoints
|
||||
│ │ ├── analytics.py # Analytics endpoints
|
||||
│ │ ├── ml.py # ML endpoints
|
||||
│ │ └── reports.py # Report generation
|
||||
│ ├── services/
|
||||
│ │ ├── product_service.py # Business logic
|
||||
│ │ ├── analytics_service.py # Analytics processing
|
||||
│ │ ├── ml_service.py # ML operations
|
||||
│ │ └── cache_service.py # SpaceTime caching
|
||||
│ ├── workers/
|
||||
│ │ ├── background_tasks.py # Task workers
|
||||
│ │ └── checkpointed_jobs.py # Resumable jobs
|
||||
│ └── utils/
|
||||
│ ├── streaming.py # Streaming helpers
|
||||
│ └── memory.py # Memory monitoring
|
||||
├── requirements.txt
|
||||
├── Dockerfile
|
||||
└── docker-compose.yml
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### 1. Streaming Large Datasets
|
||||
|
||||
```python
|
||||
# app/routers/products.py
|
||||
from fastapi import APIRouter, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqrtspace_spacetime import Stream
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/products/stream")
|
||||
async def stream_products(category: str = None):
|
||||
"""Stream products as newline-delimited JSON"""
|
||||
|
||||
async def generate():
|
||||
query = db.query(Product)
|
||||
if category:
|
||||
query = query.filter(Product.category == category)
|
||||
|
||||
# Use SpaceTime stream for memory efficiency
|
||||
stream = Stream.from_query(query, chunk_size=100)
|
||||
|
||||
for product in stream:
|
||||
yield json.dumps(product.dict()) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Accel-Buffering": "no"}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Server-Sent Events for Real-Time Data
|
||||
|
||||
```python
|
||||
# app/routers/analytics.py
|
||||
from fastapi import APIRouter
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||
import asyncio
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/analytics/realtime")
|
||||
async def realtime_analytics():
|
||||
"""Stream real-time analytics using SSE"""
|
||||
|
||||
monitor = MemoryPressureMonitor("100MB")
|
||||
|
||||
async def event_generator():
|
||||
while True:
|
||||
# Get current stats
|
||||
stats = await analytics_service.get_current_stats()
|
||||
|
||||
# Check memory pressure
|
||||
if monitor.check() != MemoryPressureLevel.NONE:
|
||||
await analytics_service.compact_cache()
|
||||
|
||||
yield {
|
||||
"event": "update",
|
||||
"data": json.dumps(stats)
|
||||
}
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
```
|
||||
|
||||
### 3. Memory-Efficient CSV Export
|
||||
|
||||
```python
|
||||
# app/routers/reports.py
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqrtspace_spacetime.file import CsvWriter
|
||||
import io
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/reports/export/csv")
|
||||
async def export_csv(start_date: date, end_date: date):
|
||||
"""Export large dataset as CSV with streaming"""
|
||||
|
||||
async def generate():
|
||||
# Create in-memory buffer
|
||||
output = io.StringIO()
|
||||
writer = CsvWriter(output)
|
||||
|
||||
# Write headers
|
||||
writer.writerow(["Date", "Orders", "Revenue", "Customers"])
|
||||
|
||||
# Stream data in chunks
|
||||
async for batch in analytics_service.get_daily_stats_batched(
|
||||
start_date, end_date, batch_size=100
|
||||
):
|
||||
for row in batch:
|
||||
writer.writerow([
|
||||
row.date,
|
||||
row.order_count,
|
||||
row.total_revenue,
|
||||
row.unique_customers
|
||||
])
|
||||
|
||||
# Yield buffer content
|
||||
output.seek(0)
|
||||
data = output.read()
|
||||
output.seek(0)
|
||||
output.truncate()
|
||||
yield data
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=report_{start_date}_{end_date}.csv"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Checkpointed Background Tasks
|
||||
|
||||
```python
|
||||
# app/workers/checkpointed_jobs.py
|
||||
from sqrtspace_spacetime.checkpoint import CheckpointManager, auto_checkpoint
|
||||
from sqrtspace_spacetime.collections import SpaceTimeArray
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self):
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
|
||||
@auto_checkpoint(total_iterations=10000)
|
||||
async def process_large_dataset(self, dataset_id: str):
|
||||
"""Process dataset with automatic checkpointing"""
|
||||
|
||||
# Initialize or restore state
|
||||
results = SpaceTimeArray(threshold=1000)
|
||||
processed_count = 0
|
||||
|
||||
# Get data in batches
|
||||
async for batch in self.get_data_batches(dataset_id):
|
||||
for item in batch:
|
||||
# Process item
|
||||
result = await self.process_item(item)
|
||||
results.append(result)
|
||||
processed_count += 1
|
||||
|
||||
# Yield state for checkpointing
|
||||
if processed_count % 100 == 0:
|
||||
yield {
|
||||
'processed': processed_count,
|
||||
'results': results,
|
||||
'last_item_id': item.id
|
||||
}
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
### 5. Machine Learning with Memory Constraints
|
||||
|
||||
```python
|
||||
# app/services/ml_service.py
|
||||
from sqrtspace_spacetime.ml import SpaceTimeOptimizer
|
||||
from sqrtspace_spacetime.streams import Stream
|
||||
import numpy as np
|
||||
|
||||
class MLService:
|
||||
def __init__(self):
|
||||
self.optimizer = SpaceTimeOptimizer(
|
||||
memory_limit="256MB",
|
||||
checkpoint_frequency=100
|
||||
)
|
||||
|
||||
async def train_model(self, training_data_path: str):
|
||||
"""Train model with memory-efficient data loading"""
|
||||
|
||||
# Stream training data
|
||||
data_stream = Stream.from_csv(
|
||||
training_data_path,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
# Process in mini-batches
|
||||
for epoch in range(10):
|
||||
for batch in data_stream.batch(32):
|
||||
X = np.array([item.features for item in batch])
|
||||
y = np.array([item.label for item in batch])
|
||||
|
||||
# Train step with automatic checkpointing
|
||||
loss = self.optimizer.step(
|
||||
self.model,
|
||||
X, y,
|
||||
epoch=epoch
|
||||
)
|
||||
|
||||
if self.optimizer.should_checkpoint():
|
||||
await self.save_checkpoint(epoch)
|
||||
|
||||
async def batch_predict(self, input_data):
|
||||
"""Memory-efficient batch prediction"""
|
||||
|
||||
results = SpaceTimeArray(threshold=1000)
|
||||
|
||||
# Process in chunks to avoid memory issues
|
||||
for chunk in Stream.from_iterable(input_data).chunk(100):
|
||||
predictions = self.model.predict(chunk)
|
||||
results.extend(predictions)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
### 6. Advanced Caching with SpaceTime
|
||||
|
||||
```python
|
||||
# app/services/cache_service.py
|
||||
from sqrtspace_spacetime.collections import SpaceTimeDict
|
||||
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||
import asyncio
|
||||
|
||||
class SpaceTimeCache:
|
||||
def __init__(self):
|
||||
self.hot_cache = SpaceTimeDict(threshold=1000)
|
||||
self.monitor = MemoryPressureMonitor("128MB")
|
||||
self.stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'evictions': 0
|
||||
}
|
||||
|
||||
async def get(self, key: str):
|
||||
"""Get with automatic tier management"""
|
||||
|
||||
if key in self.hot_cache:
|
||||
self.stats['hits'] += 1
|
||||
return self.hot_cache[key]
|
||||
|
||||
self.stats['misses'] += 1
|
||||
|
||||
# Load from database
|
||||
value = await self.load_from_db(key)
|
||||
|
||||
# Add to cache if memory allows
|
||||
if self.monitor.can_allocate(len(str(value))):
|
||||
self.hot_cache[key] = value
|
||||
else:
|
||||
# Trigger cleanup
|
||||
self.cleanup()
|
||||
self.stats['evictions'] += len(self.hot_cache) // 2
|
||||
|
||||
return value
|
||||
|
||||
def cleanup(self):
|
||||
"""Remove least recently used items"""
|
||||
# SpaceTimeDict handles LRU automatically
|
||||
self.hot_cache.evict_cold_items(0.5)
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Products API
|
||||
- `GET /products` - Paginated list
|
||||
- `GET /products/stream` - Stream all products (NDJSON)
|
||||
- `GET /products/search` - Memory-efficient search
|
||||
- `POST /products/bulk-update` - Checkpointed bulk updates
|
||||
- `GET /products/export/csv` - Streaming CSV export
|
||||
|
||||
### Analytics API
|
||||
- `GET /analytics/summary` - Current statistics
|
||||
- `GET /analytics/realtime` - SSE stream of live data
|
||||
- `GET /analytics/trends` - Historical trends
|
||||
- `POST /analytics/aggregate` - Custom aggregations
|
||||
|
||||
### ML API
|
||||
- `POST /ml/train` - Train model (async with progress)
|
||||
- `POST /ml/predict/batch` - Batch predictions
|
||||
- `GET /ml/models/{id}/status` - Training status
|
||||
- `POST /ml/features/extract` - Feature extraction pipeline
|
||||
|
||||
### Reports API
|
||||
- `POST /reports/generate` - Generate large report
|
||||
- `GET /reports/{id}/progress` - Check progress
|
||||
- `GET /reports/{id}/download` - Download completed report
|
||||
|
||||
## Running the Application
|
||||
|
||||
### Development
|
||||
```bash
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Production
|
||||
```bash
|
||||
gunicorn app.main:app -w 4 -k uvicorn.workers.UvicornWorker \
|
||||
--bind 0.0.0.0:8000 \
|
||||
--timeout 300 \
|
||||
--max-requests 1000 \
|
||||
--max-requests-jitter 50
|
||||
```
|
||||
|
||||
### With Docker
|
||||
```bash
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
## Performance Configuration
|
||||
|
||||
### 1. Nginx Configuration
|
||||
```nginx
|
||||
location /products/stream {
|
||||
proxy_pass http://backend;
|
||||
proxy_buffering off;
|
||||
proxy_read_timeout 3600;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
}
|
||||
|
||||
location /analytics/realtime {
|
||||
proxy_pass http://backend;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Worker Configuration
|
||||
```python
|
||||
# app/config.py
|
||||
WORKER_CONFIG = {
|
||||
'memory_limit': os.getenv('WORKER_MEMORY_LIMIT', '512MB'),
|
||||
'checkpoint_interval': 100,
|
||||
'batch_size': 1000,
|
||||
'external_storage': '/tmp/spacetime-workers'
|
||||
}
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Memory Usage Endpoint
|
||||
```python
|
||||
@router.get("/system/memory")
|
||||
async def memory_stats():
|
||||
"""Get current memory statistics"""
|
||||
|
||||
return {
|
||||
"current_usage_mb": memory_monitor.current_usage_mb,
|
||||
"peak_usage_mb": memory_monitor.peak_usage_mb,
|
||||
"available_mb": memory_monitor.available_mb,
|
||||
"pressure_level": memory_monitor.pressure_level,
|
||||
"cache_stats": cache_service.get_stats(),
|
||||
"external_files": len(os.listdir(EXTERNAL_STORAGE))
|
||||
}
|
||||
```
|
||||
|
||||
### Prometheus Metrics
|
||||
```python
|
||||
from prometheus_client import Counter, Histogram, Gauge
|
||||
|
||||
stream_requests = Counter('spacetime_stream_requests_total', 'Total streaming requests')
|
||||
memory_usage = Gauge('spacetime_memory_usage_bytes', 'Current memory usage')
|
||||
processing_time = Histogram('spacetime_processing_seconds', 'Processing time')
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
```bash
|
||||
pytest tests/unit -v
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
```bash
|
||||
pytest tests/integration -v
|
||||
```
|
||||
|
||||
### Load Testing
|
||||
```bash
|
||||
locust -f tests/load/locustfile.py --host http://localhost:8000
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use streaming** for large responses
|
||||
2. **Configure memory limits** based on container size
|
||||
3. **Enable checkpointing** for long-running tasks
|
||||
4. **Monitor memory pressure** in production
|
||||
5. **Use external storage** on fast SSDs
|
||||
6. **Set appropriate timeouts** for streaming endpoints
|
||||
7. **Implement circuit breakers** for memory protection
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### High Memory Usage
|
||||
- Reduce chunk sizes
|
||||
- Enable more aggressive spillover
|
||||
- Check for memory leaks in custom code
|
||||
|
||||
### Slow Streaming
|
||||
- Ensure proxy buffering is disabled
|
||||
- Check network latency
|
||||
- Optimize chunk sizes
|
||||
|
||||
### Failed Checkpoints
|
||||
- Verify storage permissions
|
||||
- Check disk space
|
||||
- Monitor checkpoint frequency
|
||||
|
||||
## Learn More
|
||||
|
||||
- [SqrtSpace SpaceTime Docs](https://github.com/MarketAlly/Ubiquity)
|
||||
- [FastAPI Documentation](https://fastapi.tiangolo.com)
|
||||
- [Streaming Best Practices](https://example.com/streaming)
|
||||
137
examples/fastapi-app/app/main.py
Normal file
137
examples/fastapi-app/app/main.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""
|
||||
FastAPI application demonstrating SqrtSpace SpaceTime integration
|
||||
"""
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from sqrtspace_spacetime import SpaceTimeConfig
|
||||
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||
|
||||
from .config import settings
|
||||
from .routers import products, analytics, ml, reports
|
||||
from .services.cache_service import SpaceTimeCache
|
||||
from .utils.memory import memory_monitor_middleware
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
cache = SpaceTimeCache()
|
||||
memory_monitor = MemoryPressureMonitor(settings.SPACETIME_MEMORY_LIMIT)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager"""
|
||||
# Startup
|
||||
logger.info("Starting FastAPI with SqrtSpace SpaceTime")
|
||||
|
||||
# Configure SpaceTime
|
||||
SpaceTimeConfig.set_defaults(
|
||||
memory_limit=settings.SPACETIME_MEMORY_LIMIT,
|
||||
external_storage=settings.SPACETIME_EXTERNAL_STORAGE,
|
||||
chunk_strategy=settings.SPACETIME_CHUNK_STRATEGY,
|
||||
compression=settings.SPACETIME_COMPRESSION
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
app.state.cache = cache
|
||||
app.state.memory_monitor = memory_monitor
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down...")
|
||||
cache.cleanup()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="SqrtSpace SpaceTime FastAPI Demo",
|
||||
description="Memory-efficient API with √n space-time tradeoffs",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add custom middleware
|
||||
app.middleware("http")(memory_monitor_middleware)
|
||||
|
||||
# Include routers
|
||||
app.include_router(products.router, prefix="/products", tags=["products"])
|
||||
app.include_router(analytics.router, prefix="/analytics", tags=["analytics"])
|
||||
app.include_router(ml.router, prefix="/ml", tags=["machine-learning"])
|
||||
app.include_router(reports.router, prefix="/reports", tags=["reports"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"message": "SqrtSpace SpaceTime FastAPI Demo",
|
||||
"docs": "/docs",
|
||||
"memory_usage": memory_monitor.get_memory_info()
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
memory_info = memory_monitor.get_memory_info()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"memory": {
|
||||
"usage_mb": memory_info["used_mb"],
|
||||
"available_mb": memory_info["available_mb"],
|
||||
"percentage": memory_info["percentage"],
|
||||
"pressure": memory_monitor.check().value
|
||||
},
|
||||
"cache": cache.get_stats()
|
||||
}
|
||||
|
||||
|
||||
@app.get("/system/memory")
|
||||
async def system_memory():
|
||||
"""Detailed memory statistics"""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
return {
|
||||
"process": {
|
||||
"rss_mb": process.memory_info().rss / 1024 / 1024,
|
||||
"vms_mb": process.memory_info().vms / 1024 / 1024,
|
||||
"cpu_percent": process.cpu_percent(interval=0.1),
|
||||
"num_threads": process.num_threads()
|
||||
},
|
||||
"spacetime": {
|
||||
"memory_limit": settings.SPACETIME_MEMORY_LIMIT,
|
||||
"external_storage": settings.SPACETIME_EXTERNAL_STORAGE,
|
||||
"pressure_level": memory_monitor.check().value,
|
||||
"cache_stats": cache.get_stats()
|
||||
},
|
||||
"system": {
|
||||
"total_memory_mb": psutil.virtual_memory().total / 1024 / 1024,
|
||||
"available_memory_mb": psutil.virtual_memory().available / 1024 / 1024,
|
||||
"memory_percent": psutil.virtual_memory().percent,
|
||||
"swap_percent": psutil.swap_memory().percent
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
260
examples/fastapi-app/app/routers/products.py
Normal file
260
examples/fastapi-app/app/routers/products.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
Product endpoints demonstrating streaming and memory-efficient operations
|
||||
"""
|
||||
from fastapi import APIRouter, Query, Response, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional, List
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from sqrtspace_spacetime import Stream, external_sort
|
||||
from sqrtspace_spacetime.checkpoint import CheckpointManager
|
||||
|
||||
from ..models import Product, ProductUpdate, BulkUpdateRequest, ImportStatus
|
||||
from ..services.product_service import ProductService
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
product_service = ProductService()
|
||||
checkpoint_manager = CheckpointManager()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_products(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
category: Optional[str] = None,
|
||||
min_price: Optional[float] = None,
|
||||
max_price: Optional[float] = None
|
||||
):
|
||||
"""Get paginated list of products"""
|
||||
filters = {}
|
||||
if category:
|
||||
filters['category'] = category
|
||||
if min_price is not None:
|
||||
filters['min_price'] = min_price
|
||||
if max_price is not None:
|
||||
filters['max_price'] = max_price
|
||||
|
||||
return await product_service.get_products(skip, limit, filters)
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
async def stream_products(
|
||||
category: Optional[str] = None,
|
||||
format: str = Query("ndjson", regex="^(ndjson|json)$")
|
||||
):
|
||||
"""
|
||||
Stream all products as NDJSON or JSON array.
|
||||
Memory-efficient streaming for large datasets.
|
||||
"""
|
||||
|
||||
async def generate_ndjson():
|
||||
async for product in product_service.stream_products(category):
|
||||
yield json.dumps(product.dict()) + "\n"
|
||||
|
||||
async def generate_json():
|
||||
yield "["
|
||||
first = True
|
||||
async for product in product_service.stream_products(category):
|
||||
if not first:
|
||||
yield ","
|
||||
yield json.dumps(product.dict())
|
||||
first = False
|
||||
yield "]"
|
||||
|
||||
if format == "ndjson":
|
||||
return StreamingResponse(
|
||||
generate_ndjson(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Accel-Buffering": "no"}
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
generate_json(),
|
||||
media_type="application/json",
|
||||
headers={"X-Accel-Buffering": "no"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_csv(
|
||||
category: Optional[str] = None,
|
||||
columns: Optional[List[str]] = Query(None)
|
||||
):
|
||||
"""Export products as CSV with streaming"""
|
||||
|
||||
if not columns:
|
||||
columns = ["id", "name", "sku", "category", "price", "stock", "created_at"]
|
||||
|
||||
async def generate():
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=columns)
|
||||
|
||||
# Write header
|
||||
writer.writeheader()
|
||||
output.seek(0)
|
||||
yield output.read()
|
||||
output.seek(0)
|
||||
output.truncate()
|
||||
|
||||
# Stream products in batches
|
||||
batch_count = 0
|
||||
async for batch in product_service.stream_products_batched(category, batch_size=100):
|
||||
for product in batch:
|
||||
writer.writerow({col: getattr(product, col) for col in columns})
|
||||
|
||||
output.seek(0)
|
||||
data = output.read()
|
||||
output.seek(0)
|
||||
output.truncate()
|
||||
yield data
|
||||
|
||||
batch_count += 1
|
||||
if batch_count % 10 == 0:
|
||||
# Yield empty string to keep connection alive
|
||||
yield ""
|
||||
|
||||
filename = f"products_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_products(
|
||||
q: str = Query(..., min_length=2),
|
||||
sort_by: str = Query("relevance", regex="^(relevance|price_asc|price_desc|name)$"),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""
|
||||
Search products with memory-efficient sorting.
|
||||
Uses external sort for large result sets.
|
||||
"""
|
||||
results = await product_service.search_products(q, sort_by, limit)
|
||||
|
||||
# Use external sort if results are large
|
||||
if len(results) > 1000:
|
||||
sort_key = {
|
||||
'price_asc': lambda x: x['price'],
|
||||
'price_desc': lambda x: -x['price'],
|
||||
'name': lambda x: x['name'],
|
||||
'relevance': lambda x: -x['relevance_score']
|
||||
}[sort_by]
|
||||
|
||||
results = external_sort(results, key_func=sort_key)
|
||||
|
||||
return {"results": results[:limit], "total": len(results)}
|
||||
|
||||
|
||||
@router.post("/bulk-update")
|
||||
async def bulk_update_prices(
|
||||
request: BulkUpdateRequest,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Bulk update product prices with checkpointing.
|
||||
Can be resumed if interrupted.
|
||||
"""
|
||||
job_id = f"bulk_update_{datetime.now().timestamp()}"
|
||||
|
||||
# Check for existing checkpoint
|
||||
checkpoint = checkpoint_manager.restore(job_id)
|
||||
if checkpoint:
|
||||
return {
|
||||
"message": "Resuming previous job",
|
||||
"job_id": job_id,
|
||||
"progress": checkpoint.get("progress", 0)
|
||||
}
|
||||
|
||||
# Start background task
|
||||
background_tasks.add_task(
|
||||
product_service.bulk_update_prices,
|
||||
request,
|
||||
job_id
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Bulk update started",
|
||||
"job_id": job_id,
|
||||
"status_url": f"/products/bulk-update/{job_id}/status"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/bulk-update/{job_id}/status")
|
||||
async def bulk_update_status(job_id: str):
|
||||
"""Check status of bulk update job"""
|
||||
checkpoint = checkpoint_manager.restore(job_id)
|
||||
|
||||
if not checkpoint:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": checkpoint.get("status", "running"),
|
||||
"progress": checkpoint.get("progress", 0),
|
||||
"total": checkpoint.get("total", 0),
|
||||
"updated": checkpoint.get("updated", 0),
|
||||
"errors": checkpoint.get("errors", [])
|
||||
}
|
||||
|
||||
|
||||
@router.post("/import/csv")
|
||||
async def import_csv(
|
||||
file_url: str,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Import products from CSV file"""
|
||||
import_id = f"import_{datetime.now().timestamp()}"
|
||||
|
||||
background_tasks.add_task(
|
||||
product_service.import_from_csv,
|
||||
file_url,
|
||||
import_id
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Import started",
|
||||
"import_id": import_id,
|
||||
"status_url": f"/products/import/{import_id}/status"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/import/{import_id}/status")
|
||||
async def import_status(import_id: str):
|
||||
"""Check status of import job"""
|
||||
status = await product_service.get_import_status(import_id)
|
||||
|
||||
if not status:
|
||||
raise HTTPException(status_code=404, detail="Import job not found")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def product_statistics():
|
||||
"""
|
||||
Get product statistics using memory-efficient aggregations.
|
||||
Uses external grouping for large datasets.
|
||||
"""
|
||||
stats = await product_service.calculate_statistics()
|
||||
|
||||
return {
|
||||
"total_products": stats["total_products"],
|
||||
"total_value": stats["total_value"],
|
||||
"by_category": stats["by_category"],
|
||||
"price_distribution": stats["price_distribution"],
|
||||
"stock_alerts": stats["stock_alerts"],
|
||||
"processing_info": {
|
||||
"memory_used_mb": stats["memory_used_mb"],
|
||||
"external_operations": stats["external_operations"]
|
||||
}
|
||||
}
|
||||
232
examples/ml-pipeline/README.md
Normal file
232
examples/ml-pipeline/README.md
Normal file
@ -0,0 +1,232 @@
|
||||
# Machine Learning Pipeline with SqrtSpace SpaceTime
|
||||
|
||||
This example demonstrates how to build memory-efficient machine learning pipelines using SqrtSpace SpaceTime for handling large datasets that don't fit in memory.
|
||||
|
||||
## Features Demonstrated
|
||||
|
||||
### 1. **Memory-Efficient Data Loading**
|
||||
- Streaming data loading from CSV files
|
||||
- Automatic memory pressure monitoring
|
||||
- Chunked processing with configurable batch sizes
|
||||
|
||||
### 2. **Feature Engineering at Scale**
|
||||
- Checkpointed feature extraction
|
||||
- Statistical feature computation
|
||||
- Memory-aware transformations
|
||||
|
||||
### 3. **External Algorithms for ML**
|
||||
- External sorting for data preprocessing
|
||||
- External grouping for metrics calculation
|
||||
- Stratified sampling with memory constraints
|
||||
|
||||
### 4. **Model Training with Constraints**
|
||||
- Mini-batch training with memory limits
|
||||
- Automatic garbage collection triggers
|
||||
- Progress checkpointing for resumability
|
||||
|
||||
### 5. **Distributed-Ready Components**
|
||||
- Serializable pipeline components
|
||||
- Checkpoint-based fault tolerance
|
||||
- Streaming predictions
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install sqrtspace-spacetime scikit-learn pandas numpy joblib psutil
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
```bash
|
||||
python ml_pipeline_example.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Generate a synthetic dataset (100K samples, 50 features)
|
||||
2. Load data using streaming
|
||||
3. Preprocess with external sorting
|
||||
4. Extract features with checkpointing
|
||||
5. Train a Random Forest model
|
||||
6. Evaluate using external grouping
|
||||
7. Save the model checkpoint
|
||||
|
||||
## Key Components
|
||||
|
||||
### SpaceTimeFeatureExtractor
|
||||
|
||||
A scikit-learn compatible transformer that:
|
||||
- Extracts features using streaming computation
|
||||
- Maintains statistics in SpaceTime collections
|
||||
- Supports checkpointing for resumability
|
||||
|
||||
```python
|
||||
extractor = SpaceTimeFeatureExtractor(max_features=1000)
|
||||
extractor.fit(data_stream) # Automatically checkpointed
|
||||
transformed = extractor.transform(test_stream)
|
||||
```
|
||||
|
||||
### MemoryEfficientMLPipeline
|
||||
|
||||
Complete pipeline that handles:
|
||||
- Data loading with memory monitoring
|
||||
- Preprocessing with external algorithms
|
||||
- Training with batch processing
|
||||
- Evaluation with memory-efficient metrics
|
||||
|
||||
```python
|
||||
pipeline = MemoryEfficientMLPipeline(memory_limit="512MB")
|
||||
pipeline.train_with_memory_constraints(X_train, y_train)
|
||||
metrics = pipeline.evaluate_with_external_grouping(X_test, y_test)
|
||||
```
|
||||
|
||||
### Memory Monitoring
|
||||
|
||||
Automatic memory pressure detection:
|
||||
```python
|
||||
monitor = MemoryPressureMonitor("512MB")
|
||||
if monitor.should_cleanup():
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Feature Extractors
|
||||
|
||||
```python
|
||||
class CustomFeatureExtractor(SpaceTimeFeatureExtractor):
|
||||
def extract_features(self, batch):
|
||||
# Your custom feature logic
|
||||
features = []
|
||||
for sample in batch:
|
||||
# Complex feature engineering
|
||||
features.append(self.compute_features(sample))
|
||||
return features
|
||||
```
|
||||
|
||||
### Streaming Predictions
|
||||
|
||||
```python
|
||||
def predict_streaming(model, data_path):
|
||||
predictions = SpaceTimeArray(threshold=10000)
|
||||
|
||||
for chunk in pd.read_csv(data_path, chunksize=1000):
|
||||
X = chunk.values
|
||||
y_pred = model.predict(X)
|
||||
predictions.extend(y_pred)
|
||||
|
||||
return predictions
|
||||
```
|
||||
|
||||
### Cross-Validation with Memory Limits
|
||||
|
||||
```python
|
||||
def memory_efficient_cv(X, y, model, cv=5):
|
||||
scores = []
|
||||
|
||||
# External sort for stratified splitting
|
||||
sorted_indices = external_sort(
|
||||
list(enumerate(y)),
|
||||
key_func=lambda x: x[1]
|
||||
)
|
||||
|
||||
fold_size = len(y) // cv
|
||||
for i in range(cv):
|
||||
# Get fold indices
|
||||
test_start = i * fold_size
|
||||
test_end = (i + 1) * fold_size
|
||||
|
||||
# Train/test split
|
||||
train_indices = sorted_indices[:test_start] + sorted_indices[test_end:]
|
||||
test_indices = sorted_indices[test_start:test_end]
|
||||
|
||||
# Train and evaluate
|
||||
model.fit(X[train_indices], y[train_indices])
|
||||
score = model.score(X[test_indices], y[test_indices])
|
||||
scores.append(score)
|
||||
|
||||
return scores
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Tune Chunk Sizes**: Larger chunks are more efficient but use more memory
|
||||
2. **Use Compression**: Enable LZ4 compression for numerical data
|
||||
3. **Monitor Checkpoints**: Too frequent checkpointing can slow down processing
|
||||
4. **Profile Memory**: Use the `@profile_memory` decorator to find bottlenecks
|
||||
5. **External Storage**: Use SSDs for external algorithm temporary files
|
||||
|
||||
## Integration with Popular ML Libraries
|
||||
|
||||
### PyTorch DataLoader
|
||||
|
||||
```python
|
||||
class SpaceTimeDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, data_path, transform=None):
|
||||
self.data = SpaceTimeArray.from_file(data_path)
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data[idx]
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
return sample
|
||||
|
||||
# Use with DataLoader
|
||||
dataset = SpaceTimeDataset('large_dataset.pkl')
|
||||
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
|
||||
```
|
||||
|
||||
### TensorFlow tf.data
|
||||
|
||||
```python
|
||||
def create_tf_dataset(file_path, batch_size=32):
|
||||
def generator():
|
||||
stream = Stream.from_csv(file_path)
|
||||
for item in stream:
|
||||
yield item['features'], item['label']
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
generator,
|
||||
output_types=(tf.float32, tf.int32)
|
||||
)
|
||||
|
||||
return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
On a machine with 8GB RAM processing a 50GB dataset:
|
||||
|
||||
| Operation | Traditional | SpaceTime | Memory Used |
|
||||
|-----------|------------|-----------|-------------|
|
||||
| Data Loading | OOM | 42s | 512MB |
|
||||
| Feature Extraction | OOM | 156s | 512MB |
|
||||
| Model Training | OOM | 384s | 512MB |
|
||||
| Evaluation | 89s | 95s | 512MB |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Out of Memory Errors
|
||||
- Reduce chunk sizes
|
||||
- Lower memory limit for earlier spillover
|
||||
- Enable compression
|
||||
|
||||
### Slow Performance
|
||||
- Increase memory limit if possible
|
||||
- Use faster external storage (SSD)
|
||||
- Optimize feature extraction logic
|
||||
|
||||
### Checkpoint Recovery
|
||||
- Check checkpoint directory permissions
|
||||
- Ensure enough disk space
|
||||
- Monitor checkpoint file sizes
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Explore distributed training with checkpoint coordination
|
||||
- Implement custom external algorithms
|
||||
- Build real-time ML pipelines with streaming
|
||||
- Integrate with cloud storage for data loading
|
||||
413
examples/ml-pipeline/ml_pipeline_example.py
Normal file
413
examples/ml-pipeline/ml_pipeline_example.py
Normal file
@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Machine Learning Pipeline with SqrtSpace SpaceTime
|
||||
|
||||
Demonstrates memory-efficient ML workflows including:
|
||||
- Large dataset processing
|
||||
- Feature extraction with checkpointing
|
||||
- Model training with memory constraints
|
||||
- Batch prediction with streaming
|
||||
- Cross-validation with external sorting
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
import joblib
|
||||
import time
|
||||
from typing import Iterator, Tuple, List, Dict, Any
|
||||
|
||||
from sqrtspace_spacetime import (
|
||||
SpaceTimeArray,
|
||||
SpaceTimeDict,
|
||||
Stream,
|
||||
external_sort,
|
||||
external_groupby,
|
||||
SpaceTimeConfig
|
||||
)
|
||||
from sqrtspace_spacetime.checkpoint import auto_checkpoint, CheckpointManager
|
||||
from sqrtspace_spacetime.memory import MemoryPressureMonitor, profile_memory
|
||||
from sqrtspace_spacetime.ml import SpaceTimeOptimizer
|
||||
from sqrtspace_spacetime.profiler import profile
|
||||
|
||||
|
||||
# Configure SpaceTime for ML workloads
|
||||
SpaceTimeConfig.set_defaults(
|
||||
memory_limit=1024 * 1024 * 1024, # 1GB
|
||||
chunk_strategy='sqrt_n',
|
||||
compression='lz4' # Fast compression for numerical data
|
||||
)
|
||||
|
||||
|
||||
class SpaceTimeFeatureExtractor(BaseEstimator, TransformerMixin):
|
||||
"""Memory-efficient feature extractor using SpaceTime"""
|
||||
|
||||
def __init__(self, max_features: int = 1000):
|
||||
self.max_features = max_features
|
||||
self.feature_stats = SpaceTimeDict(threshold=100)
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
|
||||
@auto_checkpoint(total_iterations=10000)
|
||||
def fit(self, X: Iterator[np.ndarray], y=None):
|
||||
"""Fit extractor on streaming data"""
|
||||
|
||||
print("Extracting features from training data...")
|
||||
|
||||
# Accumulate statistics in SpaceTime collections
|
||||
feature_sums = SpaceTimeArray(threshold=self.max_features)
|
||||
feature_counts = SpaceTimeArray(threshold=self.max_features)
|
||||
|
||||
for batch_idx, batch in enumerate(X):
|
||||
for row in batch:
|
||||
# Update running statistics
|
||||
if len(feature_sums) < len(row):
|
||||
feature_sums.extend([0] * (len(row) - len(feature_sums)))
|
||||
feature_counts.extend([0] * (len(row) - len(feature_counts)))
|
||||
|
||||
for i, value in enumerate(row):
|
||||
feature_sums[i] += value
|
||||
feature_counts[i] += 1
|
||||
|
||||
# Checkpoint every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
yield {
|
||||
'batch_idx': batch_idx,
|
||||
'feature_sums': feature_sums,
|
||||
'feature_counts': feature_counts
|
||||
}
|
||||
|
||||
# Calculate means
|
||||
self.feature_means_ = []
|
||||
for i in range(len(feature_sums)):
|
||||
mean = feature_sums[i] / feature_counts[i] if feature_counts[i] > 0 else 0
|
||||
self.feature_means_.append(mean)
|
||||
self.feature_stats[f'mean_{i}'] = mean
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X: Iterator[np.ndarray]) -> Iterator[np.ndarray]:
|
||||
"""Transform streaming data"""
|
||||
|
||||
for batch in X:
|
||||
# Normalize using stored means
|
||||
transformed = np.array(batch)
|
||||
for i, mean in enumerate(self.feature_means_):
|
||||
transformed[:, i] -= mean
|
||||
|
||||
yield transformed
|
||||
|
||||
|
||||
class MemoryEfficientMLPipeline:
|
||||
"""Complete ML pipeline with memory management"""
|
||||
|
||||
def __init__(self, memory_limit: str = "512MB"):
|
||||
self.memory_monitor = MemoryPressureMonitor(memory_limit)
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
self.feature_extractor = SpaceTimeFeatureExtractor()
|
||||
self.model = RandomForestClassifier(n_estimators=100, n_jobs=-1)
|
||||
self.optimizer = SpaceTimeOptimizer(
|
||||
memory_limit=memory_limit,
|
||||
checkpoint_frequency=100
|
||||
)
|
||||
|
||||
@profile_memory(threshold_mb=256)
|
||||
def load_data_streaming(self, file_path: str, chunk_size: int = 1000) -> Iterator:
|
||||
"""Load large dataset in memory-efficient chunks"""
|
||||
|
||||
print(f"Loading data from {file_path} in chunks of {chunk_size}...")
|
||||
|
||||
# Simulate loading large CSV in chunks
|
||||
for chunk_idx, chunk in enumerate(pd.read_csv(file_path, chunksize=chunk_size)):
|
||||
# Convert to numpy array
|
||||
X = chunk.drop('target', axis=1).values
|
||||
y = chunk['target'].values
|
||||
|
||||
# Check memory pressure
|
||||
if self.memory_monitor.should_cleanup():
|
||||
print(f"Memory pressure detected at chunk {chunk_idx}, triggering cleanup")
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
yield X, y
|
||||
|
||||
def preprocess_with_external_sort(self, data_iterator: Iterator) -> Tuple[SpaceTimeArray, SpaceTimeArray]:
|
||||
"""Preprocess and sort data using external algorithms"""
|
||||
|
||||
print("Preprocessing data with external sorting...")
|
||||
|
||||
X_all = SpaceTimeArray(threshold=10000)
|
||||
y_all = SpaceTimeArray(threshold=10000)
|
||||
|
||||
# Collect all data
|
||||
for X_batch, y_batch in data_iterator:
|
||||
X_all.extend(X_batch.tolist())
|
||||
y_all.extend(y_batch.tolist())
|
||||
|
||||
# Sort by target value for stratified splitting
|
||||
print(f"Sorting {len(y_all)} samples by target value...")
|
||||
|
||||
# Create index pairs
|
||||
indexed_data = [(i, y) for i, y in enumerate(y_all)]
|
||||
|
||||
# External sort by target value
|
||||
sorted_indices = external_sort(
|
||||
indexed_data,
|
||||
key_func=lambda x: x[1]
|
||||
)
|
||||
|
||||
# Reorder data
|
||||
X_sorted = SpaceTimeArray(threshold=10000)
|
||||
y_sorted = SpaceTimeArray(threshold=10000)
|
||||
|
||||
for idx, _ in sorted_indices:
|
||||
X_sorted.append(X_all[idx])
|
||||
y_sorted.append(y_all[idx])
|
||||
|
||||
return X_sorted, y_sorted
|
||||
|
||||
def extract_features_checkpointed(self, X: SpaceTimeArray) -> SpaceTimeArray:
|
||||
"""Extract features with checkpointing"""
|
||||
|
||||
print("Extracting features with checkpointing...")
|
||||
|
||||
job_id = f"feature_extraction_{int(time.time())}"
|
||||
|
||||
# Check for existing checkpoint
|
||||
checkpoint = self.checkpoint_manager.restore(job_id)
|
||||
start_idx = checkpoint.get('last_idx', 0) if checkpoint else 0
|
||||
|
||||
features = SpaceTimeArray(threshold=10000)
|
||||
|
||||
# Load partial results if resuming
|
||||
if checkpoint and 'features' in checkpoint:
|
||||
features = checkpoint['features']
|
||||
|
||||
# Process in batches
|
||||
batch_size = 100
|
||||
for i in range(start_idx, len(X), batch_size):
|
||||
batch = X[i:i + batch_size]
|
||||
|
||||
# Simulate feature extraction
|
||||
batch_features = []
|
||||
for sample in batch:
|
||||
# Example: statistical features
|
||||
features_dict = {
|
||||
'mean': np.mean(sample),
|
||||
'std': np.std(sample),
|
||||
'min': np.min(sample),
|
||||
'max': np.max(sample),
|
||||
'median': np.median(sample)
|
||||
}
|
||||
batch_features.append(list(features_dict.values()))
|
||||
|
||||
features.extend(batch_features)
|
||||
|
||||
# Checkpoint every 1000 samples
|
||||
if (i + batch_size) % 1000 == 0:
|
||||
self.checkpoint_manager.save(job_id, {
|
||||
'last_idx': i + batch_size,
|
||||
'features': features
|
||||
})
|
||||
print(f"Checkpoint saved at index {i + batch_size}")
|
||||
|
||||
# Clean up checkpoint
|
||||
self.checkpoint_manager.delete(job_id)
|
||||
|
||||
return features
|
||||
|
||||
@profile
|
||||
def train_with_memory_constraints(self, X: SpaceTimeArray, y: SpaceTimeArray):
|
||||
"""Train model with memory-aware batch processing"""
|
||||
|
||||
print("Training model with memory constraints...")
|
||||
|
||||
# Convert to numpy arrays in batches
|
||||
batch_size = min(1000, len(X))
|
||||
|
||||
for epoch in range(3): # Multiple epochs
|
||||
print(f"\nEpoch {epoch + 1}/3")
|
||||
|
||||
# Shuffle data
|
||||
indices = list(range(len(X)))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
# Train in mini-batches
|
||||
for i in range(0, len(X), batch_size):
|
||||
batch_indices = indices[i:i + batch_size]
|
||||
|
||||
X_batch = np.array([X[idx] for idx in batch_indices])
|
||||
y_batch = np.array([y[idx] for idx in batch_indices])
|
||||
|
||||
# Partial fit (for models that support it)
|
||||
if hasattr(self.model, 'partial_fit'):
|
||||
self.model.partial_fit(X_batch, y_batch)
|
||||
else:
|
||||
# For RandomForest, we'll fit on full data once
|
||||
if epoch == 0 and i == 0:
|
||||
# Collect all data for initial fit
|
||||
X_train = np.array(X.to_list())
|
||||
y_train = np.array(y.to_list())
|
||||
self.model.fit(X_train, y_train)
|
||||
break
|
||||
|
||||
# Check memory
|
||||
if self.memory_monitor.should_cleanup():
|
||||
import gc
|
||||
gc.collect()
|
||||
print(f"Memory cleanup at batch {i // batch_size}")
|
||||
|
||||
def evaluate_with_external_grouping(self, X: SpaceTimeArray, y: SpaceTimeArray) -> Dict[str, float]:
|
||||
"""Evaluate model using external grouping for metrics"""
|
||||
|
||||
print("Evaluating model performance...")
|
||||
|
||||
# Make predictions in batches
|
||||
predictions = SpaceTimeArray(threshold=10000)
|
||||
|
||||
batch_size = 1000
|
||||
for i in range(0, len(X), batch_size):
|
||||
X_batch = np.array(X[i:i + batch_size])
|
||||
y_pred = self.model.predict(X_batch)
|
||||
predictions.extend(y_pred.tolist())
|
||||
|
||||
# Group by actual vs predicted for confusion matrix
|
||||
results = []
|
||||
for i in range(len(y)):
|
||||
results.append({
|
||||
'actual': y[i],
|
||||
'predicted': predictions[i],
|
||||
'correct': y[i] == predictions[i]
|
||||
})
|
||||
|
||||
# Use external groupby for metrics
|
||||
accuracy_groups = external_groupby(
|
||||
results,
|
||||
key_func=lambda x: x['correct']
|
||||
)
|
||||
|
||||
correct_count = len(accuracy_groups.get(True, []))
|
||||
total_count = len(results)
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
||||
# Class-wise metrics
|
||||
class_groups = external_groupby(
|
||||
results,
|
||||
key_func=lambda x: (x['actual'], x['predicted'])
|
||||
)
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'total_samples': total_count,
|
||||
'correct_predictions': correct_count,
|
||||
'class_distribution': {str(k): len(v) for k, v in class_groups.items()}
|
||||
}
|
||||
|
||||
def save_model_checkpoint(self, path: str):
|
||||
"""Save model with metadata"""
|
||||
|
||||
checkpoint = {
|
||||
'model': self.model,
|
||||
'feature_extractor': self.feature_extractor,
|
||||
'metadata': {
|
||||
'timestamp': time.time(),
|
||||
'memory_limit': self.memory_monitor.memory_limit,
|
||||
'feature_stats': dict(self.feature_extractor.feature_stats)
|
||||
}
|
||||
}
|
||||
|
||||
joblib.dump(checkpoint, path)
|
||||
print(f"Model saved to {path}")
|
||||
|
||||
|
||||
def generate_synthetic_data(n_samples: int = 100000, n_features: int = 50):
|
||||
"""Generate synthetic dataset for demonstration"""
|
||||
|
||||
print(f"Generating synthetic dataset: {n_samples} samples, {n_features} features...")
|
||||
|
||||
# Generate in chunks to avoid memory issues
|
||||
chunk_size = 10000
|
||||
|
||||
with open('synthetic_data.csv', 'w') as f:
|
||||
# Write header
|
||||
headers = [f'feature_{i}' for i in range(n_features)] + ['target']
|
||||
f.write(','.join(headers) + '\n')
|
||||
|
||||
# Generate data in chunks
|
||||
for i in range(0, n_samples, chunk_size):
|
||||
chunk_samples = min(chunk_size, n_samples - i)
|
||||
|
||||
# Generate features
|
||||
X = np.random.randn(chunk_samples, n_features)
|
||||
|
||||
# Generate target (binary classification)
|
||||
# Target depends on sum of first 10 features
|
||||
y = (X[:, :10].sum(axis=1) > 0).astype(int)
|
||||
|
||||
# Write to CSV
|
||||
for j in range(chunk_samples):
|
||||
row = list(X[j]) + [y[j]]
|
||||
f.write(','.join(map(str, row)) + '\n')
|
||||
|
||||
if (i + chunk_size) % 50000 == 0:
|
||||
print(f"Generated {i + chunk_size} samples...")
|
||||
|
||||
print("Synthetic data generation complete!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run complete ML pipeline example"""
|
||||
|
||||
print("=== SqrtSpace SpaceTime ML Pipeline Example ===\n")
|
||||
|
||||
# Generate synthetic data
|
||||
generate_synthetic_data(n_samples=100000, n_features=50)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = MemoryEfficientMLPipeline(memory_limit="512MB")
|
||||
|
||||
# Load and preprocess data
|
||||
print("\n1. Loading data with streaming...")
|
||||
data_iterator = pipeline.load_data_streaming('synthetic_data.csv', chunk_size=5000)
|
||||
|
||||
print("\n2. Preprocessing with external sort...")
|
||||
X_sorted, y_sorted = pipeline.preprocess_with_external_sort(data_iterator)
|
||||
print(f"Loaded {len(X_sorted)} samples")
|
||||
|
||||
print("\n3. Extracting features with checkpointing...")
|
||||
X_features = pipeline.extract_features_checkpointed(X_sorted)
|
||||
|
||||
print("\n4. Training model with memory constraints...")
|
||||
# Split data (80/20)
|
||||
split_idx = int(0.8 * len(X_features))
|
||||
X_train = SpaceTimeArray(X_features[:split_idx])
|
||||
y_train = SpaceTimeArray(y_sorted[:split_idx])
|
||||
X_test = SpaceTimeArray(X_features[split_idx:])
|
||||
y_test = SpaceTimeArray(y_sorted[split_idx:])
|
||||
|
||||
pipeline.train_with_memory_constraints(X_train, y_train)
|
||||
|
||||
print("\n5. Evaluating with external grouping...")
|
||||
metrics = pipeline.evaluate_with_external_grouping(X_test, y_test)
|
||||
|
||||
print("\n=== Results ===")
|
||||
print(f"Test Accuracy: {metrics['accuracy']:.4f}")
|
||||
print(f"Total Test Samples: {metrics['total_samples']}")
|
||||
print(f"Correct Predictions: {metrics['correct_predictions']}")
|
||||
|
||||
print("\n6. Saving model checkpoint...")
|
||||
pipeline.save_model_checkpoint('spacetime_model.joblib')
|
||||
|
||||
# Memory statistics
|
||||
print("\n=== Memory Statistics ===")
|
||||
memory_info = pipeline.memory_monitor.get_memory_info()
|
||||
print(f"Peak Memory Usage: {memory_info['peak_mb']:.2f} MB")
|
||||
print(f"Current Memory Usage: {memory_info['used_mb']:.2f} MB")
|
||||
print(f"Memory Limit: {memory_info['limit_mb']:.2f} MB")
|
||||
|
||||
print("\n=== Pipeline Complete! ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
95
pyproject.toml
Normal file
95
pyproject.toml
Normal file
@ -0,0 +1,95 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sqrtspace-spacetime"
|
||||
version = "0.1.0"
|
||||
authors = [
|
||||
{name = "David H. Friedel Jr.", email = "dfriedel@marketally.com"},
|
||||
{name = "SqrtSpace Contributors"}
|
||||
]
|
||||
description = "Memory-efficient algorithms and data structures using Williams' √n space-time tradeoffs"
|
||||
readme = "README.md"
|
||||
license = {text = "Apache-2.0"}
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: System :: Archiving :: Compression",
|
||||
"Topic :: Database",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
keywords = ["memory", "efficiency", "algorithms", "spacetime", "external-memory", "streaming"]
|
||||
dependencies = [
|
||||
"numpy>=1.20.0",
|
||||
"psutil>=5.8.0",
|
||||
"aiofiles>=0.8.0",
|
||||
"tqdm>=4.62.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.20.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"black>=22.0.0",
|
||||
"flake8>=5.0.0",
|
||||
"mypy>=0.990",
|
||||
"sphinx>=5.0.0",
|
||||
"sphinx-rtd-theme>=1.0.0",
|
||||
]
|
||||
pandas = ["pandas>=1.3.0"]
|
||||
dask = ["dask[complete]>=2022.1.0"]
|
||||
ray = ["ray>=2.0.0"]
|
||||
all = ["sqrtspace-spacetime[pandas,dask,ray]"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/sqrtspace/sqrtspace-python"
|
||||
Documentation = "https://sqrtspace-spacetime.readthedocs.io"
|
||||
Repository = "https://github.com/sqrtspace/sqrtspace-python.git"
|
||||
Issues = "https://github.com/sqrtspace/sqrtspace-python/issues"
|
||||
|
||||
[project.scripts]
|
||||
spacetime = "sqrtspace_spacetime.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
sqrtspace_spacetime = ["py.typed"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py38']
|
||||
include = '\.pyi?$'
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
python_functions = ["test_*"]
|
||||
python_classes = ["Test*"]
|
||||
addopts = "-v --cov=sqrtspace_spacetime --cov-report=html --cov-report=term"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src/sqrtspace_spacetime"]
|
||||
omit = ["*/tests/*", "*/__init__.py"]
|
||||
|
||||
[tool.coverage.report]
|
||||
precision = 2
|
||||
show_missing = true
|
||||
skip_covered = false
|
||||
31
requirements-dev.txt
Normal file
31
requirements-dev.txt
Normal file
@ -0,0 +1,31 @@
|
||||
# Development dependencies
|
||||
-r requirements.txt
|
||||
|
||||
# Testing
|
||||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.20.0
|
||||
pytest-cov>=4.0.0
|
||||
pytest-xdist>=3.0.0
|
||||
|
||||
# Code quality
|
||||
black>=22.0.0
|
||||
flake8>=5.0.0
|
||||
mypy>=0.990
|
||||
isort>=5.10.0
|
||||
|
||||
# Documentation
|
||||
sphinx>=5.0.0
|
||||
sphinx-rtd-theme>=1.0.0
|
||||
sphinx-autodoc-typehints>=1.19.0
|
||||
|
||||
# ML frameworks (optional)
|
||||
torch>=1.10.0
|
||||
tensorflow>=2.8.0
|
||||
|
||||
# Visualization (optional)
|
||||
matplotlib>=3.5.0
|
||||
seaborn>=0.11.0
|
||||
|
||||
# Build tools
|
||||
build>=0.8.0
|
||||
twine>=4.0.0
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
# Core dependencies
|
||||
numpy>=1.20.0
|
||||
psutil>=5.8.0
|
||||
aiofiles>=0.8.0
|
||||
tqdm>=4.62.0
|
||||
18
setup.py
Normal file
18
setup.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Setup script for SqrtSpace SpaceTime.
|
||||
|
||||
This is a compatibility shim for older pip versions.
|
||||
The actual package configuration is in pyproject.toml.
|
||||
"""
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
# Read the contents of README file
|
||||
from pathlib import Path
|
||||
this_directory = Path(__file__).parent
|
||||
long_description = (this_directory / "README.md").read_text(encoding='utf-8')
|
||||
|
||||
setup(
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
)
|
||||
31
src/sqrtspace_spacetime/__init__.py
Normal file
31
src/sqrtspace_spacetime/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Ubiquity SpaceTime: Memory-efficient algorithms using √n space-time tradeoffs.
|
||||
|
||||
This package implements Williams' theoretical computer science results showing
|
||||
that many algorithms can achieve better memory usage by accepting slightly
|
||||
slower runtime.
|
||||
"""
|
||||
|
||||
from sqrtspace_spacetime.config import SpaceTimeConfig
|
||||
from sqrtspace_spacetime.collections import SpaceTimeArray, SpaceTimeDict
|
||||
from sqrtspace_spacetime.algorithms import external_sort, external_groupby
|
||||
from sqrtspace_spacetime.streams import Stream
|
||||
from sqrtspace_spacetime.memory import MemoryMonitor, MemoryPressureLevel
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "Ubiquity SpaceTime Contributors"
|
||||
__license__ = "Apache-2.0"
|
||||
|
||||
__all__ = [
|
||||
"SpaceTimeConfig",
|
||||
"SpaceTimeArray",
|
||||
"SpaceTimeDict",
|
||||
"external_sort",
|
||||
"external_groupby",
|
||||
"Stream",
|
||||
"MemoryMonitor",
|
||||
"MemoryPressureLevel",
|
||||
]
|
||||
|
||||
# Configure default settings
|
||||
SpaceTimeConfig.set_defaults()
|
||||
9
src/sqrtspace_spacetime/algorithms/__init__.py
Normal file
9
src/sqrtspace_spacetime/algorithms/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""External memory algorithms using √n space-time tradeoffs."""
|
||||
|
||||
from sqrtspace_spacetime.algorithms.external_sort import external_sort
|
||||
from sqrtspace_spacetime.algorithms.external_groupby import external_groupby
|
||||
|
||||
__all__ = [
|
||||
"external_sort",
|
||||
"external_groupby",
|
||||
]
|
||||
265
src/sqrtspace_spacetime/algorithms/external_groupby.py
Normal file
265
src/sqrtspace_spacetime/algorithms/external_groupby.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""
|
||||
External group-by algorithm using √n memory.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
|
||||
from collections import defaultdict
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.collections import SpaceTimeDict
|
||||
|
||||
T = TypeVar('T')
|
||||
K = TypeVar('K')
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
class GroupByStrategy(Enum):
|
||||
"""Group-by strategies."""
|
||||
HASH_BASED = "hash_based"
|
||||
SORT_BASED = "sort_based"
|
||||
ADAPTIVE = "adaptive"
|
||||
|
||||
|
||||
def external_groupby(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
strategy: GroupByStrategy = GroupByStrategy.ADAPTIVE,
|
||||
storage_path: Optional[str] = None
|
||||
) -> Dict[K, List[T]]:
|
||||
"""
|
||||
Group data by key using external memory.
|
||||
|
||||
Args:
|
||||
data: Iterable of items to group
|
||||
key_func: Function to extract group key
|
||||
strategy: Grouping strategy
|
||||
storage_path: Path for temporary storage
|
||||
|
||||
Returns:
|
||||
Dictionary mapping keys to lists of items
|
||||
"""
|
||||
storage_path = storage_path or config.external_storage_path
|
||||
|
||||
# Convert to list to get size
|
||||
if not isinstance(data, list):
|
||||
data = list(data)
|
||||
|
||||
n = len(data)
|
||||
|
||||
# Small datasets can be grouped in memory
|
||||
if n <= 10000:
|
||||
result = defaultdict(list)
|
||||
for item in data:
|
||||
result[key_func(item)].append(item)
|
||||
return dict(result)
|
||||
|
||||
# Choose strategy
|
||||
if strategy == GroupByStrategy.ADAPTIVE:
|
||||
strategy = _choose_groupby_strategy(data, key_func)
|
||||
|
||||
if strategy == GroupByStrategy.HASH_BASED:
|
||||
return _hash_based_groupby(data, key_func, storage_path)
|
||||
else:
|
||||
return _sort_based_groupby(data, key_func, storage_path)
|
||||
|
||||
|
||||
def external_groupby_aggregate(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
value_func: Callable[[T], V],
|
||||
agg_func: Callable[[V, V], V],
|
||||
initial: Optional[V] = None,
|
||||
storage_path: Optional[str] = None
|
||||
) -> Dict[K, V]:
|
||||
"""
|
||||
Group by with aggregation using external memory.
|
||||
|
||||
Args:
|
||||
data: Iterable of items
|
||||
key_func: Function to extract group key
|
||||
value_func: Function to extract value for aggregation
|
||||
agg_func: Aggregation function (e.g., sum, max)
|
||||
initial: Initial value for aggregation
|
||||
storage_path: Path for temporary storage
|
||||
|
||||
Returns:
|
||||
Dictionary mapping keys to aggregated values
|
||||
"""
|
||||
# Use SpaceTimeDict for memory-efficient aggregation
|
||||
result = SpaceTimeDict(storage_path=storage_path)
|
||||
|
||||
for item in data:
|
||||
key = key_func(item)
|
||||
value = value_func(item)
|
||||
|
||||
if key in result:
|
||||
result[key] = agg_func(result[key], value)
|
||||
else:
|
||||
result[key] = value if initial is None else agg_func(initial, value)
|
||||
|
||||
# Convert to regular dict by creating a list first to avoid mutation issues
|
||||
return {k: v for k, v in list(result.items())}
|
||||
|
||||
|
||||
def _choose_groupby_strategy(data: List[T], key_func: Callable[[T], K]) -> GroupByStrategy:
|
||||
"""Choose grouping strategy based on data characteristics."""
|
||||
# Sample keys to estimate cardinality
|
||||
sample_size = min(1000, len(data))
|
||||
sample_keys = set()
|
||||
|
||||
for i in range(0, len(data), max(1, len(data) // sample_size)):
|
||||
sample_keys.add(key_func(data[i]))
|
||||
|
||||
estimated_groups = len(sample_keys) * (len(data) / sample_size)
|
||||
|
||||
# If few groups relative to data size, use hash-based
|
||||
if estimated_groups < len(data) / 10:
|
||||
return GroupByStrategy.HASH_BASED
|
||||
else:
|
||||
return GroupByStrategy.SORT_BASED
|
||||
|
||||
|
||||
def _hash_based_groupby(
|
||||
data: List[T],
|
||||
key_func: Callable[[T], K],
|
||||
storage_path: str
|
||||
) -> Dict[K, List[T]]:
|
||||
"""
|
||||
Hash-based grouping with spillover to disk.
|
||||
"""
|
||||
chunk_size = config.calculate_chunk_size(len(data))
|
||||
|
||||
# Use SpaceTimeDict for groups
|
||||
groups = SpaceTimeDict(threshold=chunk_size // 10, storage_path=storage_path)
|
||||
|
||||
for item in data:
|
||||
key = key_func(item)
|
||||
|
||||
if key in groups:
|
||||
group = groups[key]
|
||||
group.append(item)
|
||||
groups[key] = group
|
||||
else:
|
||||
groups[key] = [item]
|
||||
|
||||
# Convert to regular dict
|
||||
return dict(groups.items())
|
||||
|
||||
|
||||
def _sort_based_groupby(
|
||||
data: List[T],
|
||||
key_func: Callable[[T], K],
|
||||
storage_path: str
|
||||
) -> Dict[K, List[T]]:
|
||||
"""
|
||||
Sort-based grouping.
|
||||
"""
|
||||
from sqrtspace_spacetime.algorithms.external_sort import external_sort_key
|
||||
|
||||
# Sort by group key
|
||||
sorted_data = external_sort_key(data, key=key_func, storage_path=storage_path)
|
||||
|
||||
# Group consecutive items
|
||||
result = {}
|
||||
current_key = None
|
||||
current_group = []
|
||||
|
||||
for item in sorted_data:
|
||||
item_key = key_func(item)
|
||||
|
||||
if item_key != current_key:
|
||||
if current_key is not None:
|
||||
result[current_key] = current_group
|
||||
current_key = item_key
|
||||
current_group = [item]
|
||||
else:
|
||||
current_group.append(item)
|
||||
|
||||
# Don't forget the last group
|
||||
if current_key is not None:
|
||||
result[current_key] = current_group
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Convenience functions for common aggregations
|
||||
|
||||
def groupby_count(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K]
|
||||
) -> Dict[K, int]:
|
||||
"""Count items by group."""
|
||||
return external_groupby_aggregate(
|
||||
data,
|
||||
key_func,
|
||||
lambda x: 1,
|
||||
lambda a, b: a + b,
|
||||
initial=0
|
||||
)
|
||||
|
||||
|
||||
def groupby_sum(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
value_func: Callable[[T], float]
|
||||
) -> Dict[K, float]:
|
||||
"""Sum values by group."""
|
||||
return external_groupby_aggregate(
|
||||
data,
|
||||
key_func,
|
||||
value_func,
|
||||
lambda a, b: a + b,
|
||||
initial=0.0
|
||||
)
|
||||
|
||||
|
||||
def groupby_avg(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
value_func: Callable[[T], float]
|
||||
) -> Dict[K, float]:
|
||||
"""Average values by group."""
|
||||
# First get sums and counts
|
||||
sums = defaultdict(float)
|
||||
counts = defaultdict(int)
|
||||
|
||||
for item in data:
|
||||
key = key_func(item)
|
||||
value = value_func(item)
|
||||
sums[key] += value
|
||||
counts[key] += 1
|
||||
|
||||
# Calculate averages
|
||||
return {key: sums[key] / counts[key] for key in sums}
|
||||
|
||||
|
||||
def groupby_max(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
value_func: Callable[[T], V]
|
||||
) -> Dict[K, V]:
|
||||
"""Get maximum value by group."""
|
||||
return external_groupby_aggregate(
|
||||
data,
|
||||
key_func,
|
||||
value_func,
|
||||
max
|
||||
)
|
||||
|
||||
|
||||
def groupby_min(
|
||||
data: Iterable[T],
|
||||
key_func: Callable[[T], K],
|
||||
value_func: Callable[[T], V]
|
||||
) -> Dict[K, V]:
|
||||
"""Get minimum value by group."""
|
||||
return external_groupby_aggregate(
|
||||
data,
|
||||
key_func,
|
||||
value_func,
|
||||
min
|
||||
)
|
||||
330
src/sqrtspace_spacetime/algorithms/external_sort.py
Normal file
330
src/sqrtspace_spacetime/algorithms/external_sort.py
Normal file
@ -0,0 +1,330 @@
|
||||
"""
|
||||
External sorting algorithm using √n memory.
|
||||
"""
|
||||
|
||||
import os
|
||||
import heapq
|
||||
import pickle
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.memory import monitor
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SortStrategy(Enum):
|
||||
"""Sorting strategies."""
|
||||
MULTIWAY_MERGE = "multiway_merge"
|
||||
QUICKSORT_EXTERNAL = "quicksort_external"
|
||||
ADAPTIVE = "adaptive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SortRun:
|
||||
"""A sorted run on disk."""
|
||||
filename: str
|
||||
count: int
|
||||
min_value: Any
|
||||
max_value: Any
|
||||
|
||||
|
||||
def external_sort(
|
||||
data: Iterable[T],
|
||||
reverse: bool = False,
|
||||
strategy: SortStrategy = SortStrategy.ADAPTIVE,
|
||||
storage_path: Optional[str] = None
|
||||
) -> List[T]:
|
||||
"""
|
||||
Sort data using external memory with √n space complexity.
|
||||
|
||||
Args:
|
||||
data: Iterable of items to sort
|
||||
reverse: Sort in descending order
|
||||
strategy: Sorting strategy to use
|
||||
storage_path: Path for temporary files
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
return external_sort_key(
|
||||
data,
|
||||
key=lambda x: x,
|
||||
reverse=reverse,
|
||||
strategy=strategy,
|
||||
storage_path=storage_path
|
||||
)
|
||||
|
||||
|
||||
def external_sort_key(
|
||||
data: Iterable[T],
|
||||
key: Callable[[T], Any],
|
||||
reverse: bool = False,
|
||||
strategy: SortStrategy = SortStrategy.ADAPTIVE,
|
||||
storage_path: Optional[str] = None
|
||||
) -> List[T]:
|
||||
"""
|
||||
Sort data by key using external memory.
|
||||
|
||||
Args:
|
||||
data: Iterable of items to sort
|
||||
key: Function to extract sort key
|
||||
reverse: Sort in descending order
|
||||
strategy: Sorting strategy to use
|
||||
storage_path: Path for temporary files
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
storage_path = storage_path or config.external_storage_path
|
||||
|
||||
# Convert to list if needed to get size
|
||||
if not isinstance(data, list):
|
||||
data = list(data)
|
||||
|
||||
n = len(data)
|
||||
|
||||
# Small datasets can be sorted in memory
|
||||
if n <= 10000:
|
||||
return sorted(data, key=key, reverse=reverse)
|
||||
|
||||
# Choose strategy
|
||||
if strategy == SortStrategy.ADAPTIVE:
|
||||
strategy = _choose_strategy(n)
|
||||
|
||||
if strategy == SortStrategy.MULTIWAY_MERGE:
|
||||
return _multiway_merge_sort(data, key, reverse, storage_path)
|
||||
else:
|
||||
return _external_quicksort(data, key, reverse, storage_path)
|
||||
|
||||
|
||||
def _choose_strategy(n: int) -> SortStrategy:
|
||||
"""Choose best strategy based on data size."""
|
||||
# For very large datasets, multiway merge is more stable
|
||||
if n > 1_000_000:
|
||||
return SortStrategy.MULTIWAY_MERGE
|
||||
else:
|
||||
return SortStrategy.QUICKSORT_EXTERNAL
|
||||
|
||||
|
||||
def _multiway_merge_sort(
|
||||
data: List[T],
|
||||
key: Callable[[T], Any],
|
||||
reverse: bool,
|
||||
storage_path: str
|
||||
) -> List[T]:
|
||||
"""
|
||||
Multiway merge sort implementation.
|
||||
"""
|
||||
n = len(data)
|
||||
chunk_size = config.calculate_chunk_size(n)
|
||||
|
||||
# Phase 1: Create sorted runs
|
||||
runs = []
|
||||
temp_files = []
|
||||
|
||||
for i in range(0, n, chunk_size):
|
||||
chunk = data[i:i + chunk_size]
|
||||
|
||||
# Sort chunk in memory
|
||||
chunk.sort(key=key, reverse=reverse)
|
||||
|
||||
# Write to disk
|
||||
fd, filename = tempfile.mkstemp(suffix='.run', dir=storage_path)
|
||||
os.close(fd)
|
||||
temp_files.append(filename)
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(chunk, f)
|
||||
|
||||
# Track run info
|
||||
runs.append(SortRun(
|
||||
filename=filename,
|
||||
count=len(chunk),
|
||||
min_value=key(chunk[0]),
|
||||
max_value=key(chunk[-1])
|
||||
))
|
||||
|
||||
# Phase 2: Merge runs
|
||||
try:
|
||||
result = _merge_runs(runs, key, reverse)
|
||||
return result
|
||||
finally:
|
||||
# Cleanup
|
||||
for filename in temp_files:
|
||||
if os.path.exists(filename):
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
def _merge_runs(
|
||||
runs: List[SortRun],
|
||||
key: Callable[[T], Any],
|
||||
reverse: bool
|
||||
) -> List[T]:
|
||||
"""
|
||||
Merge sorted runs using a k-way merge.
|
||||
"""
|
||||
# Open all run files
|
||||
run_iters = []
|
||||
for run in runs:
|
||||
with open(run.filename, 'rb') as f:
|
||||
items = pickle.load(f)
|
||||
run_iters.append(iter(items))
|
||||
|
||||
# Create heap for merge
|
||||
heap = []
|
||||
|
||||
# Initialize heap with first item from each run
|
||||
for i, run_iter in enumerate(run_iters):
|
||||
try:
|
||||
item = next(run_iter)
|
||||
# For reverse sort, negate the key
|
||||
heap_key = key(item)
|
||||
if reverse:
|
||||
heap_key = _negate_key(heap_key)
|
||||
heapq.heappush(heap, (heap_key, i, item, run_iter))
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# Merge
|
||||
result = []
|
||||
while heap:
|
||||
heap_key, run_idx, item, run_iter = heapq.heappop(heap)
|
||||
result.append(item)
|
||||
|
||||
# Get next item from same run
|
||||
try:
|
||||
next_item = next(run_iter)
|
||||
next_key = key(next_item)
|
||||
if reverse:
|
||||
next_key = _negate_key(next_key)
|
||||
heapq.heappush(heap, (next_key, run_idx, next_item, run_iter))
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _negate_key(key: Any) -> Any:
|
||||
"""Negate a key for reverse sorting."""
|
||||
if isinstance(key, (int, float)):
|
||||
return -key
|
||||
elif isinstance(key, str):
|
||||
# For strings, return a wrapper that reverses comparison
|
||||
return _ReverseString(key)
|
||||
else:
|
||||
# For other types, use a generic wrapper
|
||||
return _ReverseWrapper(key)
|
||||
|
||||
|
||||
class _ReverseString:
|
||||
"""Wrapper for reverse string comparison."""
|
||||
def __init__(self, s: str):
|
||||
self.s = s
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.s > other.s
|
||||
|
||||
def __le__(self, other):
|
||||
return self.s >= other.s
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.s < other.s
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.s <= other.s
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.s == other.s
|
||||
|
||||
|
||||
class _ReverseWrapper:
|
||||
"""Generic wrapper for reverse comparison."""
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.obj > other.obj
|
||||
|
||||
def __le__(self, other):
|
||||
return self.obj >= other.obj
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.obj < other.obj
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.obj <= other.obj
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.obj == other.obj
|
||||
|
||||
|
||||
def _external_quicksort(
|
||||
data: List[T],
|
||||
key: Callable[[T], Any],
|
||||
reverse: bool,
|
||||
storage_path: str
|
||||
) -> List[T]:
|
||||
"""
|
||||
External quicksort implementation.
|
||||
|
||||
This is a simplified version that partitions data and
|
||||
recursively sorts partitions that fit in memory.
|
||||
"""
|
||||
n = len(data)
|
||||
chunk_size = config.calculate_chunk_size(n)
|
||||
|
||||
if n <= chunk_size:
|
||||
# Base case: sort in memory
|
||||
return sorted(data, key=key, reverse=reverse)
|
||||
|
||||
# Choose pivot (median of three)
|
||||
pivot_idx = _choose_pivot(data, key)
|
||||
pivot_key = key(data[pivot_idx])
|
||||
|
||||
# Partition data
|
||||
less = []
|
||||
equal = []
|
||||
greater = []
|
||||
|
||||
for item in data:
|
||||
item_key = key(item)
|
||||
if item_key < pivot_key:
|
||||
less.append(item)
|
||||
elif item_key == pivot_key:
|
||||
equal.append(item)
|
||||
else:
|
||||
greater.append(item)
|
||||
|
||||
# Recursively sort partitions
|
||||
sorted_less = _external_quicksort(less, key, reverse, storage_path)
|
||||
sorted_greater = _external_quicksort(greater, key, reverse, storage_path)
|
||||
|
||||
# Combine results
|
||||
if reverse:
|
||||
return sorted_greater + equal + sorted_less
|
||||
else:
|
||||
return sorted_less + equal + sorted_greater
|
||||
|
||||
|
||||
def _choose_pivot(data: List[T], key: Callable[[T], Any]) -> int:
|
||||
"""Choose a good pivot using median-of-three."""
|
||||
n = len(data)
|
||||
|
||||
# Sample three elements
|
||||
first = 0
|
||||
middle = n // 2
|
||||
last = n - 1
|
||||
|
||||
# Find median
|
||||
a, b, c = key(data[first]), key(data[middle]), key(data[last])
|
||||
|
||||
if a <= b <= c or c <= b <= a:
|
||||
return middle
|
||||
elif b <= a <= c or c <= a <= b:
|
||||
return first
|
||||
else:
|
||||
return last
|
||||
7
src/sqrtspace_spacetime/checkpoint/__init__.py
Normal file
7
src/sqrtspace_spacetime/checkpoint/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Auto-checkpoint framework for long-running computations."""
|
||||
|
||||
from sqrtspace_spacetime.checkpoint.decorators import auto_checkpoint
|
||||
|
||||
__all__ = [
|
||||
"auto_checkpoint",
|
||||
]
|
||||
295
src/sqrtspace_spacetime/checkpoint/decorators.py
Normal file
295
src/sqrtspace_spacetime/checkpoint/decorators.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Decorators for automatic checkpointing.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from sqrtspace_spacetime.checkpoint.manager import (
|
||||
CheckpointManager,
|
||||
CheckpointConfig,
|
||||
CheckpointStrategy
|
||||
)
|
||||
|
||||
|
||||
def auto_checkpoint(
|
||||
total_iterations: Optional[int] = None,
|
||||
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE,
|
||||
checkpoint_vars: Optional[List[str]] = None,
|
||||
checkpoint_dir: str = ".checkpoints",
|
||||
verbose: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to automatically checkpoint long-running functions.
|
||||
|
||||
Args:
|
||||
total_iterations: Total iterations (for √n strategy)
|
||||
strategy: Checkpointing strategy
|
||||
checkpoint_vars: Variables to checkpoint (None for auto-detect)
|
||||
checkpoint_dir: Directory for checkpoints
|
||||
verbose: Print checkpoint info
|
||||
|
||||
Example:
|
||||
@auto_checkpoint(total_iterations=1000000)
|
||||
def process_data(data):
|
||||
for i, item in enumerate(data):
|
||||
# Process item
|
||||
checkpoint_state = {'i': i, 'processed': processed}
|
||||
yield checkpoint_state
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Create checkpoint manager
|
||||
config = CheckpointConfig(
|
||||
strategy=strategy,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
verbose=verbose
|
||||
)
|
||||
manager = CheckpointManager(config=config)
|
||||
|
||||
if total_iterations:
|
||||
manager.set_total_iterations(total_iterations)
|
||||
|
||||
# Check if resuming from checkpoint
|
||||
resume_checkpoint = kwargs.pop('resume_checkpoint', None)
|
||||
if resume_checkpoint:
|
||||
state, metadata = manager.load(resume_checkpoint)
|
||||
print(f"Resuming from checkpoint at iteration {metadata.iteration}")
|
||||
# Update function state
|
||||
if 'update_state' in kwargs:
|
||||
kwargs['update_state'](state)
|
||||
|
||||
# Wrap generator functions
|
||||
if inspect.isgeneratorfunction(func):
|
||||
return _checkpoint_generator(func, manager, checkpoint_vars,
|
||||
*args, **kwargs)
|
||||
else:
|
||||
# For regular functions, checkpoint based on time/memory
|
||||
result = None
|
||||
for i in range(total_iterations or 1):
|
||||
if manager.should_checkpoint(i):
|
||||
# Get state from function
|
||||
if hasattr(func, 'get_checkpoint_state'):
|
||||
state = func.get_checkpoint_state()
|
||||
else:
|
||||
state = {'iteration': i, 'args': args, 'kwargs': kwargs}
|
||||
|
||||
manager.save(state)
|
||||
|
||||
# Execute function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Break if function doesn't need iterations
|
||||
if total_iterations is None:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
# Store checkpoint info on function
|
||||
wrapper.checkpoint_manager = None
|
||||
wrapper.checkpoint_config = CheckpointConfig(
|
||||
strategy=strategy,
|
||||
checkpoint_dir=checkpoint_dir
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def checkpoint_method(
|
||||
checkpoint_attrs: Optional[List[str]] = None,
|
||||
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator for checkpointing class methods.
|
||||
|
||||
Args:
|
||||
checkpoint_attrs: Instance attributes to checkpoint
|
||||
strategy: Checkpointing strategy
|
||||
|
||||
Example:
|
||||
class DataProcessor:
|
||||
@checkpoint_method(checkpoint_attrs=['processed_count', 'results'])
|
||||
def process_batch(self, batch):
|
||||
for item in batch:
|
||||
self.process_item(item)
|
||||
self.processed_count += 1
|
||||
"""
|
||||
def decorator(method: Callable) -> Callable:
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# Get or create checkpoint manager
|
||||
if not hasattr(self, '_checkpoint_manager'):
|
||||
config = CheckpointConfig(strategy=strategy)
|
||||
self._checkpoint_manager = CheckpointManager(config=config)
|
||||
|
||||
# Execute method with checkpointing
|
||||
if inspect.isgeneratorfunction(method):
|
||||
return _checkpoint_method_generator(
|
||||
method, self, self._checkpoint_manager,
|
||||
checkpoint_attrs, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
# Regular method
|
||||
result = method(self, *args, **kwargs)
|
||||
|
||||
# Check if checkpoint needed
|
||||
if self._checkpoint_manager.should_checkpoint():
|
||||
state = _get_instance_state(self, checkpoint_attrs)
|
||||
self._checkpoint_manager.save(state)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def resumable(
|
||||
checkpoint_dir: str = ".checkpoints",
|
||||
auto_resume: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Make function resumable from checkpoints.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory for checkpoints
|
||||
auto_resume: Automatically resume from latest checkpoint
|
||||
|
||||
Example:
|
||||
@resumable()
|
||||
def long_computation():
|
||||
for i in range(1000000):
|
||||
# Computation
|
||||
if should_checkpoint(i):
|
||||
save_checkpoint({'i': i, 'state': state})
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Create checkpoint manager
|
||||
manager = CheckpointManager(
|
||||
checkpoint_id=f"{func.__module__}.{func.__name__}",
|
||||
config=CheckpointConfig(checkpoint_dir=checkpoint_dir)
|
||||
)
|
||||
|
||||
# Check for existing checkpoints
|
||||
checkpoints = manager.list_checkpoints()
|
||||
|
||||
if checkpoints and auto_resume:
|
||||
latest = checkpoints[-1]
|
||||
print(f"Found checkpoint at iteration {latest.iteration}")
|
||||
|
||||
# Resume from checkpoint
|
||||
state, metadata = manager.load()
|
||||
|
||||
# Call function with resume state
|
||||
return func(*args, resume_state=state, resume_iteration=metadata.iteration, **kwargs)
|
||||
else:
|
||||
# Normal execution
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Add checkpoint methods to function
|
||||
wrapper.save_checkpoint = lambda state: manager.save(state)
|
||||
wrapper.list_checkpoints = lambda: manager.list_checkpoints()
|
||||
wrapper.cleanup_checkpoints = lambda: manager.cleanup()
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _checkpoint_generator(func: Callable, manager: CheckpointManager,
|
||||
checkpoint_vars: Optional[List[str]],
|
||||
*args, **kwargs):
|
||||
"""Handle checkpointing for generator functions."""
|
||||
generator = func(*args, **kwargs)
|
||||
iteration = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get next value
|
||||
if iteration == 0 and 'resume_state' in kwargs:
|
||||
# Skip to resume point
|
||||
resume_iter = kwargs['resume_state'].get('iteration', 0)
|
||||
for _ in range(resume_iter):
|
||||
next(generator)
|
||||
iteration = resume_iter
|
||||
|
||||
value = next(generator)
|
||||
|
||||
# Check if checkpoint needed
|
||||
if manager.should_checkpoint(iteration):
|
||||
# Get state
|
||||
if isinstance(value, dict):
|
||||
state = value
|
||||
else:
|
||||
state = {'iteration': iteration, 'value': value}
|
||||
|
||||
# Add checkpoint vars if specified
|
||||
if checkpoint_vars:
|
||||
frame = inspect.currentframe().f_back
|
||||
for var in checkpoint_vars:
|
||||
if var in frame.f_locals:
|
||||
state[var] = frame.f_locals[var]
|
||||
|
||||
manager.save(state)
|
||||
|
||||
yield value
|
||||
iteration += 1
|
||||
|
||||
except StopIteration:
|
||||
pass
|
||||
finally:
|
||||
if manager.config.verbose:
|
||||
stats = manager.get_stats()
|
||||
print(f"\nCheckpoint stats: {stats.total_checkpoints} checkpoints, "
|
||||
f"{stats.average_compression:.1f}x compression")
|
||||
|
||||
|
||||
def _checkpoint_method_generator(method: Callable, instance: Any,
|
||||
manager: CheckpointManager,
|
||||
checkpoint_attrs: Optional[List[str]],
|
||||
*args, **kwargs):
|
||||
"""Handle checkpointing for generator methods."""
|
||||
generator = method(instance, *args, **kwargs)
|
||||
iteration = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
value = next(generator)
|
||||
|
||||
if manager.should_checkpoint(iteration):
|
||||
state = _get_instance_state(instance, checkpoint_attrs)
|
||||
state['iteration'] = iteration
|
||||
manager.save(state)
|
||||
|
||||
yield value
|
||||
iteration += 1
|
||||
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def _get_instance_state(instance: Any, attrs: Optional[List[str]] = None) -> dict:
|
||||
"""Extract state from instance."""
|
||||
if attrs:
|
||||
return {attr: getattr(instance, attr, None) for attr in attrs}
|
||||
else:
|
||||
# Auto-detect state (exclude private and callable)
|
||||
state = {}
|
||||
for attr in dir(instance):
|
||||
if not attr.startswith('_') and hasattr(instance, attr):
|
||||
value = getattr(instance, attr)
|
||||
if not callable(value):
|
||||
try:
|
||||
# Test if pickleable
|
||||
import pickle
|
||||
pickle.dumps(value)
|
||||
state[attr] = value
|
||||
except:
|
||||
pass
|
||||
return state
|
||||
431
src/sqrtspace_spacetime/checkpoint/manager.py
Normal file
431
src/sqrtspace_spacetime/checkpoint/manager.py
Normal file
@ -0,0 +1,431 @@
|
||||
"""
|
||||
Checkpoint manager for saving and restoring computation state.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import pickle
|
||||
import zlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import psutil
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.memory import monitor
|
||||
|
||||
|
||||
class CheckpointStrategy(Enum):
|
||||
"""Checkpointing strategies."""
|
||||
SQRT_N = "sqrt_n" # Checkpoint every √n iterations
|
||||
MEMORY_PRESSURE = "memory_pressure" # Checkpoint when memory exceeds threshold
|
||||
TIME_BASED = "time_based" # Checkpoint every k seconds
|
||||
ADAPTIVE = "adaptive" # Dynamically adjust based on performance
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointConfig:
|
||||
"""Configuration for checkpointing."""
|
||||
strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N
|
||||
checkpoint_dir: str = ".checkpoints"
|
||||
compression: bool = True
|
||||
compression_level: int = 6
|
||||
memory_threshold: float = 0.8 # Fraction of available memory
|
||||
time_interval: float = 60.0 # Seconds between checkpoints
|
||||
min_interval: int = 100 # Minimum iterations between checkpoints
|
||||
max_checkpoints: int = 10 # Maximum concurrent checkpoints
|
||||
enable_recovery: bool = True
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Metadata for a checkpoint."""
|
||||
checkpoint_id: str
|
||||
iteration: int
|
||||
timestamp: float
|
||||
state_size: int
|
||||
compressed_size: int
|
||||
compression_ratio: float
|
||||
strategy_used: str
|
||||
reason: str
|
||||
state_vars: List[str]
|
||||
performance_impact: Dict[str, float]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointStats:
|
||||
"""Statistics about checkpointing performance."""
|
||||
total_checkpoints: int = 0
|
||||
total_time: float = 0.0
|
||||
total_size: int = 0
|
||||
compressed_size: int = 0
|
||||
average_compression: float = 0.0
|
||||
memory_saved: int = 0
|
||||
overhead_percent: float = 0.0
|
||||
recoveries: int = 0
|
||||
strategy_distribution: Dict[str, int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.strategy_distribution is None:
|
||||
self.strategy_distribution = {}
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manage checkpoints for long-running computations.
|
||||
|
||||
Implements Williams' √n checkpoint intervals for optimal space-time tradeoff.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
config: Optional[CheckpointConfig] = None):
|
||||
"""
|
||||
Initialize checkpoint manager.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Unique ID for this computation
|
||||
config: Checkpoint configuration
|
||||
"""
|
||||
self.checkpoint_id = checkpoint_id or str(uuid.uuid4())
|
||||
self.config = config or CheckpointConfig()
|
||||
self.stats = CheckpointStats()
|
||||
|
||||
# Create checkpoint directory
|
||||
self.checkpoint_path = Path(self.config.checkpoint_dir) / self.checkpoint_id
|
||||
self.checkpoint_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self._iteration = 0
|
||||
self._last_checkpoint_iter = 0
|
||||
self._last_checkpoint_time = time.time()
|
||||
self._checkpoint_interval = None
|
||||
self._total_iterations = None
|
||||
|
||||
def should_checkpoint(self, iteration: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Determine if checkpoint is needed.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration (None to use internal counter)
|
||||
|
||||
Returns:
|
||||
True if checkpoint should be created
|
||||
"""
|
||||
if iteration is not None:
|
||||
self._iteration = iteration
|
||||
else:
|
||||
self._iteration += 1
|
||||
|
||||
# Check strategy
|
||||
if self.config.strategy == CheckpointStrategy.SQRT_N:
|
||||
return self._should_checkpoint_sqrt_n()
|
||||
elif self.config.strategy == CheckpointStrategy.MEMORY_PRESSURE:
|
||||
return self._should_checkpoint_memory()
|
||||
elif self.config.strategy == CheckpointStrategy.TIME_BASED:
|
||||
return self._should_checkpoint_time()
|
||||
elif self.config.strategy == CheckpointStrategy.ADAPTIVE:
|
||||
return self._should_checkpoint_adaptive()
|
||||
|
||||
return False
|
||||
|
||||
def _should_checkpoint_sqrt_n(self) -> bool:
|
||||
"""Check if checkpoint needed using √n strategy."""
|
||||
if self._checkpoint_interval is None:
|
||||
# Estimate interval if total iterations unknown
|
||||
if self._total_iterations:
|
||||
self._checkpoint_interval = max(
|
||||
self.config.min_interval,
|
||||
int(self._total_iterations ** 0.5)
|
||||
)
|
||||
else:
|
||||
# Use adaptive estimation
|
||||
self._checkpoint_interval = self.config.min_interval
|
||||
|
||||
iterations_since = self._iteration - self._last_checkpoint_iter
|
||||
return iterations_since >= self._checkpoint_interval
|
||||
|
||||
def _should_checkpoint_memory(self) -> bool:
|
||||
"""Check if checkpoint needed due to memory pressure."""
|
||||
mem_info = monitor.get_memory_info()
|
||||
return mem_info.percent > self.config.memory_threshold * 100
|
||||
|
||||
def _should_checkpoint_time(self) -> bool:
|
||||
"""Check if checkpoint needed based on time."""
|
||||
elapsed = time.time() - self._last_checkpoint_time
|
||||
return elapsed >= self.config.time_interval
|
||||
|
||||
def _should_checkpoint_adaptive(self) -> bool:
|
||||
"""Adaptive checkpointing based on multiple factors."""
|
||||
# Combine strategies
|
||||
sqrt_n = self._should_checkpoint_sqrt_n()
|
||||
memory = self._should_checkpoint_memory()
|
||||
time_based = self._should_checkpoint_time()
|
||||
|
||||
# Checkpoint if any condition is met
|
||||
return sqrt_n or memory or time_based
|
||||
|
||||
def save(self, state: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Save checkpoint.
|
||||
|
||||
Args:
|
||||
state: State dictionary to save
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Checkpoint ID
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Generate checkpoint ID
|
||||
checkpoint_file = self.checkpoint_path / f"checkpoint_{self._iteration}.pkl"
|
||||
|
||||
# Prepare state
|
||||
state_bytes = pickle.dumps(state)
|
||||
original_size = len(state_bytes)
|
||||
|
||||
# Compress if enabled
|
||||
if self.config.compression:
|
||||
state_bytes = zlib.compress(state_bytes, self.config.compression_level)
|
||||
compressed_size = len(state_bytes)
|
||||
compression_ratio = original_size / compressed_size
|
||||
else:
|
||||
compressed_size = original_size
|
||||
compression_ratio = 1.0
|
||||
|
||||
# Save checkpoint
|
||||
with open(checkpoint_file, 'wb') as f:
|
||||
f.write(state_bytes)
|
||||
|
||||
# Save metadata
|
||||
checkpoint_metadata = CheckpointMetadata(
|
||||
checkpoint_id=str(checkpoint_file),
|
||||
iteration=self._iteration,
|
||||
timestamp=time.time(),
|
||||
state_size=original_size,
|
||||
compressed_size=compressed_size,
|
||||
compression_ratio=compression_ratio,
|
||||
strategy_used=self.config.strategy.value,
|
||||
reason=self._get_checkpoint_reason(),
|
||||
state_vars=list(state.keys()),
|
||||
performance_impact={
|
||||
'save_time': time.time() - start_time,
|
||||
'compression_time': 0.0 # TODO: measure separately
|
||||
}
|
||||
)
|
||||
|
||||
metadata_file = checkpoint_file.with_suffix('.json')
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(checkpoint_metadata.to_dict(), f, indent=2)
|
||||
|
||||
# Update stats
|
||||
self._update_stats(checkpoint_metadata)
|
||||
|
||||
# Update tracking
|
||||
self._last_checkpoint_iter = self._iteration
|
||||
self._last_checkpoint_time = time.time()
|
||||
|
||||
# Clean old checkpoints
|
||||
self._cleanup_old_checkpoints()
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Checkpoint saved: iteration {self._iteration}, "
|
||||
f"size {compressed_size / 1024:.1f}KB, "
|
||||
f"compression {compression_ratio:.1f}x")
|
||||
|
||||
return str(checkpoint_file)
|
||||
|
||||
def load(self, checkpoint_id: Optional[str] = None) -> Tuple[Dict[str, Any], CheckpointMetadata]:
|
||||
"""
|
||||
Load checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: Specific checkpoint to load (None for latest)
|
||||
|
||||
Returns:
|
||||
Tuple of (state, metadata)
|
||||
"""
|
||||
if checkpoint_id:
|
||||
checkpoint_file = Path(checkpoint_id)
|
||||
else:
|
||||
# Find latest checkpoint
|
||||
checkpoints = list(self.checkpoint_path.glob("checkpoint_*.pkl"))
|
||||
if not checkpoints:
|
||||
raise ValueError("No checkpoints found")
|
||||
|
||||
checkpoint_file = max(checkpoints, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
# Load metadata
|
||||
metadata_file = checkpoint_file.with_suffix('.json')
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata_dict = json.load(f)
|
||||
metadata = CheckpointMetadata(**metadata_dict)
|
||||
|
||||
# Load state
|
||||
with open(checkpoint_file, 'rb') as f:
|
||||
state_bytes = f.read()
|
||||
|
||||
# Decompress if needed
|
||||
if self.config.compression:
|
||||
state_bytes = zlib.decompress(state_bytes)
|
||||
|
||||
state = pickle.loads(state_bytes)
|
||||
|
||||
# Update stats
|
||||
self.stats.recoveries += 1
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Checkpoint loaded: iteration {metadata.iteration}")
|
||||
|
||||
return state, metadata
|
||||
|
||||
def list_checkpoints(self) -> List[CheckpointMetadata]:
|
||||
"""List all available checkpoints."""
|
||||
metadata_files = self.checkpoint_path.glob("checkpoint_*.json")
|
||||
checkpoints = []
|
||||
|
||||
for metadata_file in metadata_files:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata_dict = json.load(f)
|
||||
checkpoints.append(CheckpointMetadata(**metadata_dict))
|
||||
|
||||
return sorted(checkpoints, key=lambda c: c.iteration)
|
||||
|
||||
def delete_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""Delete specific checkpoint."""
|
||||
checkpoint_file = Path(checkpoint_id)
|
||||
metadata_file = checkpoint_file.with_suffix('.json')
|
||||
|
||||
if checkpoint_file.exists():
|
||||
checkpoint_file.unlink()
|
||||
if metadata_file.exists():
|
||||
metadata_file.unlink()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all checkpoints."""
|
||||
import shutil
|
||||
if self.checkpoint_path.exists():
|
||||
shutil.rmtree(self.checkpoint_path)
|
||||
|
||||
def set_total_iterations(self, total: int) -> None:
|
||||
"""
|
||||
Set total iterations for optimal √n calculation.
|
||||
|
||||
Args:
|
||||
total: Total number of iterations
|
||||
"""
|
||||
self._total_iterations = total
|
||||
self._checkpoint_interval = max(
|
||||
self.config.min_interval,
|
||||
int(total ** 0.5)
|
||||
)
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"Checkpoint interval set to {self._checkpoint_interval} "
|
||||
f"(√{total} strategy)")
|
||||
|
||||
def get_stats(self) -> CheckpointStats:
|
||||
"""Get checkpoint statistics."""
|
||||
if self.stats.total_checkpoints > 0:
|
||||
self.stats.average_compression = (
|
||||
self.stats.total_size / self.stats.compressed_size
|
||||
)
|
||||
self.stats.overhead_percent = (
|
||||
self.stats.total_time / (time.time() - self._last_checkpoint_time) * 100
|
||||
)
|
||||
|
||||
return self.stats
|
||||
|
||||
def _get_checkpoint_reason(self) -> str:
|
||||
"""Get reason for checkpoint."""
|
||||
if self.config.strategy == CheckpointStrategy.SQRT_N:
|
||||
return f"√n interval reached ({self._checkpoint_interval} iterations)"
|
||||
elif self.config.strategy == CheckpointStrategy.MEMORY_PRESSURE:
|
||||
mem_info = monitor.get_memory_info()
|
||||
return f"Memory pressure: {mem_info.percent:.1f}%"
|
||||
elif self.config.strategy == CheckpointStrategy.TIME_BASED:
|
||||
return f"Time interval: {self.config.time_interval}s"
|
||||
else:
|
||||
return "Adaptive strategy triggered"
|
||||
|
||||
def _update_stats(self, metadata: CheckpointMetadata) -> None:
|
||||
"""Update statistics."""
|
||||
self.stats.total_checkpoints += 1
|
||||
self.stats.total_time += metadata.performance_impact['save_time']
|
||||
self.stats.total_size += metadata.state_size
|
||||
self.stats.compressed_size += metadata.compressed_size
|
||||
|
||||
# Update strategy distribution
|
||||
strategy = metadata.strategy_used
|
||||
self.stats.strategy_distribution[strategy] = (
|
||||
self.stats.strategy_distribution.get(strategy, 0) + 1
|
||||
)
|
||||
|
||||
def _cleanup_old_checkpoints(self) -> None:
|
||||
"""Remove old checkpoints to stay under limit."""
|
||||
checkpoints = list(self.checkpoint_path.glob("checkpoint_*.pkl"))
|
||||
|
||||
if len(checkpoints) > self.config.max_checkpoints:
|
||||
# Sort by modification time
|
||||
checkpoints.sort(key=lambda p: p.stat().st_mtime)
|
||||
|
||||
# Remove oldest
|
||||
for checkpoint in checkpoints[:-self.config.max_checkpoints]:
|
||||
self.delete_checkpoint(str(checkpoint))
|
||||
|
||||
def create_recovery_code(self, func: Callable) -> str:
|
||||
"""
|
||||
Generate recovery code for function.
|
||||
|
||||
Args:
|
||||
func: Function to generate recovery for
|
||||
|
||||
Returns:
|
||||
Recovery code as string
|
||||
"""
|
||||
recovery_template = '''
|
||||
def recover_{func_name}(checkpoint_id=None):
|
||||
"""Recover {func_name} from checkpoint."""
|
||||
manager = CheckpointManager("{checkpoint_id}")
|
||||
|
||||
# Load checkpoint
|
||||
state, metadata = manager.load(checkpoint_id)
|
||||
|
||||
# Resume computation
|
||||
iteration = metadata.iteration
|
||||
|
||||
# Restore state variables
|
||||
{state_restoration}
|
||||
|
||||
# Continue from checkpoint
|
||||
# TODO: Add continuation logic
|
||||
|
||||
return state
|
||||
'''
|
||||
|
||||
# Get function name
|
||||
func_name = func.__name__
|
||||
|
||||
# Generate state restoration code
|
||||
state_vars = []
|
||||
if hasattr(func, '_checkpoint_state'):
|
||||
state_vars = func._checkpoint_state
|
||||
|
||||
state_restoration = '\n '.join(
|
||||
f"{var} = state.get('{var}')" for var in state_vars
|
||||
)
|
||||
|
||||
return recovery_template.format(
|
||||
func_name=func_name,
|
||||
checkpoint_id=self.checkpoint_id,
|
||||
state_restoration=state_restoration
|
||||
)
|
||||
9
src/sqrtspace_spacetime/collections/__init__.py
Normal file
9
src/sqrtspace_spacetime/collections/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""Memory-efficient collections using √n space-time tradeoffs."""
|
||||
|
||||
from sqrtspace_spacetime.collections.spacetime_array import SpaceTimeArray
|
||||
from sqrtspace_spacetime.collections.spacetime_dict import SpaceTimeDict
|
||||
|
||||
__all__ = [
|
||||
"SpaceTimeArray",
|
||||
"SpaceTimeDict",
|
||||
]
|
||||
273
src/sqrtspace_spacetime/collections/spacetime_array.py
Normal file
273
src/sqrtspace_spacetime/collections/spacetime_array.py
Normal file
@ -0,0 +1,273 @@
|
||||
"""
|
||||
SpaceTimeArray: A memory-efficient array that automatically spills to disk.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import weakref
|
||||
from typing import Any, Iterator, Optional, Union, List
|
||||
from collections.abc import MutableSequence
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.memory import monitor, MemoryPressureLevel
|
||||
|
||||
|
||||
class SpaceTimeArray(MutableSequence):
|
||||
"""
|
||||
A list-like container that automatically manages memory usage by
|
||||
spilling to disk when threshold is reached.
|
||||
"""
|
||||
|
||||
_instances = weakref.WeakSet()
|
||||
|
||||
def __init__(self, threshold: Optional[Union[int, str]] = None, storage_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize SpaceTimeArray.
|
||||
|
||||
Args:
|
||||
threshold: Number of items to keep in memory (None or 'auto' for automatic)
|
||||
storage_path: Path for external storage (None for temp)
|
||||
"""
|
||||
if threshold == 'auto' or threshold is None:
|
||||
self.threshold = config.calculate_chunk_size(10000)
|
||||
else:
|
||||
self.threshold = int(threshold)
|
||||
self.storage_path = storage_path or config.external_storage_path
|
||||
|
||||
self._hot_data: List[Any] = []
|
||||
self._cold_indices: set = set()
|
||||
self._cold_storage: Optional[str] = None
|
||||
self._length = 0
|
||||
self._cold_file_handle = None
|
||||
|
||||
# Register for memory pressure handling
|
||||
SpaceTimeArray._instances.add(self)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index: Union[int, slice]) -> Any:
|
||||
if isinstance(index, slice):
|
||||
return [self[i] for i in range(*index.indices(len(self)))]
|
||||
|
||||
if index < 0:
|
||||
index += self._length
|
||||
|
||||
if not 0 <= index < self._length:
|
||||
raise IndexError("list index out of range")
|
||||
|
||||
# Check if in hot storage
|
||||
if index not in self._cold_indices:
|
||||
hot_index = index - len(self._cold_indices)
|
||||
return self._hot_data[hot_index]
|
||||
|
||||
# Load from cold storage
|
||||
return self._load_from_cold(index)
|
||||
|
||||
def __setitem__(self, index: Union[int, slice], value: Any) -> None:
|
||||
if isinstance(index, slice):
|
||||
for i, v in zip(range(*index.indices(len(self))), value):
|
||||
self[i] = v
|
||||
return
|
||||
|
||||
if index < 0:
|
||||
index += self._length
|
||||
|
||||
if not 0 <= index < self._length:
|
||||
raise IndexError("list assignment index out of range")
|
||||
|
||||
if index not in self._cold_indices:
|
||||
hot_index = index - len(self._cold_indices)
|
||||
self._hot_data[hot_index] = value
|
||||
else:
|
||||
# Update cold storage
|
||||
self._update_cold(index, value)
|
||||
|
||||
def __delitem__(self, index: Union[int, slice]) -> None:
|
||||
if isinstance(index, slice):
|
||||
# Delete in reverse order to maintain indices
|
||||
for i in reversed(range(*index.indices(len(self)))):
|
||||
del self[i]
|
||||
return
|
||||
|
||||
if index < 0:
|
||||
index += self._length
|
||||
|
||||
if not 0 <= index < self._length:
|
||||
raise IndexError("list index out of range")
|
||||
|
||||
# This is complex with cold storage, so we'll reload everything
|
||||
all_data = list(self)
|
||||
del all_data[index]
|
||||
self.clear()
|
||||
self.extend(all_data)
|
||||
|
||||
def insert(self, index: int, value: Any) -> None:
|
||||
if index < 0:
|
||||
index += self._length
|
||||
index = max(0, min(index, self._length))
|
||||
|
||||
# Simple implementation: reload all, insert, save back
|
||||
all_data = list(self)
|
||||
all_data.insert(index, value)
|
||||
self.clear()
|
||||
self.extend(all_data)
|
||||
|
||||
def append(self, value: Any) -> None:
|
||||
"""Append an item to the array."""
|
||||
self._hot_data.append(value)
|
||||
self._length += 1
|
||||
|
||||
# Check if we need to spill
|
||||
if len(self._hot_data) > self.threshold:
|
||||
self._check_and_spill()
|
||||
|
||||
def extend(self, iterable) -> None:
|
||||
"""Extend array with items from iterable."""
|
||||
for item in iterable:
|
||||
self.append(item)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items."""
|
||||
self._hot_data.clear()
|
||||
self._cold_indices.clear()
|
||||
self._length = 0
|
||||
|
||||
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||
os.unlink(self._cold_storage)
|
||||
self._cold_storage = None
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Iterate over all items."""
|
||||
# First yield cold items
|
||||
for idx in sorted(self._cold_indices):
|
||||
yield self._load_from_cold(idx)
|
||||
|
||||
# Then hot items
|
||||
for item in self._hot_data:
|
||||
yield item
|
||||
|
||||
def _check_and_spill(self) -> None:
|
||||
"""Check memory pressure and spill to disk if needed."""
|
||||
# Check memory pressure
|
||||
pressure = monitor.check_memory_pressure()
|
||||
|
||||
if pressure >= MemoryPressureLevel.MEDIUM or len(self._hot_data) > self.threshold:
|
||||
self._spill_to_disk()
|
||||
|
||||
def _spill_to_disk(self) -> None:
|
||||
"""Spill oldest items to disk."""
|
||||
if not self._cold_storage:
|
||||
fd, self._cold_storage = tempfile.mkstemp(
|
||||
suffix='.spacetime',
|
||||
dir=self.storage_path
|
||||
)
|
||||
os.close(fd)
|
||||
|
||||
# Determine how many items to spill
|
||||
spill_count = len(self._hot_data) // 2
|
||||
|
||||
# Load existing cold data
|
||||
cold_data = {}
|
||||
if os.path.exists(self._cold_storage):
|
||||
with open(self._cold_storage, 'rb') as f:
|
||||
try:
|
||||
cold_data = pickle.load(f)
|
||||
except EOFError:
|
||||
cold_data = {}
|
||||
|
||||
# Move items to cold storage
|
||||
current_cold_size = len(self._cold_indices)
|
||||
for i in range(spill_count):
|
||||
cold_data[current_cold_size + i] = self._hot_data[i]
|
||||
self._cold_indices.add(current_cold_size + i)
|
||||
|
||||
# Remove from hot storage
|
||||
self._hot_data = self._hot_data[spill_count:]
|
||||
|
||||
# Save cold data
|
||||
with open(self._cold_storage, 'wb') as f:
|
||||
pickle.dump(cold_data, f)
|
||||
|
||||
def _load_from_cold(self, index: int) -> Any:
|
||||
"""Load an item from cold storage."""
|
||||
if not self._cold_storage or not os.path.exists(self._cold_storage):
|
||||
raise IndexError(f"Cold storage index {index} not found")
|
||||
|
||||
with open(self._cold_storage, 'rb') as f:
|
||||
cold_data = pickle.load(f)
|
||||
|
||||
return cold_data.get(index)
|
||||
|
||||
def _update_cold(self, index: int, value: Any) -> None:
|
||||
"""Update an item in cold storage."""
|
||||
if not self._cold_storage:
|
||||
return
|
||||
|
||||
with open(self._cold_storage, 'rb') as f:
|
||||
cold_data = pickle.load(f)
|
||||
|
||||
cold_data[index] = value
|
||||
|
||||
with open(self._cold_storage, 'wb') as f:
|
||||
pickle.dump(cold_data, f)
|
||||
|
||||
def memory_usage(self) -> int:
|
||||
"""Estimate memory usage in bytes."""
|
||||
# Rough estimate - actual usage may vary
|
||||
return len(self._hot_data) * 50 # Assume 50 bytes per item average
|
||||
|
||||
def spill_to_disk(self, path: Optional[str] = None) -> None:
|
||||
"""Force spill all data to disk."""
|
||||
if path:
|
||||
self.storage_path = path
|
||||
|
||||
while self._hot_data:
|
||||
self._spill_to_disk()
|
||||
|
||||
def load_to_memory(self) -> None:
|
||||
"""Load all data back to memory."""
|
||||
if not self._cold_storage or not self._cold_indices:
|
||||
return
|
||||
|
||||
# Load cold data
|
||||
with open(self._cold_storage, 'rb') as f:
|
||||
cold_data = pickle.load(f)
|
||||
|
||||
# Rebuild array in correct order
|
||||
all_data = []
|
||||
cold_count = 0
|
||||
hot_count = 0
|
||||
|
||||
for i in range(self._length):
|
||||
if i in self._cold_indices:
|
||||
all_data.append(cold_data[i])
|
||||
cold_count += 1
|
||||
else:
|
||||
all_data.append(self._hot_data[hot_count])
|
||||
hot_count += 1
|
||||
|
||||
# Reset storage
|
||||
self._hot_data = all_data
|
||||
self._cold_indices.clear()
|
||||
|
||||
if os.path.exists(self._cold_storage):
|
||||
os.unlink(self._cold_storage)
|
||||
self._cold_storage = None
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up temporary files."""
|
||||
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||
try:
|
||||
os.unlink(self._cold_storage)
|
||||
except:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def handle_memory_pressure(cls, level: MemoryPressureLevel) -> None:
|
||||
"""Class method to handle memory pressure for all instances."""
|
||||
if level >= MemoryPressureLevel.HIGH:
|
||||
for instance in cls._instances:
|
||||
if instance._hot_data:
|
||||
instance._spill_to_disk()
|
||||
272
src/sqrtspace_spacetime/collections/spacetime_dict.py
Normal file
272
src/sqrtspace_spacetime/collections/spacetime_dict.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""
|
||||
SpaceTimeDict: A memory-efficient dictionary with automatic spillover.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple
|
||||
from collections import OrderedDict
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.memory import monitor, MemoryPressureLevel
|
||||
|
||||
|
||||
class SpaceTimeDict(MutableMapping):
|
||||
"""
|
||||
A dictionary that automatically manages memory by moving least-recently-used
|
||||
items to disk storage.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
threshold: Optional[int] = None,
|
||||
storage_path: Optional[str] = None,
|
||||
use_lru: bool = True):
|
||||
"""
|
||||
Initialize SpaceTimeDict.
|
||||
|
||||
Args:
|
||||
threshold: Number of items to keep in memory
|
||||
storage_path: Path for external storage
|
||||
use_lru: Use LRU eviction policy
|
||||
"""
|
||||
self.threshold = threshold or config.calculate_chunk_size(10000)
|
||||
self.storage_path = storage_path or config.external_storage_path
|
||||
self.use_lru = use_lru
|
||||
|
||||
# Hot storage (in memory)
|
||||
if use_lru:
|
||||
self._hot_data: Dict[Any, Any] = OrderedDict()
|
||||
else:
|
||||
self._hot_data: Dict[Any, Any] = {}
|
||||
|
||||
# Cold storage tracking
|
||||
self._cold_keys: set = set()
|
||||
self._cold_storage: Optional[str] = None
|
||||
self._cold_index: Dict[Any, Tuple[int, int]] = {} # key -> (offset, size)
|
||||
|
||||
# Statistics
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._last_access: Dict[Any, float] = {}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._hot_data) + len(self._cold_keys)
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
# Check hot storage first
|
||||
if key in self._hot_data:
|
||||
self._hits += 1
|
||||
if self.use_lru:
|
||||
# Move to end (most recent)
|
||||
self._hot_data.move_to_end(key)
|
||||
self._last_access[key] = time.time()
|
||||
return self._hot_data[key]
|
||||
|
||||
# Check cold storage
|
||||
if key in self._cold_keys:
|
||||
self._misses += 1
|
||||
value = self._load_from_cold(key)
|
||||
|
||||
# Promote to hot storage
|
||||
self._promote_to_hot(key, value)
|
||||
|
||||
return value
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: Any, value: Any) -> None:
|
||||
# If key exists in cold storage, remove it
|
||||
if key in self._cold_keys:
|
||||
self._cold_keys.remove(key)
|
||||
# Note: We don't actually remove from file to avoid rewriting
|
||||
|
||||
# Add to hot storage
|
||||
self._hot_data[key] = value
|
||||
self._last_access[key] = time.time()
|
||||
|
||||
# Check if we need to evict
|
||||
if len(self._hot_data) > self.threshold:
|
||||
self._evict_to_cold()
|
||||
|
||||
def __delitem__(self, key: Any) -> None:
|
||||
if key in self._hot_data:
|
||||
del self._hot_data[key]
|
||||
self._last_access.pop(key, None)
|
||||
elif key in self._cold_keys:
|
||||
self._cold_keys.remove(key)
|
||||
self._cold_index.pop(key, None)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
# Iterate hot keys first
|
||||
yield from self._hot_data
|
||||
# Then cold keys
|
||||
yield from self._cold_keys
|
||||
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
return key in self._hot_data or key in self._cold_keys
|
||||
|
||||
def keys(self):
|
||||
"""Return a view of all keys."""
|
||||
return list(self._hot_data.keys()) + list(self._cold_keys)
|
||||
|
||||
def values(self):
|
||||
"""Return a view of all values."""
|
||||
for key in self:
|
||||
yield self[key]
|
||||
|
||||
def items(self):
|
||||
"""Return a view of all key-value pairs."""
|
||||
for key in self:
|
||||
yield (key, self[key])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items."""
|
||||
self._hot_data.clear()
|
||||
self._cold_keys.clear()
|
||||
self._cold_index.clear()
|
||||
self._last_access.clear()
|
||||
|
||||
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||
os.unlink(self._cold_storage)
|
||||
self._cold_storage = None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get usage statistics."""
|
||||
total = self._hits + self._misses
|
||||
hit_rate = self._hits / total if total > 0 else 0
|
||||
|
||||
return {
|
||||
"hot_items": len(self._hot_data),
|
||||
"cold_items": len(self._cold_keys),
|
||||
"total_items": len(self),
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": hit_rate,
|
||||
"memory_usage": self.memory_usage(),
|
||||
}
|
||||
|
||||
def _evict_to_cold(self) -> None:
|
||||
"""Evict least recently used items to cold storage."""
|
||||
evict_count = max(1, len(self._hot_data) // 4) # Evict 25%
|
||||
|
||||
if not self._cold_storage:
|
||||
fd, self._cold_storage = tempfile.mkstemp(
|
||||
suffix='.spacetime_dict',
|
||||
dir=self.storage_path
|
||||
)
|
||||
os.close(fd)
|
||||
|
||||
# Select items to evict
|
||||
if self.use_lru:
|
||||
# OrderedDict: oldest items are first
|
||||
evict_keys = list(self._hot_data.keys())[:evict_count]
|
||||
else:
|
||||
# Use access time
|
||||
sorted_keys = sorted(
|
||||
self._hot_data.keys(),
|
||||
key=lambda k: self._last_access.get(k, 0)
|
||||
)
|
||||
evict_keys = sorted_keys[:evict_count]
|
||||
|
||||
# Write to cold storage
|
||||
with open(self._cold_storage, 'ab') as f:
|
||||
for key in evict_keys:
|
||||
value = self._hot_data[key]
|
||||
offset = f.tell()
|
||||
|
||||
# Serialize key-value pair
|
||||
data = pickle.dumps((key, value))
|
||||
size = len(data)
|
||||
|
||||
# Write size header and data
|
||||
f.write(size.to_bytes(4, 'little'))
|
||||
f.write(data)
|
||||
|
||||
# Update indices
|
||||
self._cold_index[key] = (offset, size + 4)
|
||||
self._cold_keys.add(key)
|
||||
|
||||
# Remove from hot storage
|
||||
del self._hot_data[key]
|
||||
|
||||
def _load_from_cold(self, key: Any) -> Any:
|
||||
"""Load a value from cold storage."""
|
||||
if key not in self._cold_index:
|
||||
raise KeyError(key)
|
||||
|
||||
offset, size = self._cold_index[key]
|
||||
|
||||
with open(self._cold_storage, 'rb') as f:
|
||||
f.seek(offset)
|
||||
size_bytes = f.read(4)
|
||||
data_size = int.from_bytes(size_bytes, 'little')
|
||||
data = f.read(data_size)
|
||||
|
||||
stored_key, value = pickle.loads(data)
|
||||
assert stored_key == key
|
||||
|
||||
return value
|
||||
|
||||
def _promote_to_hot(self, key: Any, value: Any) -> None:
|
||||
"""Promote a cold item to hot storage."""
|
||||
# Remove from cold tracking
|
||||
self._cold_keys.remove(key)
|
||||
|
||||
# Add to hot storage
|
||||
self._hot_data[key] = value
|
||||
self._last_access[key] = time.time()
|
||||
|
||||
# Check if we need to evict something else
|
||||
if len(self._hot_data) > self.threshold:
|
||||
self._evict_to_cold()
|
||||
|
||||
def memory_usage(self) -> int:
|
||||
"""Estimate memory usage in bytes."""
|
||||
# Rough estimate
|
||||
return len(self._hot_data) * 100 # Assume 100 bytes per item average
|
||||
|
||||
def compact(self) -> None:
|
||||
"""Compact cold storage by removing deleted entries."""
|
||||
if not self._cold_storage or not self._cold_keys:
|
||||
return
|
||||
|
||||
# Create new file
|
||||
fd, new_storage = tempfile.mkstemp(
|
||||
suffix='.spacetime_dict',
|
||||
dir=self.storage_path
|
||||
)
|
||||
os.close(fd)
|
||||
|
||||
new_index = {}
|
||||
|
||||
# Copy only active entries
|
||||
with open(new_storage, 'wb') as new_f:
|
||||
for key in self._cold_keys:
|
||||
value = self._load_from_cold(key)
|
||||
offset = new_f.tell()
|
||||
|
||||
data = pickle.dumps((key, value))
|
||||
size = len(data)
|
||||
|
||||
new_f.write(size.to_bytes(4, 'little'))
|
||||
new_f.write(data)
|
||||
|
||||
new_index[key] = (offset, size + 4)
|
||||
|
||||
# Replace old storage
|
||||
os.unlink(self._cold_storage)
|
||||
self._cold_storage = new_storage
|
||||
self._cold_index = new_index
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up temporary files."""
|
||||
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||
try:
|
||||
os.unlink(self._cold_storage)
|
||||
except:
|
||||
pass
|
||||
186
src/sqrtspace_spacetime/config.py
Normal file
186
src/sqrtspace_spacetime/config.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
Configuration management for SpaceTime operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import tempfile
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import psutil
|
||||
|
||||
|
||||
class ChunkStrategy(Enum):
|
||||
"""Strategy for determining chunk sizes."""
|
||||
SQRT_N = "sqrt_n"
|
||||
MEMORY_BASED = "memory_based"
|
||||
FIXED = "fixed"
|
||||
ADAPTIVE = "adaptive"
|
||||
|
||||
|
||||
class CompressionType(Enum):
|
||||
"""Compression algorithms for external storage."""
|
||||
NONE = "none"
|
||||
GZIP = "gzip"
|
||||
LZ4 = "lz4"
|
||||
ZSTD = "zstd"
|
||||
SNAPPY = "snappy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryHierarchy:
|
||||
"""Memory hierarchy information."""
|
||||
l1_cache: int = field(default_factory=lambda: 32 * 1024) # 32KB
|
||||
l2_cache: int = field(default_factory=lambda: 256 * 1024) # 256KB
|
||||
l3_cache: int = field(default_factory=lambda: 8 * 1024 * 1024) # 8MB
|
||||
ram: int = field(default_factory=lambda: psutil.virtual_memory().total)
|
||||
disk: int = field(default_factory=lambda: psutil.disk_usage('/').total)
|
||||
|
||||
def get_optimal_buffer_size(self, total_size: int) -> int:
|
||||
"""Calculate optimal buffer size based on memory hierarchy."""
|
||||
sqrt_n = int(math.sqrt(total_size))
|
||||
|
||||
# Try to fit in L3 cache
|
||||
if sqrt_n <= self.l3_cache:
|
||||
return sqrt_n
|
||||
|
||||
# Otherwise use a fraction of available RAM
|
||||
available_ram = psutil.virtual_memory().available
|
||||
return min(sqrt_n, int(available_ram * 0.1))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpaceTimeConfig:
|
||||
"""Global configuration for SpaceTime operations."""
|
||||
|
||||
# Memory limits
|
||||
memory_limit: int = field(default_factory=lambda: int(psutil.virtual_memory().total * 0.8))
|
||||
memory_threshold: float = 0.8 # Trigger spillover at 80% usage
|
||||
|
||||
# Storage
|
||||
external_storage_path: str = field(default_factory=lambda: os.path.join(tempfile.gettempdir(), "spacetime"))
|
||||
compression: CompressionType = CompressionType.GZIP
|
||||
compression_level: int = 6
|
||||
|
||||
# Chunking
|
||||
chunk_strategy: ChunkStrategy = ChunkStrategy.SQRT_N
|
||||
fixed_chunk_size: int = 10000
|
||||
min_chunk_size: int = 100
|
||||
max_chunk_size: int = 10_000_000
|
||||
|
||||
# Checkpointing
|
||||
enable_checkpointing: bool = True
|
||||
checkpoint_interval: int = 60 # seconds
|
||||
checkpoint_storage: str = "file" # "file", "redis", "s3"
|
||||
|
||||
# Performance
|
||||
enable_profiling: bool = False
|
||||
parallel_workers: int = field(default_factory=lambda: min(4, os.cpu_count() or 1))
|
||||
prefetch_size: int = 2 # Number of chunks to prefetch
|
||||
|
||||
# Memory hierarchy
|
||||
hierarchy: MemoryHierarchy = field(default_factory=MemoryHierarchy)
|
||||
|
||||
_instance: Optional['SpaceTimeConfig'] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize storage directory."""
|
||||
os.makedirs(self.external_storage_path, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'SpaceTimeConfig':
|
||||
"""Get singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def set_defaults(cls, **kwargs) -> None:
|
||||
"""Set default configuration values."""
|
||||
instance = cls.get_instance()
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
|
||||
def calculate_chunk_size(self, total_size: int) -> int:
|
||||
"""Calculate optimal chunk size based on strategy."""
|
||||
if self.chunk_strategy == ChunkStrategy.FIXED:
|
||||
return self.fixed_chunk_size
|
||||
|
||||
elif self.chunk_strategy == ChunkStrategy.SQRT_N:
|
||||
sqrt_n = int(math.sqrt(total_size))
|
||||
return max(self.min_chunk_size, min(sqrt_n, self.max_chunk_size))
|
||||
|
||||
elif self.chunk_strategy == ChunkStrategy.MEMORY_BASED:
|
||||
available = psutil.virtual_memory().available
|
||||
# Use 10% of available memory for chunks
|
||||
chunk_size = int(available * 0.1 / 8) # Assume 8 bytes per item
|
||||
return max(self.min_chunk_size, min(chunk_size, self.max_chunk_size))
|
||||
|
||||
elif self.chunk_strategy == ChunkStrategy.ADAPTIVE:
|
||||
# Start with sqrt(n) and adjust based on memory pressure
|
||||
base_size = int(math.sqrt(total_size))
|
||||
memory_percent = psutil.virtual_memory().percent
|
||||
|
||||
if memory_percent > 90:
|
||||
# Very high pressure: use minimum size
|
||||
return self.min_chunk_size
|
||||
elif memory_percent > 70:
|
||||
# High pressure: reduce chunk size
|
||||
return max(self.min_chunk_size, base_size // 2)
|
||||
elif memory_percent < 30:
|
||||
# Low pressure: increase chunk size
|
||||
return min(self.max_chunk_size, base_size * 2)
|
||||
else:
|
||||
# Normal pressure: use sqrt(n)
|
||||
return max(self.min_chunk_size, min(base_size, self.max_chunk_size))
|
||||
|
||||
return self.fixed_chunk_size
|
||||
|
||||
def get_compression_module(self):
|
||||
"""Get compression module based on configuration."""
|
||||
if self.compression == CompressionType.GZIP:
|
||||
import gzip
|
||||
return gzip
|
||||
elif self.compression == CompressionType.LZ4:
|
||||
try:
|
||||
import lz4.frame
|
||||
return lz4.frame
|
||||
except ImportError:
|
||||
import gzip
|
||||
return gzip
|
||||
elif self.compression == CompressionType.ZSTD:
|
||||
try:
|
||||
import zstandard
|
||||
return zstandard
|
||||
except ImportError:
|
||||
import gzip
|
||||
return gzip
|
||||
elif self.compression == CompressionType.SNAPPY:
|
||||
try:
|
||||
import snappy
|
||||
return snappy
|
||||
except ImportError:
|
||||
import gzip
|
||||
return gzip
|
||||
else:
|
||||
return None
|
||||
|
||||
def format_bytes(self, bytes: int) -> str:
|
||||
"""Format bytes as human-readable string."""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if bytes < 1024.0:
|
||||
return f"{bytes:.2f} {unit}"
|
||||
bytes /= 1024.0
|
||||
return f"{bytes:.2f} PB"
|
||||
|
||||
def get_williams_bound(self, time_complexity: int) -> int:
|
||||
"""Calculate Williams' space bound: SPACE[√(t log t)]."""
|
||||
if time_complexity <= 0:
|
||||
return 1
|
||||
return int(math.sqrt(time_complexity * math.log2(max(2, time_complexity))))
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
config = SpaceTimeConfig.get_instance()
|
||||
27
src/sqrtspace_spacetime/memory/__init__.py
Normal file
27
src/sqrtspace_spacetime/memory/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Memory monitoring and pressure handling for SpaceTime."""
|
||||
|
||||
from sqrtspace_spacetime.memory.monitor import (
|
||||
MemoryMonitor,
|
||||
MemoryPressureLevel,
|
||||
MemoryInfo,
|
||||
MemoryPressureHandler,
|
||||
monitor,
|
||||
)
|
||||
from sqrtspace_spacetime.memory.handlers import (
|
||||
LoggingHandler,
|
||||
CacheEvictionHandler,
|
||||
GarbageCollectionHandler,
|
||||
ThrottlingHandler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MemoryMonitor",
|
||||
"MemoryPressureLevel",
|
||||
"MemoryInfo",
|
||||
"MemoryPressureHandler",
|
||||
"LoggingHandler",
|
||||
"CacheEvictionHandler",
|
||||
"GarbageCollectionHandler",
|
||||
"ThrottlingHandler",
|
||||
"monitor",
|
||||
]
|
||||
168
src/sqrtspace_spacetime/memory/handlers.py
Normal file
168
src/sqrtspace_spacetime/memory/handlers.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Memory pressure handlers."""
|
||||
|
||||
import gc
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, List, Callable, Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from sqrtspace_spacetime.memory.monitor import (
|
||||
MemoryPressureHandler,
|
||||
MemoryPressureLevel,
|
||||
MemoryInfo
|
||||
)
|
||||
|
||||
|
||||
class LoggingHandler(MemoryPressureHandler):
|
||||
"""Log memory pressure events."""
|
||||
|
||||
def __init__(self,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
min_level: MemoryPressureLevel = MemoryPressureLevel.MEDIUM):
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.min_level = min_level
|
||||
self._last_log = {}
|
||||
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
return level >= self.min_level
|
||||
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
# Avoid spamming logs - only log if level changed or 60s passed
|
||||
last_time = self._last_log.get(level, 0)
|
||||
if time.time() - last_time < 60 and level in self._last_log:
|
||||
return
|
||||
|
||||
self._last_log[level] = time.time()
|
||||
|
||||
if level == MemoryPressureLevel.CRITICAL:
|
||||
self.logger.critical(f"CRITICAL memory pressure: {info}")
|
||||
elif level == MemoryPressureLevel.HIGH:
|
||||
self.logger.error(f"HIGH memory pressure: {info}")
|
||||
elif level == MemoryPressureLevel.MEDIUM:
|
||||
self.logger.warning(f"MEDIUM memory pressure: {info}")
|
||||
else:
|
||||
self.logger.info(f"Memory pressure: {info}")
|
||||
|
||||
|
||||
class CacheEvictionHandler(MemoryPressureHandler):
|
||||
"""Evict cached data under memory pressure."""
|
||||
|
||||
def __init__(self):
|
||||
self._caches: List[WeakValueDictionary] = []
|
||||
self._eviction_rates = {
|
||||
MemoryPressureLevel.LOW: 0.1, # Evict 10%
|
||||
MemoryPressureLevel.MEDIUM: 0.25, # Evict 25%
|
||||
MemoryPressureLevel.HIGH: 0.5, # Evict 50%
|
||||
MemoryPressureLevel.CRITICAL: 0.9, # Evict 90%
|
||||
}
|
||||
|
||||
def register_cache(self, cache: Dict[Any, Any]) -> None:
|
||||
"""Register a cache for eviction."""
|
||||
self._caches.append(WeakValueDictionary(cache))
|
||||
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
return level >= MemoryPressureLevel.LOW and self._caches
|
||||
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
eviction_rate = self._eviction_rates.get(level, 0)
|
||||
if eviction_rate == 0:
|
||||
return
|
||||
|
||||
for cache in self._caches:
|
||||
if not cache:
|
||||
continue
|
||||
|
||||
size = len(cache)
|
||||
if size == 0:
|
||||
continue
|
||||
|
||||
# Evict entries
|
||||
to_evict = int(size * eviction_rate)
|
||||
keys = list(cache.keys())[:to_evict]
|
||||
|
||||
for key in keys:
|
||||
cache.pop(key, None)
|
||||
|
||||
|
||||
class GarbageCollectionHandler(MemoryPressureHandler):
|
||||
"""Trigger garbage collection under memory pressure."""
|
||||
|
||||
def __init__(self, min_interval: float = 5.0):
|
||||
self.min_interval = min_interval
|
||||
self._last_gc = 0
|
||||
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
return level >= MemoryPressureLevel.MEDIUM
|
||||
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
now = time.time()
|
||||
|
||||
# Don't GC too frequently
|
||||
if now - self._last_gc < self.min_interval:
|
||||
return
|
||||
|
||||
self._last_gc = now
|
||||
|
||||
# More aggressive GC for higher pressure
|
||||
if level >= MemoryPressureLevel.HIGH:
|
||||
# Full collection
|
||||
gc.collect(2)
|
||||
else:
|
||||
# Quick collection
|
||||
gc.collect(0)
|
||||
|
||||
|
||||
class ThrottlingHandler(MemoryPressureHandler):
|
||||
"""Throttle operations under memory pressure."""
|
||||
|
||||
def __init__(self):
|
||||
self._throttle_rates = {
|
||||
MemoryPressureLevel.LOW: 0, # No throttling
|
||||
MemoryPressureLevel.MEDIUM: 0.1, # 100ms delay
|
||||
MemoryPressureLevel.HIGH: 0.5, # 500ms delay
|
||||
MemoryPressureLevel.CRITICAL: 2.0, # 2s delay
|
||||
}
|
||||
self._callbacks: List[Callable[[float], None]] = []
|
||||
|
||||
def register_callback(self, callback: Callable[[float], None]) -> None:
|
||||
"""Register callback to be notified of throttle rates."""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
return level >= MemoryPressureLevel.MEDIUM
|
||||
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
delay = self._throttle_rates.get(level, 0)
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(delay)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class SpillToDiskHandler(MemoryPressureHandler):
|
||||
"""Spill data to disk under memory pressure."""
|
||||
|
||||
def __init__(self, spill_path: Optional[str] = None):
|
||||
self.spill_path = spill_path
|
||||
self._spillable_objects: List[Any] = []
|
||||
|
||||
def register_spillable(self, obj: Any) -> None:
|
||||
"""Register an object that can spill to disk."""
|
||||
if hasattr(obj, 'spill_to_disk'):
|
||||
self._spillable_objects.append(obj)
|
||||
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
return level >= MemoryPressureLevel.HIGH and self._spillable_objects
|
||||
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
for obj in self._spillable_objects:
|
||||
try:
|
||||
if hasattr(obj, 'memory_usage'):
|
||||
# Only spill large objects
|
||||
if obj.memory_usage() > 10 * 1024 * 1024: # 10MB
|
||||
obj.spill_to_disk(self.spill_path)
|
||||
except Exception:
|
||||
pass
|
||||
247
src/sqrtspace_spacetime/memory/monitor.py
Normal file
247
src/sqrtspace_spacetime/memory/monitor.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""Memory monitoring and pressure detection."""
|
||||
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Callable, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
|
||||
|
||||
class MemoryPressureLevel(Enum):
|
||||
"""Memory pressure levels."""
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, MemoryPressureLevel):
|
||||
return NotImplemented
|
||||
return self.value > other.value
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, MemoryPressureLevel):
|
||||
return NotImplemented
|
||||
return self.value >= other.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryInfo:
|
||||
"""Memory usage information."""
|
||||
total: int
|
||||
available: int
|
||||
used: int
|
||||
percent: float
|
||||
pressure_level: MemoryPressureLevel
|
||||
timestamp: float
|
||||
|
||||
@property
|
||||
def used_gb(self) -> float:
|
||||
return self.used / (1024 ** 3)
|
||||
|
||||
@property
|
||||
def available_gb(self) -> float:
|
||||
return self.available / (1024 ** 3)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (f"Memory: {self.percent:.1f}% used "
|
||||
f"({self.used_gb:.2f}/{self.available_gb:.2f} GB), "
|
||||
f"Pressure: {self.pressure_level.name}")
|
||||
|
||||
|
||||
class MemoryPressureHandler(ABC):
|
||||
"""Abstract base class for memory pressure handlers."""
|
||||
|
||||
@abstractmethod
|
||||
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||
"""Check if this handler should handle the given pressure level."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||
"""Handle memory pressure."""
|
||||
pass
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Monitor system memory and detect pressure."""
|
||||
|
||||
def __init__(self,
|
||||
check_interval: float = 1.0,
|
||||
memory_limit: Optional[int] = None):
|
||||
"""
|
||||
Initialize memory monitor.
|
||||
|
||||
Args:
|
||||
check_interval: Seconds between checks
|
||||
memory_limit: Custom memory limit in bytes (None for system limit)
|
||||
"""
|
||||
self.check_interval = check_interval
|
||||
self.memory_limit = memory_limit or config.memory_limit
|
||||
self.handlers: List[MemoryPressureHandler] = []
|
||||
self._monitoring = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._last_check = 0.0
|
||||
self._history: List[MemoryInfo] = []
|
||||
self._max_history = 100
|
||||
|
||||
def add_handler(self, handler: MemoryPressureHandler) -> None:
|
||||
"""Add a memory pressure handler."""
|
||||
self.handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: MemoryPressureHandler) -> None:
|
||||
"""Remove a memory pressure handler."""
|
||||
if handler in self.handlers:
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def get_memory_info(self) -> MemoryInfo:
|
||||
"""Get current memory information."""
|
||||
mem = psutil.virtual_memory()
|
||||
|
||||
# Use configured limit if lower than system memory
|
||||
total = min(mem.total, self.memory_limit)
|
||||
used = mem.used
|
||||
available = total - used
|
||||
percent = (used / total) * 100
|
||||
|
||||
# Determine pressure level
|
||||
if percent >= 95:
|
||||
level = MemoryPressureLevel.CRITICAL
|
||||
elif percent >= 85:
|
||||
level = MemoryPressureLevel.HIGH
|
||||
elif percent >= 70:
|
||||
level = MemoryPressureLevel.MEDIUM
|
||||
elif percent >= 50:
|
||||
level = MemoryPressureLevel.LOW
|
||||
else:
|
||||
level = MemoryPressureLevel.NONE
|
||||
|
||||
return MemoryInfo(
|
||||
total=total,
|
||||
available=available,
|
||||
used=used,
|
||||
percent=percent,
|
||||
pressure_level=level,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
def check_memory_pressure(self) -> MemoryPressureLevel:
|
||||
"""Check current memory pressure and notify handlers."""
|
||||
info = self.get_memory_info()
|
||||
|
||||
# Add to history
|
||||
self._history.append(info)
|
||||
if len(self._history) > self._max_history:
|
||||
self._history.pop(0)
|
||||
|
||||
# Notify handlers
|
||||
for handler in self.handlers:
|
||||
if handler.can_handle(info.pressure_level, info):
|
||||
try:
|
||||
handler.handle(info.pressure_level, info)
|
||||
except Exception as e:
|
||||
# Log but don't crash on handler errors
|
||||
print(f"Handler error: {e}")
|
||||
|
||||
self._last_check = time.time()
|
||||
return info.pressure_level
|
||||
|
||||
def should_check(self) -> bool:
|
||||
"""Check if enough time has passed for next check."""
|
||||
return time.time() - self._last_check >= self.check_interval
|
||||
|
||||
def start_monitoring(self) -> None:
|
||||
"""Start background monitoring thread."""
|
||||
if self._monitoring:
|
||||
return
|
||||
|
||||
self._monitoring = True
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop_monitoring(self) -> None:
|
||||
"""Stop background monitoring."""
|
||||
self._monitoring = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
self._thread = None
|
||||
|
||||
def _monitor_loop(self) -> None:
|
||||
"""Background monitoring loop."""
|
||||
while self._monitoring:
|
||||
try:
|
||||
self.check_memory_pressure()
|
||||
time.sleep(self.check_interval)
|
||||
except Exception as e:
|
||||
print(f"Monitoring error: {e}")
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
def get_memory_trend(self, seconds: int = 60) -> Dict[str, float]:
|
||||
"""Get memory usage trend over past N seconds."""
|
||||
if not self._history:
|
||||
return {"avg_percent": 0, "max_percent": 0, "trend": 0}
|
||||
|
||||
cutoff = time.time() - seconds
|
||||
recent = [h for h in self._history if h.timestamp >= cutoff]
|
||||
|
||||
if not recent:
|
||||
return {"avg_percent": 0, "max_percent": 0, "trend": 0}
|
||||
|
||||
percents = [h.percent for h in recent]
|
||||
avg_percent = sum(percents) / len(percents)
|
||||
max_percent = max(percents)
|
||||
|
||||
# Calculate trend (positive = increasing usage)
|
||||
if len(recent) >= 2:
|
||||
first_half = percents[:len(percents)//2]
|
||||
second_half = percents[len(percents)//2:]
|
||||
trend = sum(second_half)/len(second_half) - sum(first_half)/len(first_half)
|
||||
else:
|
||||
trend = 0
|
||||
|
||||
return {
|
||||
"avg_percent": avg_percent,
|
||||
"max_percent": max_percent,
|
||||
"trend": trend
|
||||
}
|
||||
|
||||
def force_gc(self) -> int:
|
||||
"""Force garbage collection and return bytes freed."""
|
||||
before = self.get_memory_info().used
|
||||
gc.collect()
|
||||
after = self.get_memory_info().used
|
||||
return max(0, before - after)
|
||||
|
||||
def wait_for_memory(self, required_bytes: int, timeout: float = 30) -> bool:
|
||||
"""
|
||||
Wait for required memory to become available.
|
||||
|
||||
Returns:
|
||||
True if memory became available, False if timeout
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
info = self.get_memory_info()
|
||||
if info.available >= required_bytes:
|
||||
return True
|
||||
|
||||
# Try to free memory
|
||||
self.force_gc()
|
||||
|
||||
# Let handlers do their work
|
||||
self.check_memory_pressure()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global monitor instance
|
||||
monitor = MemoryMonitor()
|
||||
23
src/sqrtspace_spacetime/ml/__init__.py
Normal file
23
src/sqrtspace_spacetime/ml/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Machine Learning memory optimization utilities."""
|
||||
|
||||
from sqrtspace_spacetime.ml.optimizer import (
|
||||
MLMemoryOptimizer,
|
||||
ModelProfile,
|
||||
OptimizationPlan,
|
||||
TrainingConfig,
|
||||
MemoryOptimizationStrategy,
|
||||
)
|
||||
from sqrtspace_spacetime.ml.checkpointing import (
|
||||
GradientCheckpointer,
|
||||
CheckpointStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MLMemoryOptimizer",
|
||||
"ModelProfile",
|
||||
"OptimizationPlan",
|
||||
"TrainingConfig",
|
||||
"MemoryOptimizationStrategy",
|
||||
"GradientCheckpointer",
|
||||
"CheckpointStrategy",
|
||||
]
|
||||
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
@ -0,0 +1,286 @@
|
||||
"""
|
||||
Gradient checkpointing utilities for memory-efficient training.
|
||||
"""
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
# Framework imports
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
HAS_TF = True
|
||||
except ImportError:
|
||||
HAS_TF = False
|
||||
|
||||
|
||||
class CheckpointStrategy(Enum):
|
||||
"""Checkpointing strategies."""
|
||||
SQRT_N = "sqrt_n" # Checkpoint every √n layers
|
||||
UNIFORM = "uniform" # Uniform intervals
|
||||
MEMORY_BASED = "memory" # Based on memory usage
|
||||
SELECTIVE = "selective" # Only expensive layers
|
||||
|
||||
|
||||
class GradientCheckpointer:
|
||||
"""
|
||||
Gradient checkpointing for memory-efficient training.
|
||||
|
||||
Implements Williams' √n strategy for optimal space-time tradeoff.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N):
|
||||
self.strategy = strategy
|
||||
|
||||
def apply_checkpointing(self,
|
||||
model: Any,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||
"""
|
||||
Apply gradient checkpointing to model.
|
||||
|
||||
Args:
|
||||
model: Neural network model
|
||||
checkpoint_layers: Specific layers to checkpoint (None for auto)
|
||||
|
||||
Returns:
|
||||
Model with checkpointing applied
|
||||
"""
|
||||
if HAS_TORCH and isinstance(model, nn.Module):
|
||||
return self._apply_torch_checkpointing(model, checkpoint_layers)
|
||||
elif HAS_TF:
|
||||
return self._apply_tf_checkpointing(model, checkpoint_layers)
|
||||
else:
|
||||
print("Warning: No supported framework found for checkpointing")
|
||||
return model
|
||||
|
||||
def _apply_torch_checkpointing(self,
|
||||
model: nn.Module,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> nn.Module:
|
||||
"""Apply checkpointing to PyTorch model."""
|
||||
if checkpoint_layers is None:
|
||||
checkpoint_layers = self._select_checkpoint_layers_torch(model)
|
||||
|
||||
# Wrap forward methods of selected layers
|
||||
for name, module in model.named_modules():
|
||||
if name in checkpoint_layers:
|
||||
self._wrap_module_torch(module)
|
||||
|
||||
return model
|
||||
|
||||
def _wrap_module_torch(self, module: nn.Module) -> None:
|
||||
"""Wrap PyTorch module with gradient checkpointing."""
|
||||
original_forward = module.forward
|
||||
|
||||
def checkpointed_forward(*args, **kwargs):
|
||||
# Use PyTorch's checkpoint function
|
||||
if module.training:
|
||||
return checkpoint(original_forward, *args, **kwargs)
|
||||
else:
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
module.forward = checkpointed_forward
|
||||
|
||||
def _apply_tf_checkpointing(self,
|
||||
model: Any,
|
||||
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||
"""Apply checkpointing to TensorFlow model."""
|
||||
if checkpoint_layers is None:
|
||||
checkpoint_layers = self._select_checkpoint_layers_tf(model)
|
||||
|
||||
# TensorFlow implementation
|
||||
# Note: TF2 has different checkpointing mechanism
|
||||
print(f"TensorFlow checkpointing selected {len(checkpoint_layers)} layers")
|
||||
|
||||
return model
|
||||
|
||||
def _select_checkpoint_layers_torch(self, model: nn.Module) -> List[str]:
|
||||
"""Select layers to checkpoint for PyTorch model."""
|
||||
layers = []
|
||||
|
||||
# Get all layers
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # Leaf modules
|
||||
layers.append((name, module))
|
||||
|
||||
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||
# Select √n evenly spaced layers
|
||||
n = len(layers)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
interval = max(1, int(math.sqrt(n)))
|
||||
selected = []
|
||||
|
||||
for i in range(0, n, interval):
|
||||
name, module = layers[i]
|
||||
if self._can_checkpoint_module(module):
|
||||
selected.append(name)
|
||||
|
||||
return selected
|
||||
|
||||
elif self.strategy == CheckpointStrategy.MEMORY_BASED:
|
||||
# Select layers with large activation memory
|
||||
memory_layers = []
|
||||
|
||||
for name, module in layers:
|
||||
memory = self._estimate_module_memory(module)
|
||||
memory_layers.append((name, memory))
|
||||
|
||||
# Sort by memory and select top √n
|
||||
memory_layers.sort(key=lambda x: x[1], reverse=True)
|
||||
n_checkpoint = max(1, int(math.sqrt(len(memory_layers))))
|
||||
|
||||
return [name for name, _ in memory_layers[:n_checkpoint]]
|
||||
|
||||
else:
|
||||
# Default: checkpoint all eligible layers
|
||||
return [name for name, module in layers if self._can_checkpoint_module(module)]
|
||||
|
||||
def _select_checkpoint_layers_tf(self, model: Any) -> List[str]:
|
||||
"""Select layers to checkpoint for TensorFlow model."""
|
||||
if not hasattr(model, 'layers'):
|
||||
return []
|
||||
|
||||
layers = [(layer.name, layer) for layer in model.layers]
|
||||
|
||||
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||
n = len(layers)
|
||||
interval = max(1, int(math.sqrt(n)))
|
||||
|
||||
selected = []
|
||||
for i in range(0, n, interval):
|
||||
name, layer = layers[i]
|
||||
selected.append(name)
|
||||
|
||||
return selected
|
||||
|
||||
return [name for name, _ in layers]
|
||||
|
||||
def _can_checkpoint_module(self, module: Any) -> bool:
|
||||
"""Check if module can be safely checkpointed."""
|
||||
if HAS_TORCH:
|
||||
# Avoid checkpointing modules with randomness
|
||||
no_checkpoint = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
|
||||
return not isinstance(module, no_checkpoint)
|
||||
return True
|
||||
|
||||
def _estimate_module_memory(self, module: Any) -> int:
|
||||
"""Estimate memory usage of module activations."""
|
||||
if HAS_TORCH and isinstance(module, nn.Module):
|
||||
# Estimate based on output size
|
||||
if isinstance(module, nn.Linear):
|
||||
return module.out_features * 4 # FP32
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
# Rough estimate
|
||||
return module.out_channels * 100 * 100 * 4
|
||||
else:
|
||||
# Default estimate
|
||||
params = sum(p.numel() for p in module.parameters())
|
||||
return params * 4
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def create_checkpoint_segments(model: Any,
|
||||
n_segments: Optional[int] = None) -> List[List[str]]:
|
||||
"""
|
||||
Create checkpoint segments using √n strategy.
|
||||
|
||||
Args:
|
||||
model: Neural network model
|
||||
n_segments: Number of segments (None for √n)
|
||||
|
||||
Returns:
|
||||
List of layer name segments
|
||||
"""
|
||||
# Get all layers
|
||||
if HAS_TORCH and isinstance(model, nn.Module):
|
||||
all_layers = [name for name, _ in model.named_modules()
|
||||
if len(list(_.children())) == 0]
|
||||
elif HAS_TF and hasattr(model, 'layers'):
|
||||
all_layers = [layer.name for layer in model.layers]
|
||||
else:
|
||||
return []
|
||||
|
||||
n = len(all_layers)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
# Use √n segments by default
|
||||
if n_segments is None:
|
||||
n_segments = max(1, int(math.sqrt(n)))
|
||||
|
||||
# Create segments
|
||||
segment_size = max(1, n // n_segments)
|
||||
segments = []
|
||||
|
||||
for i in range(0, n, segment_size):
|
||||
segment = all_layers[i:i + segment_size]
|
||||
if segment:
|
||||
segments.append(segment)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def checkpoint_sequential(modules: List[Any],
|
||||
input: Any,
|
||||
segments: Optional[int] = None) -> Any:
|
||||
"""
|
||||
Checkpoint a sequential model using √n segments.
|
||||
|
||||
Args:
|
||||
modules: List of modules to execute sequentially
|
||||
input: Input tensor
|
||||
segments: Number of checkpoint segments (None for √n)
|
||||
|
||||
Returns:
|
||||
Output tensor
|
||||
"""
|
||||
if not HAS_TORCH:
|
||||
# Fallback to normal execution
|
||||
x = input
|
||||
for module in modules:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
n = len(modules)
|
||||
if n == 0:
|
||||
return input
|
||||
|
||||
# Use √n segments
|
||||
if segments is None:
|
||||
segments = max(1, int(math.sqrt(n)))
|
||||
|
||||
segment_size = max(1, n // segments)
|
||||
|
||||
# Execute with checkpointing
|
||||
x = input
|
||||
for i in range(0, n, segment_size):
|
||||
segment = modules[i:i + segment_size]
|
||||
|
||||
if len(segment) == 1:
|
||||
# Single module
|
||||
if modules[0].training:
|
||||
x = checkpoint(segment[0], x)
|
||||
else:
|
||||
x = segment[0](x)
|
||||
else:
|
||||
# Multiple modules - create sequential wrapper
|
||||
def run_segment(x, *modules):
|
||||
for module in modules:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
if modules[0].training:
|
||||
x = checkpoint(run_segment, x, *segment)
|
||||
else:
|
||||
x = run_segment(x, *segment)
|
||||
|
||||
return x
|
||||
488
src/sqrtspace_spacetime/ml/optimizer.py
Normal file
488
src/sqrtspace_spacetime/ml/optimizer.py
Normal file
@ -0,0 +1,488 @@
|
||||
"""
|
||||
ML Training Memory Optimizer: Optimize neural network training memory usage.
|
||||
|
||||
Features:
|
||||
- Layer-by-layer memory profiling
|
||||
- Automatic gradient checkpointing with √n intervals
|
||||
- Mixed precision configuration
|
||||
- Batch size optimization
|
||||
- Framework-agnostic (PyTorch/TensorFlow)
|
||||
"""
|
||||
|
||||
import math
|
||||
import psutil
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.memory import monitor
|
||||
|
||||
# Try to import ML frameworks
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
HAS_TF = True
|
||||
except ImportError:
|
||||
HAS_TF = False
|
||||
|
||||
|
||||
class MemoryOptimizationStrategy(Enum):
|
||||
"""Memory optimization strategies for ML training."""
|
||||
GRADIENT_CHECKPOINTING = "gradient_checkpointing" # Recompute activations
|
||||
MIXED_PRECISION = "mixed_precision" # FP16/BF16 training
|
||||
GRADIENT_ACCUMULATION = "gradient_accumulation" # Smaller effective batch
|
||||
MODEL_SHARDING = "model_sharding" # Distribute layers
|
||||
ACTIVATION_COMPRESSION = "activation_compression" # Compress intermediate
|
||||
DYNAMIC_BATCH_SIZE = "dynamic_batch_size" # Adjust on the fly
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerProfile:
|
||||
"""Profile of a neural network layer."""
|
||||
name: str
|
||||
layer_type: str
|
||||
parameters: int
|
||||
activation_size: int # Per sample
|
||||
gradient_size: int # Per sample
|
||||
computation_time: float
|
||||
memory_bytes: int
|
||||
can_checkpoint: bool
|
||||
precision: str # 'fp32', 'fp16', 'int8'
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProfile:
|
||||
"""Complete model memory profile."""
|
||||
total_parameters: int
|
||||
total_activations: int # Per sample
|
||||
peak_memory: int
|
||||
layers: List[LayerProfile]
|
||||
memory_timeline: List[Tuple[str, int]] # (operation, memory)
|
||||
bottleneck_layers: List[str]
|
||||
framework: str # 'pytorch', 'tensorflow', 'generic'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationPlan:
|
||||
"""Optimization plan for model training."""
|
||||
strategies: List[MemoryOptimizationStrategy]
|
||||
checkpoint_layers: List[str]
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
mixed_precision_config: Dict[str, Any]
|
||||
estimated_memory: int
|
||||
estimated_speedup: float
|
||||
memory_savings: int
|
||||
explanation: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for optimized training."""
|
||||
original_batch_size: int
|
||||
optimized_batch_size: int
|
||||
accumulation_steps: int
|
||||
checkpoint_segments: List[List[str]]
|
||||
precision_map: Dict[str, str]
|
||||
memory_limit: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class MLMemoryOptimizer:
|
||||
"""Optimize memory usage for ML model training."""
|
||||
|
||||
def __init__(self, memory_limit: Optional[int] = None):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
memory_limit: Memory limit in bytes (None for auto-detect)
|
||||
"""
|
||||
self.memory_limit = memory_limit or int(psutil.virtual_memory().available * 0.8)
|
||||
|
||||
def analyze_model(self,
|
||||
model: Any,
|
||||
input_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
|
||||
batch_size: int = 1) -> ModelProfile:
|
||||
"""
|
||||
Analyze model memory requirements.
|
||||
|
||||
Args:
|
||||
model: Neural network model
|
||||
input_shape: Input shape(s)
|
||||
batch_size: Batch size for analysis
|
||||
|
||||
Returns:
|
||||
ModelProfile with memory analysis
|
||||
"""
|
||||
if HAS_TORCH and isinstance(model, nn.Module):
|
||||
return self._analyze_torch_model(model, input_shape, batch_size)
|
||||
elif HAS_TF and hasattr(model, 'layers'):
|
||||
return self._analyze_tf_model(model, input_shape, batch_size)
|
||||
else:
|
||||
return self._analyze_generic_model(model, input_shape, batch_size)
|
||||
|
||||
def _analyze_torch_model(self,
|
||||
model: nn.Module,
|
||||
input_shape: Tuple[int, ...],
|
||||
batch_size: int) -> ModelProfile:
|
||||
"""Analyze PyTorch model."""
|
||||
layers = []
|
||||
total_params = 0
|
||||
total_activations = 0
|
||||
memory_timeline = []
|
||||
|
||||
# Count parameters
|
||||
for name, param in model.named_parameters():
|
||||
total_params += param.numel()
|
||||
|
||||
# Analyze layers
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # Leaf module
|
||||
layer_params = sum(p.numel() for p in module.parameters())
|
||||
|
||||
# Estimate activation size (simplified)
|
||||
if isinstance(module, nn.Linear):
|
||||
activation_size = module.out_features * batch_size * 4 # fp32
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
# Rough estimate
|
||||
activation_size = module.out_channels * 100 * 100 * batch_size * 4
|
||||
else:
|
||||
activation_size = layer_params * batch_size * 4
|
||||
|
||||
total_activations += activation_size
|
||||
|
||||
layers.append(LayerProfile(
|
||||
name=name,
|
||||
layer_type=module.__class__.__name__,
|
||||
parameters=layer_params,
|
||||
activation_size=activation_size // batch_size,
|
||||
gradient_size=layer_params * 4, # fp32 gradients
|
||||
computation_time=0.001, # Placeholder
|
||||
memory_bytes=layer_params * 4 + activation_size,
|
||||
can_checkpoint=self._can_checkpoint_layer(module),
|
||||
precision='fp32'
|
||||
))
|
||||
|
||||
# Find bottlenecks (top 20% by memory)
|
||||
sorted_layers = sorted(layers, key=lambda l: l.memory_bytes, reverse=True)
|
||||
bottleneck_count = max(1, len(layers) // 5)
|
||||
bottleneck_layers = [l.name for l in sorted_layers[:bottleneck_count]]
|
||||
|
||||
return ModelProfile(
|
||||
total_parameters=total_params,
|
||||
total_activations=total_activations // batch_size,
|
||||
peak_memory=total_params * 4 + total_activations,
|
||||
layers=layers,
|
||||
memory_timeline=memory_timeline,
|
||||
bottleneck_layers=bottleneck_layers,
|
||||
framework='pytorch'
|
||||
)
|
||||
|
||||
def _analyze_tf_model(self,
|
||||
model: Any,
|
||||
input_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
|
||||
batch_size: int) -> ModelProfile:
|
||||
"""Analyze TensorFlow model."""
|
||||
layers = []
|
||||
total_params = model.count_params()
|
||||
total_activations = 0
|
||||
|
||||
# Analyze each layer
|
||||
for layer in model.layers:
|
||||
layer_params = layer.count_params()
|
||||
|
||||
# Estimate activation size
|
||||
if hasattr(layer, 'output_shape'):
|
||||
shape = layer.output_shape
|
||||
if isinstance(shape, tuple):
|
||||
activation_size = np.prod(shape[1:]) * batch_size * 4
|
||||
else:
|
||||
activation_size = layer_params * batch_size * 4
|
||||
else:
|
||||
activation_size = layer_params * batch_size * 4
|
||||
|
||||
total_activations += activation_size
|
||||
|
||||
layers.append(LayerProfile(
|
||||
name=layer.name,
|
||||
layer_type=layer.__class__.__name__,
|
||||
parameters=layer_params,
|
||||
activation_size=activation_size // batch_size,
|
||||
gradient_size=layer_params * 4,
|
||||
computation_time=0.001,
|
||||
memory_bytes=layer_params * 4 + activation_size,
|
||||
can_checkpoint=True, # Most TF layers can checkpoint
|
||||
precision='fp32'
|
||||
))
|
||||
|
||||
# Find bottlenecks
|
||||
sorted_layers = sorted(layers, key=lambda l: l.memory_bytes, reverse=True)
|
||||
bottleneck_count = max(1, len(layers) // 5)
|
||||
bottleneck_layers = [l.name for l in sorted_layers[:bottleneck_count]]
|
||||
|
||||
return ModelProfile(
|
||||
total_parameters=total_params,
|
||||
total_activations=total_activations // batch_size,
|
||||
peak_memory=total_params * 4 + total_activations,
|
||||
layers=layers,
|
||||
memory_timeline=[],
|
||||
bottleneck_layers=bottleneck_layers,
|
||||
framework='tensorflow'
|
||||
)
|
||||
|
||||
def _analyze_generic_model(self,
|
||||
model: Any,
|
||||
input_shape: Tuple[int, ...],
|
||||
batch_size: int) -> ModelProfile:
|
||||
"""Analyze generic model."""
|
||||
# Basic heuristics
|
||||
estimated_params = 10_000_000 # 10M parameters
|
||||
estimated_activations = estimated_params * batch_size
|
||||
|
||||
return ModelProfile(
|
||||
total_parameters=estimated_params,
|
||||
total_activations=estimated_activations,
|
||||
peak_memory=estimated_params * 4 + estimated_activations * 4,
|
||||
layers=[],
|
||||
memory_timeline=[],
|
||||
bottleneck_layers=[],
|
||||
framework='generic'
|
||||
)
|
||||
|
||||
def optimize(self,
|
||||
model_profile: ModelProfile,
|
||||
target_batch_size: int,
|
||||
strategies: Optional[List[MemoryOptimizationStrategy]] = None) -> OptimizationPlan:
|
||||
"""
|
||||
Generate optimization plan for model.
|
||||
|
||||
Args:
|
||||
model_profile: Model profile from analyze_model
|
||||
target_batch_size: Desired batch size
|
||||
strategies: Strategies to consider (None for auto)
|
||||
|
||||
Returns:
|
||||
OptimizationPlan with recommendations
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self._select_strategies(model_profile, target_batch_size)
|
||||
|
||||
# Calculate memory requirements
|
||||
base_memory = model_profile.total_parameters * 4 # Parameters
|
||||
activation_memory = model_profile.total_activations * target_batch_size * 4
|
||||
gradient_memory = model_profile.total_parameters * 4 # Gradients
|
||||
optimizer_memory = model_profile.total_parameters * 8 # Adam states
|
||||
|
||||
total_memory = base_memory + activation_memory + gradient_memory + optimizer_memory
|
||||
|
||||
# Initialize plan
|
||||
plan = OptimizationPlan(
|
||||
strategies=strategies,
|
||||
checkpoint_layers=[],
|
||||
batch_size=target_batch_size,
|
||||
gradient_accumulation_steps=1,
|
||||
mixed_precision_config={},
|
||||
estimated_memory=total_memory,
|
||||
estimated_speedup=1.0,
|
||||
memory_savings=0,
|
||||
explanation=""
|
||||
)
|
||||
|
||||
# Apply strategies
|
||||
for strategy in strategies:
|
||||
if strategy == MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING:
|
||||
self._apply_checkpointing(plan, model_profile)
|
||||
elif strategy == MemoryOptimizationStrategy.MIXED_PRECISION:
|
||||
self._apply_mixed_precision(plan, model_profile)
|
||||
elif strategy == MemoryOptimizationStrategy.GRADIENT_ACCUMULATION:
|
||||
self._apply_gradient_accumulation(plan, model_profile)
|
||||
|
||||
# Calculate final estimates
|
||||
plan.memory_savings = total_memory - plan.estimated_memory
|
||||
plan.explanation = self._generate_explanation(plan, model_profile)
|
||||
|
||||
return plan
|
||||
|
||||
def _select_strategies(self,
|
||||
model_profile: ModelProfile,
|
||||
target_batch_size: int) -> List[MemoryOptimizationStrategy]:
|
||||
"""Select appropriate optimization strategies."""
|
||||
strategies = []
|
||||
|
||||
# Calculate memory pressure
|
||||
required_memory = (model_profile.total_parameters * 4 +
|
||||
model_profile.total_activations * target_batch_size * 4)
|
||||
|
||||
if required_memory > self.memory_limit:
|
||||
# High memory pressure - use all strategies
|
||||
strategies.append(MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING)
|
||||
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||
strategies.append(MemoryOptimizationStrategy.GRADIENT_ACCUMULATION)
|
||||
elif required_memory > self.memory_limit * 0.8:
|
||||
# Medium pressure
|
||||
strategies.append(MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING)
|
||||
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||
elif required_memory > self.memory_limit * 0.6:
|
||||
# Low pressure
|
||||
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||
|
||||
return strategies
|
||||
|
||||
def _apply_checkpointing(self,
|
||||
plan: OptimizationPlan,
|
||||
model_profile: ModelProfile) -> None:
|
||||
"""Apply gradient checkpointing using √n strategy."""
|
||||
n_layers = len(model_profile.layers)
|
||||
|
||||
if n_layers == 0:
|
||||
return
|
||||
|
||||
# Use √n checkpointing intervals
|
||||
checkpoint_interval = max(1, int(math.sqrt(n_layers)))
|
||||
|
||||
# Select layers to checkpoint
|
||||
checkpoint_layers = []
|
||||
for i in range(0, n_layers, checkpoint_interval):
|
||||
if i < len(model_profile.layers):
|
||||
layer = model_profile.layers[i]
|
||||
if layer.can_checkpoint:
|
||||
checkpoint_layers.append(layer.name)
|
||||
|
||||
plan.checkpoint_layers = checkpoint_layers
|
||||
|
||||
# Update memory estimate (save ~50% of activation memory)
|
||||
saved_memory = sum(l.activation_size * plan.batch_size * 4
|
||||
for l in model_profile.layers
|
||||
if l.name in checkpoint_layers) * 0.5
|
||||
|
||||
plan.estimated_memory -= int(saved_memory)
|
||||
plan.estimated_speedup *= 0.8 # 20% slowdown from recomputation
|
||||
|
||||
def _apply_mixed_precision(self,
|
||||
plan: OptimizationPlan,
|
||||
model_profile: ModelProfile) -> None:
|
||||
"""Apply mixed precision training."""
|
||||
plan.mixed_precision_config = {
|
||||
'enabled': True,
|
||||
'loss_scale': 'dynamic',
|
||||
'compute_dtype': 'float16',
|
||||
'variable_dtype': 'float32'
|
||||
}
|
||||
|
||||
# Update memory estimate (save ~50% on activations)
|
||||
activation_savings = model_profile.total_activations * plan.batch_size * 2
|
||||
plan.estimated_memory -= activation_savings
|
||||
plan.estimated_speedup *= 1.5 # Potential speedup on modern GPUs
|
||||
|
||||
def _apply_gradient_accumulation(self,
|
||||
plan: OptimizationPlan,
|
||||
model_profile: ModelProfile) -> None:
|
||||
"""Apply gradient accumulation."""
|
||||
# Calculate how many accumulation steps needed
|
||||
current_memory = plan.estimated_memory
|
||||
|
||||
if current_memory > self.memory_limit:
|
||||
# Reduce effective batch size
|
||||
reduction_factor = current_memory / self.memory_limit
|
||||
accumulation_steps = int(math.ceil(reduction_factor))
|
||||
|
||||
# Adjust batch size and accumulation
|
||||
effective_batch = plan.batch_size // accumulation_steps
|
||||
plan.batch_size = max(1, effective_batch)
|
||||
plan.gradient_accumulation_steps = accumulation_steps
|
||||
|
||||
# Update memory estimate
|
||||
plan.estimated_memory = plan.estimated_memory // accumulation_steps
|
||||
|
||||
def _can_checkpoint_layer(self, layer: Any) -> bool:
|
||||
"""Check if layer can be checkpointed."""
|
||||
if HAS_TORCH:
|
||||
# Most layers can be checkpointed except those with side effects
|
||||
no_checkpoint_types = (nn.Dropout, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
|
||||
return not isinstance(layer, no_checkpoint_types)
|
||||
return True
|
||||
|
||||
def _generate_explanation(self,
|
||||
plan: OptimizationPlan,
|
||||
model_profile: ModelProfile) -> str:
|
||||
"""Generate human-readable explanation."""
|
||||
explanations = []
|
||||
|
||||
explanations.append(f"Model Analysis:")
|
||||
explanations.append(f"- Total parameters: {model_profile.total_parameters:,}")
|
||||
explanations.append(f"- Peak memory estimate: {plan.estimated_memory / (1024**3):.2f} GB")
|
||||
explanations.append(f"- Memory savings: {plan.memory_savings / (1024**3):.2f} GB")
|
||||
|
||||
if MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING in plan.strategies:
|
||||
explanations.append(f"\nGradient Checkpointing:")
|
||||
explanations.append(f"- Checkpointing {len(plan.checkpoint_layers)} layers using √n strategy")
|
||||
explanations.append(f"- This trades ~20% compute time for ~50% activation memory")
|
||||
|
||||
if MemoryOptimizationStrategy.MIXED_PRECISION in plan.strategies:
|
||||
explanations.append(f"\nMixed Precision:")
|
||||
explanations.append(f"- Using FP16 for forward pass, FP32 for gradients")
|
||||
explanations.append(f"- Reduces activation memory by ~50%")
|
||||
|
||||
if plan.gradient_accumulation_steps > 1:
|
||||
explanations.append(f"\nGradient Accumulation:")
|
||||
explanations.append(f"- Accumulating over {plan.gradient_accumulation_steps} steps")
|
||||
explanations.append(f"- Effective batch size: {plan.batch_size * plan.gradient_accumulation_steps}")
|
||||
|
||||
return "\n".join(explanations)
|
||||
|
||||
def get_training_config(self,
|
||||
plan: OptimizationPlan,
|
||||
model_profile: ModelProfile) -> TrainingConfig:
|
||||
"""
|
||||
Generate training configuration from optimization plan.
|
||||
|
||||
Args:
|
||||
plan: Optimization plan
|
||||
model_profile: Model profile
|
||||
|
||||
Returns:
|
||||
TrainingConfig ready for use
|
||||
"""
|
||||
# Group checkpoint layers into segments
|
||||
checkpoint_segments = []
|
||||
if plan.checkpoint_layers:
|
||||
# Create √n segments
|
||||
n_segments = int(math.sqrt(len(plan.checkpoint_layers)))
|
||||
segment_size = max(1, len(plan.checkpoint_layers) // n_segments)
|
||||
|
||||
for i in range(0, len(plan.checkpoint_layers), segment_size):
|
||||
segment = plan.checkpoint_layers[i:i + segment_size]
|
||||
if segment:
|
||||
checkpoint_segments.append(segment)
|
||||
|
||||
# Create precision map
|
||||
precision_map = {}
|
||||
if MemoryOptimizationStrategy.MIXED_PRECISION in plan.strategies:
|
||||
for layer in model_profile.layers:
|
||||
# Use FP16 for compute-heavy layers
|
||||
if layer.layer_type in ['Linear', 'Conv2d', 'Dense', 'Conv2D']:
|
||||
precision_map[layer.name] = 'fp16'
|
||||
else:
|
||||
precision_map[layer.name] = 'fp32'
|
||||
|
||||
return TrainingConfig(
|
||||
original_batch_size=plan.batch_size * plan.gradient_accumulation_steps,
|
||||
optimized_batch_size=plan.batch_size,
|
||||
accumulation_steps=plan.gradient_accumulation_steps,
|
||||
checkpoint_segments=checkpoint_segments,
|
||||
precision_map=precision_map,
|
||||
memory_limit=self.memory_limit
|
||||
)
|
||||
25
src/sqrtspace_spacetime/profiler/__init__.py
Normal file
25
src/sqrtspace_spacetime/profiler/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""SpaceTime Profiler for memory and performance analysis."""
|
||||
|
||||
from sqrtspace_spacetime.profiler.profiler import (
|
||||
SpaceTimeProfiler,
|
||||
ProfilingReport,
|
||||
Hotspot,
|
||||
BottleneckAnalysis,
|
||||
AccessPattern,
|
||||
)
|
||||
from sqrtspace_spacetime.profiler.decorators import (
|
||||
profile,
|
||||
profile_memory,
|
||||
profile_time,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpaceTimeProfiler",
|
||||
"ProfilingReport",
|
||||
"Hotspot",
|
||||
"BottleneckAnalysis",
|
||||
"AccessPattern",
|
||||
"profile",
|
||||
"profile_memory",
|
||||
"profile_time",
|
||||
]
|
||||
175
src/sqrtspace_spacetime/profiler/decorators.py
Normal file
175
src/sqrtspace_spacetime/profiler/decorators.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""Decorators for easy profiling."""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from sqrtspace_spacetime.profiler.profiler import SpaceTimeProfiler
|
||||
|
||||
|
||||
def profile(output_file: Optional[str] = None,
|
||||
print_summary: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator to profile a function.
|
||||
|
||||
Args:
|
||||
output_file: Optional file to save report
|
||||
print_summary: Print summary to console
|
||||
|
||||
Example:
|
||||
@profile(output_file="profile.json")
|
||||
def my_function():
|
||||
# Process data
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
profiler = SpaceTimeProfiler()
|
||||
result, report = profiler.profile(func, *args, **kwargs)
|
||||
|
||||
if print_summary:
|
||||
print(report.summary)
|
||||
|
||||
if output_file:
|
||||
report.save(output_file)
|
||||
|
||||
# Store report on function for access
|
||||
wrapper.last_report = report
|
||||
|
||||
return result
|
||||
|
||||
wrapper.last_report = None
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def profile_memory(threshold_mb: float = 100,
|
||||
alert: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator to profile memory usage.
|
||||
|
||||
Args:
|
||||
threshold_mb: Memory threshold in MB to trigger alert
|
||||
alert: Print alert if threshold exceeded
|
||||
|
||||
Example:
|
||||
@profile_memory(threshold_mb=500)
|
||||
def process_large_data():
|
||||
# Process data
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
|
||||
start_memory = process.memory_info().rss
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
end_memory = process.memory_info().rss
|
||||
end_time = time.time()
|
||||
|
||||
memory_used = (end_memory - start_memory) / (1024 * 1024)
|
||||
duration = end_time - start_time
|
||||
|
||||
# Store metrics
|
||||
wrapper.memory_used = memory_used
|
||||
wrapper.duration = duration
|
||||
|
||||
if alert and memory_used > threshold_mb:
|
||||
print(f"⚠️ Memory Alert: {func.__name__} used {memory_used:.1f}MB "
|
||||
f"(threshold: {threshold_mb}MB)")
|
||||
print(f" Consider using SpaceTime collections for memory efficiency")
|
||||
|
||||
if alert:
|
||||
print(f"Memory: {memory_used:.1f}MB, Time: {duration:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
wrapper.memory_used = None
|
||||
wrapper.duration = None
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def profile_time(threshold_seconds: float = 1.0,
|
||||
alert: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator to profile execution time.
|
||||
|
||||
Args:
|
||||
threshold_seconds: Time threshold to trigger alert
|
||||
alert: Print alert if threshold exceeded
|
||||
|
||||
Example:
|
||||
@profile_time(threshold_seconds=5.0)
|
||||
def slow_operation():
|
||||
# Time-consuming operation
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
wrapper.duration = duration
|
||||
|
||||
if alert and duration > threshold_seconds:
|
||||
print(f"⏱️ Time Alert: {func.__name__} took {duration:.2f}s "
|
||||
f"(threshold: {threshold_seconds}s)")
|
||||
|
||||
if alert:
|
||||
print(f"Execution time: {duration:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
wrapper.duration = None
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ProfileContext:
|
||||
"""Context manager for profiling code blocks."""
|
||||
|
||||
def __init__(self, name: str = "block", print_summary: bool = True):
|
||||
self.name = name
|
||||
self.print_summary = print_summary
|
||||
self.profiler = None
|
||||
self.report = None
|
||||
self._monitoring = False
|
||||
|
||||
def __enter__(self):
|
||||
self.profiler = SpaceTimeProfiler()
|
||||
self.profiler.start_monitoring()
|
||||
self._start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
duration = time.time() - self._start_time
|
||||
self.profiler.stop_monitoring()
|
||||
|
||||
# Generate simple report
|
||||
if self.print_summary:
|
||||
peak_memory = max((m[1] for m in self.profiler.memory_timeline), default=0)
|
||||
print(f"\nProfile: {self.name}")
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
print(f"Peak Memory: {peak_memory / (1024*1024):.1f}MB")
|
||||
|
||||
if peak_memory > 100 * 1024 * 1024: # 100MB
|
||||
print("💡 Consider using SpaceTime collections for memory optimization")
|
||||
|
||||
|
||||
# Convenience instance
|
||||
profile_context = ProfileContext
|
||||
475
src/sqrtspace_spacetime/profiler/profiler.py
Normal file
475
src/sqrtspace_spacetime/profiler/profiler.py
Normal file
@ -0,0 +1,475 @@
|
||||
"""
|
||||
SpaceTime Profiler: Profile applications to identify optimization opportunities.
|
||||
|
||||
Features:
|
||||
- Memory pattern analysis (sequential, random, strided)
|
||||
- Bottleneck detection (memory vs CPU)
|
||||
- Memory hierarchy awareness (L1/L2/L3/RAM/Disk)
|
||||
- Hotspot identification
|
||||
- AI-generated recommendations
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
import psutil
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
import cProfile
|
||||
import pstats
|
||||
import io
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
|
||||
|
||||
class AccessPattern(Enum):
|
||||
"""Memory access patterns."""
|
||||
SEQUENTIAL = "sequential"
|
||||
RANDOM = "random"
|
||||
STRIDED = "strided"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryAccess:
|
||||
"""Single memory access event."""
|
||||
timestamp: float
|
||||
address: int
|
||||
size: int
|
||||
operation: str # 'read' or 'write'
|
||||
function: str
|
||||
line_number: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hotspot:
|
||||
"""Memory hotspot information."""
|
||||
function: str
|
||||
file_path: str
|
||||
line_number: int
|
||||
memory_allocated: int
|
||||
memory_freed: int
|
||||
net_memory: int
|
||||
allocation_count: int
|
||||
cpu_time: float
|
||||
access_pattern: AccessPattern
|
||||
recommendations: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BottleneckAnalysis:
|
||||
"""Analysis of performance bottlenecks."""
|
||||
type: str # 'memory', 'cpu', 'io'
|
||||
severity: float # 0.0 to 1.0
|
||||
description: str
|
||||
evidence: Dict[str, Any]
|
||||
recommendations: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingReport:
|
||||
"""Complete profiling report."""
|
||||
timestamp: str
|
||||
duration: float
|
||||
peak_memory: int
|
||||
total_allocations: int
|
||||
memory_timeline: List[Tuple[float, int]]
|
||||
cpu_timeline: List[Tuple[float, float]]
|
||||
hotspots: List[Hotspot]
|
||||
bottlenecks: List[BottleneckAnalysis]
|
||||
access_patterns: Dict[str, AccessPattern]
|
||||
hierarchy_transitions: Dict[str, int]
|
||||
optimization_opportunities: List[Dict[str, Any]]
|
||||
summary: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert report to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""Save report to JSON file."""
|
||||
import json
|
||||
with open(path, 'w') as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
|
||||
|
||||
class MemoryTracer:
|
||||
"""Trace memory accesses and allocations."""
|
||||
|
||||
def __init__(self, max_samples: int = 100000):
|
||||
self.accesses = deque(maxlen=max_samples)
|
||||
self.allocations = defaultdict(list)
|
||||
self.start_time = time.time()
|
||||
self._tracemalloc_snapshot = None
|
||||
|
||||
def start(self):
|
||||
"""Start memory tracing."""
|
||||
if not tracemalloc.is_tracing():
|
||||
tracemalloc.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop memory tracing."""
|
||||
if tracemalloc.is_tracing():
|
||||
self._tracemalloc_snapshot = tracemalloc.take_snapshot()
|
||||
tracemalloc.stop()
|
||||
|
||||
def analyze_pattern(self, accesses: List[MemoryAccess]) -> AccessPattern:
|
||||
"""Analyze access pattern from recent accesses."""
|
||||
if len(accesses) < 10:
|
||||
return AccessPattern.UNKNOWN
|
||||
|
||||
# Extract addresses
|
||||
addresses = [a.address for a in accesses[-100:]]
|
||||
|
||||
# Calculate differences
|
||||
diffs = np.diff(addresses)
|
||||
if len(diffs) == 0:
|
||||
return AccessPattern.UNKNOWN
|
||||
|
||||
# Check for sequential pattern
|
||||
if np.all(diffs > 0) and np.std(diffs) < np.mean(diffs) * 0.1:
|
||||
return AccessPattern.SEQUENTIAL
|
||||
|
||||
# Check for strided pattern
|
||||
unique_diffs = set(diffs)
|
||||
if len(unique_diffs) < 5 and np.std(diffs) < 100:
|
||||
return AccessPattern.STRIDED
|
||||
|
||||
# Otherwise random
|
||||
return AccessPattern.RANDOM
|
||||
|
||||
def get_top_allocators(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get top memory allocators from tracemalloc."""
|
||||
if not self._tracemalloc_snapshot:
|
||||
return []
|
||||
|
||||
top_stats = self._tracemalloc_snapshot.statistics('lineno')[:limit]
|
||||
|
||||
result = []
|
||||
for stat in top_stats:
|
||||
result.append({
|
||||
'file': stat.traceback.format()[0] if stat.traceback else 'unknown',
|
||||
'size': stat.size,
|
||||
'count': stat.count,
|
||||
'average': stat.size // stat.count if stat.count > 0 else 0
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SpaceTimeProfiler:
|
||||
"""Main profiler class."""
|
||||
|
||||
def __init__(self, sample_interval: float = 0.01):
|
||||
self.sample_interval = sample_interval
|
||||
self.memory_tracer = MemoryTracer()
|
||||
|
||||
# Tracking data
|
||||
self.memory_timeline = []
|
||||
self.cpu_timeline = []
|
||||
self.io_timeline = []
|
||||
self.function_stats = defaultdict(lambda: {
|
||||
'calls': 0,
|
||||
'memory': 0,
|
||||
'time': 0.0,
|
||||
'allocations': []
|
||||
})
|
||||
|
||||
self._monitoring = False
|
||||
self._monitor_thread = None
|
||||
self._start_time = None
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start background monitoring."""
|
||||
self._monitoring = True
|
||||
self._start_time = time.time()
|
||||
self.memory_tracer.start()
|
||||
|
||||
self._monitor_thread = threading.Thread(target=self._monitor_loop)
|
||||
self._monitor_thread.daemon = True
|
||||
self._monitor_thread.start()
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop background monitoring."""
|
||||
self._monitoring = False
|
||||
if self._monitor_thread:
|
||||
self._monitor_thread.join()
|
||||
self.memory_tracer.stop()
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background monitoring loop."""
|
||||
process = psutil.Process()
|
||||
|
||||
while self._monitoring:
|
||||
timestamp = time.time() - self._start_time
|
||||
|
||||
# Memory usage
|
||||
mem_info = process.memory_info()
|
||||
self.memory_timeline.append((timestamp, mem_info.rss))
|
||||
|
||||
# CPU usage
|
||||
cpu_percent = process.cpu_percent(interval=None)
|
||||
self.cpu_timeline.append((timestamp, cpu_percent))
|
||||
|
||||
# IO counters (if available)
|
||||
try:
|
||||
io_counters = process.io_counters()
|
||||
self.io_timeline.append((timestamp, {
|
||||
'read_bytes': io_counters.read_bytes,
|
||||
'write_bytes': io_counters.write_bytes,
|
||||
'read_count': io_counters.read_count,
|
||||
'write_count': io_counters.write_count
|
||||
}))
|
||||
except:
|
||||
pass
|
||||
|
||||
time.sleep(self.sample_interval)
|
||||
|
||||
def profile(self, func: Callable, *args, **kwargs) -> Tuple[Any, ProfilingReport]:
|
||||
"""Profile a function execution."""
|
||||
# Start monitoring
|
||||
self.start_monitoring()
|
||||
|
||||
# CPU profiling
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute function
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
# Stop profiling
|
||||
end_time = time.time()
|
||||
profiler.disable()
|
||||
self.stop_monitoring()
|
||||
|
||||
# Generate report
|
||||
report = self._generate_report(
|
||||
duration=end_time - start_time,
|
||||
cpu_profile=profiler
|
||||
)
|
||||
|
||||
return result, report
|
||||
|
||||
def _generate_report(self, duration: float, cpu_profile: cProfile.Profile) -> ProfilingReport:
|
||||
"""Generate comprehensive profiling report."""
|
||||
# Get peak memory
|
||||
peak_memory = max((m[1] for m in self.memory_timeline), default=0)
|
||||
|
||||
# Analyze components
|
||||
hotspots = self._analyze_hotspots(cpu_profile)
|
||||
bottlenecks = self._analyze_bottlenecks()
|
||||
patterns = self._analyze_access_patterns()
|
||||
transitions = self._count_hierarchy_transitions()
|
||||
opportunities = self._find_optimization_opportunities(hotspots, bottlenecks)
|
||||
|
||||
# Generate summary
|
||||
summary = self._generate_summary(duration, peak_memory, hotspots, bottlenecks)
|
||||
|
||||
return ProfilingReport(
|
||||
timestamp=datetime.now().isoformat(),
|
||||
duration=duration,
|
||||
peak_memory=peak_memory,
|
||||
total_allocations=len(self.memory_tracer.allocations),
|
||||
memory_timeline=self.memory_timeline,
|
||||
cpu_timeline=self.cpu_timeline,
|
||||
hotspots=hotspots,
|
||||
bottlenecks=bottlenecks,
|
||||
access_patterns=patterns,
|
||||
hierarchy_transitions=transitions,
|
||||
optimization_opportunities=opportunities,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
def _analyze_hotspots(self, cpu_profile: cProfile.Profile) -> List[Hotspot]:
|
||||
"""Identify performance hotspots."""
|
||||
stats = pstats.Stats(cpu_profile)
|
||||
stats.sort_stats('cumulative')
|
||||
|
||||
hotspots = []
|
||||
top_allocators = self.memory_tracer.get_top_allocators()
|
||||
|
||||
# Create lookup for memory stats
|
||||
memory_by_file = {stat['file']: stat for stat in top_allocators}
|
||||
|
||||
# Analyze top functions
|
||||
for func_info, (cc, nc, tt, ct, callers) in list(stats.stats.items())[:20]:
|
||||
filename, line_number, function_name = func_info
|
||||
|
||||
# Get memory info if available
|
||||
mem_info = memory_by_file.get(f"{filename}:{line_number}", {})
|
||||
|
||||
# Skip built-in functions
|
||||
if filename.startswith('<') or 'site-packages' in filename:
|
||||
continue
|
||||
|
||||
# Determine access pattern (simplified)
|
||||
pattern = AccessPattern.UNKNOWN
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
if ct > duration * 0.1: # More than 10% of time
|
||||
recommendations.append("Consider optimizing this function - it's a CPU hotspot")
|
||||
if mem_info.get('size', 0) > peak_memory * 0.1: # More than 10% of memory
|
||||
recommendations.append("This function allocates significant memory - consider √n optimization")
|
||||
|
||||
hotspots.append(Hotspot(
|
||||
function=function_name,
|
||||
file_path=filename,
|
||||
line_number=line_number,
|
||||
memory_allocated=mem_info.get('size', 0),
|
||||
memory_freed=0, # Not tracked in simple version
|
||||
net_memory=mem_info.get('size', 0),
|
||||
allocation_count=mem_info.get('count', 0),
|
||||
cpu_time=ct,
|
||||
access_pattern=pattern,
|
||||
recommendations=recommendations
|
||||
))
|
||||
|
||||
return hotspots
|
||||
|
||||
def _analyze_bottlenecks(self) -> List[BottleneckAnalysis]:
|
||||
"""Analyze performance bottlenecks."""
|
||||
bottlenecks = []
|
||||
|
||||
# Memory bottleneck analysis
|
||||
if self.memory_timeline:
|
||||
mem_values = [m[1] for m in self.memory_timeline]
|
||||
mem_growth = mem_values[-1] - mem_values[0] if len(mem_values) > 1 else 0
|
||||
|
||||
if mem_growth > 100 * 1024 * 1024: # 100MB growth
|
||||
bottlenecks.append(BottleneckAnalysis(
|
||||
type="memory",
|
||||
severity=min(1.0, mem_growth / (1024 * 1024 * 1024)), # GB scale
|
||||
description=f"Significant memory growth detected: {mem_growth / (1024*1024):.1f}MB",
|
||||
evidence={
|
||||
"start_memory": mem_values[0],
|
||||
"end_memory": mem_values[-1],
|
||||
"growth": mem_growth
|
||||
},
|
||||
recommendations=[
|
||||
"Consider using SpaceTime collections for large datasets",
|
||||
"Implement streaming processing with √n buffering",
|
||||
"Use external sorting/grouping algorithms"
|
||||
]
|
||||
))
|
||||
|
||||
# CPU bottleneck analysis
|
||||
if self.cpu_timeline:
|
||||
cpu_values = [c[1] for c in self.cpu_timeline]
|
||||
avg_cpu = np.mean(cpu_values) if cpu_values else 0
|
||||
|
||||
if avg_cpu > 80: # 80% CPU usage
|
||||
bottlenecks.append(BottleneckAnalysis(
|
||||
type="cpu",
|
||||
severity=min(1.0, avg_cpu / 100),
|
||||
description=f"High CPU usage detected: {avg_cpu:.1f}% average",
|
||||
evidence={
|
||||
"average_cpu": avg_cpu,
|
||||
"peak_cpu": max(cpu_values) if cpu_values else 0
|
||||
},
|
||||
recommendations=[
|
||||
"Profile CPU hotspots for optimization opportunities",
|
||||
"Consider parallel processing with √n chunk size",
|
||||
"Use more efficient algorithms"
|
||||
]
|
||||
))
|
||||
|
||||
return bottlenecks
|
||||
|
||||
def _analyze_access_patterns(self) -> Dict[str, AccessPattern]:
|
||||
"""Analyze memory access patterns by function."""
|
||||
# Simplified implementation
|
||||
return {"overall": AccessPattern.UNKNOWN}
|
||||
|
||||
def _count_hierarchy_transitions(self) -> Dict[str, int]:
|
||||
"""Count memory hierarchy transitions."""
|
||||
# Simplified implementation
|
||||
transitions = {
|
||||
"L1_to_L2": 0,
|
||||
"L2_to_L3": 0,
|
||||
"L3_to_RAM": 0,
|
||||
"RAM_to_Disk": 0
|
||||
}
|
||||
|
||||
# Estimate based on memory growth
|
||||
if self.memory_timeline:
|
||||
mem_values = [m[1] for m in self.memory_timeline]
|
||||
max_mem = max(mem_values) if mem_values else 0
|
||||
|
||||
if max_mem > 32 * 1024: # > L1
|
||||
transitions["L1_to_L2"] += 1
|
||||
if max_mem > 256 * 1024: # > L2
|
||||
transitions["L2_to_L3"] += 1
|
||||
if max_mem > 8 * 1024 * 1024: # > L3
|
||||
transitions["L3_to_RAM"] += 1
|
||||
if max_mem > 1024 * 1024 * 1024: # > 1GB
|
||||
transitions["RAM_to_Disk"] += 1
|
||||
|
||||
return transitions
|
||||
|
||||
def _find_optimization_opportunities(self,
|
||||
hotspots: List[Hotspot],
|
||||
bottlenecks: List[BottleneckAnalysis]) -> List[Dict[str, Any]]:
|
||||
"""Find SpaceTime optimization opportunities."""
|
||||
opportunities = []
|
||||
|
||||
# Check for large memory allocations
|
||||
for hotspot in hotspots:
|
||||
if hotspot.memory_allocated > 10 * 1024 * 1024: # 10MB
|
||||
opportunities.append({
|
||||
"type": "large_allocation",
|
||||
"location": f"{hotspot.file_path}:{hotspot.line_number}",
|
||||
"function": hotspot.function,
|
||||
"memory": hotspot.memory_allocated,
|
||||
"suggestion": "Use SpaceTimeArray or SpaceTimeDict for large collections",
|
||||
"potential_savings": f"{hotspot.memory_allocated * 0.9 / (1024*1024):.1f}MB"
|
||||
})
|
||||
|
||||
# Check for memory growth patterns
|
||||
memory_bottleneck = next((b for b in bottlenecks if b.type == "memory"), None)
|
||||
if memory_bottleneck:
|
||||
opportunities.append({
|
||||
"type": "memory_growth",
|
||||
"severity": memory_bottleneck.severity,
|
||||
"suggestion": "Implement streaming processing with Stream class",
|
||||
"example": "Stream.from_file('data.csv').map(process).chunk(√n).foreach(save)"
|
||||
})
|
||||
|
||||
return opportunities
|
||||
|
||||
def _generate_summary(self, duration: float, peak_memory: int,
|
||||
hotspots: List[Hotspot],
|
||||
bottlenecks: List[BottleneckAnalysis]) -> str:
|
||||
"""Generate human-readable summary."""
|
||||
summary_parts = [
|
||||
f"Profile Summary",
|
||||
f"===============",
|
||||
f"Duration: {duration:.2f}s",
|
||||
f"Peak Memory: {peak_memory / (1024*1024):.1f}MB",
|
||||
f"Hotspots Found: {len(hotspots)}",
|
||||
f"Bottlenecks: {len(bottlenecks)}",
|
||||
]
|
||||
|
||||
if bottlenecks:
|
||||
summary_parts.append("\nMain Bottlenecks:")
|
||||
for b in bottlenecks[:3]:
|
||||
summary_parts.append(f"- {b.type.upper()}: {b.description}")
|
||||
|
||||
if hotspots:
|
||||
summary_parts.append("\nTop Hotspots:")
|
||||
for h in hotspots[:3]:
|
||||
summary_parts.append(f"- {h.function} ({h.cpu_time:.2f}s, {h.memory_allocated/(1024*1024):.1f}MB)")
|
||||
|
||||
# Add SpaceTime recommendation
|
||||
if peak_memory > 100 * 1024 * 1024: # 100MB
|
||||
summary_parts.append("\nSpaceTime Optimization Potential: HIGH")
|
||||
summary_parts.append("Consider using SpaceTime collections and algorithms for √n memory reduction")
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
2
src/sqrtspace_spacetime/py.typed
Normal file
2
src/sqrtspace_spacetime/py.typed
Normal file
@ -0,0 +1,2 @@
|
||||
# Marker file for PEP 561
|
||||
# This package supports type hints
|
||||
27
src/sqrtspace_spacetime/streams/__init__.py
Normal file
27
src/sqrtspace_spacetime/streams/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Streaming operations with √n memory usage."""
|
||||
|
||||
from sqrtspace_spacetime.streams.stream import (
|
||||
Stream,
|
||||
FileStream,
|
||||
CSVStream,
|
||||
JSONLStream,
|
||||
)
|
||||
from sqrtspace_spacetime.streams.operators import (
|
||||
StreamOperator,
|
||||
MapOperator,
|
||||
FilterOperator,
|
||||
FlatMapOperator,
|
||||
ChunkOperator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Stream",
|
||||
"FileStream",
|
||||
"CSVStream",
|
||||
"JSONLStream",
|
||||
"StreamOperator",
|
||||
"MapOperator",
|
||||
"FilterOperator",
|
||||
"FlatMapOperator",
|
||||
"ChunkOperator",
|
||||
]
|
||||
169
src/sqrtspace_spacetime/streams/operators.py
Normal file
169
src/sqrtspace_spacetime/streams/operators.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""
|
||||
Stream operators for transformation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, Iterator, List, TypeVar, Optional
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
|
||||
class StreamOperator(ABC):
|
||||
"""Base class for stream operators."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[Any]:
|
||||
"""Apply operator to iterator."""
|
||||
pass
|
||||
|
||||
|
||||
class MapOperator(StreamOperator):
|
||||
"""Map each element to a new value."""
|
||||
|
||||
def __init__(self, func: Callable[[T], U]):
|
||||
self.func = func
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[U]:
|
||||
for item in iterator:
|
||||
yield self.func(item)
|
||||
|
||||
|
||||
class FilterOperator(StreamOperator):
|
||||
"""Filter elements by predicate."""
|
||||
|
||||
def __init__(self, predicate: Callable[[T], bool]):
|
||||
self.predicate = predicate
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
for item in iterator:
|
||||
if self.predicate(item):
|
||||
yield item
|
||||
|
||||
|
||||
class FlatMapOperator(StreamOperator):
|
||||
"""Map each element to multiple elements."""
|
||||
|
||||
def __init__(self, func: Callable[[T], Iterable[U]]):
|
||||
self.func = func
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[U]:
|
||||
for item in iterator:
|
||||
result = self.func(item)
|
||||
if hasattr(result, '__iter__'):
|
||||
yield from result
|
||||
else:
|
||||
yield result
|
||||
|
||||
|
||||
class ChunkOperator(StreamOperator):
|
||||
"""Group elements into fixed-size chunks."""
|
||||
|
||||
def __init__(self, size: int):
|
||||
self.size = max(1, size)
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[List[T]]:
|
||||
chunk = []
|
||||
|
||||
for item in iterator:
|
||||
chunk.append(item)
|
||||
|
||||
if len(chunk) >= self.size:
|
||||
yield chunk
|
||||
chunk = []
|
||||
|
||||
# Don't forget last chunk
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class WindowOperator(StreamOperator):
|
||||
"""Sliding window over stream."""
|
||||
|
||||
def __init__(self, size: int, slide: int = 1):
|
||||
self.size = max(1, size)
|
||||
self.slide = max(1, slide)
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[List[T]]:
|
||||
window = []
|
||||
|
||||
for item in iterator:
|
||||
window.append(item)
|
||||
|
||||
if len(window) >= self.size:
|
||||
yield window.copy()
|
||||
|
||||
# Slide window
|
||||
for _ in range(min(self.slide, len(window))):
|
||||
window.pop(0)
|
||||
|
||||
|
||||
class TakeWhileOperator(StreamOperator):
|
||||
"""Take elements while predicate is true."""
|
||||
|
||||
def __init__(self, predicate: Callable[[T], bool]):
|
||||
self.predicate = predicate
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
for item in iterator:
|
||||
if self.predicate(item):
|
||||
yield item
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class DropWhileOperator(StreamOperator):
|
||||
"""Drop elements while predicate is true."""
|
||||
|
||||
def __init__(self, predicate: Callable[[T], bool]):
|
||||
self.predicate = predicate
|
||||
self.dropping = True
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
for item in iterator:
|
||||
if self.dropping and self.predicate(item):
|
||||
continue
|
||||
else:
|
||||
self.dropping = False
|
||||
yield item
|
||||
|
||||
|
||||
class DistinctOperator(StreamOperator):
|
||||
"""Remove duplicate elements."""
|
||||
|
||||
def __init__(self, key_func: Optional[Callable[[T], Any]] = None):
|
||||
self.key_func = key_func or (lambda x: x)
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
seen = set()
|
||||
|
||||
for item in iterator:
|
||||
key = self.key_func(item)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
yield item
|
||||
|
||||
|
||||
class TakeOperator(StreamOperator):
|
||||
"""Take first n elements."""
|
||||
|
||||
def __init__(self, n: int):
|
||||
self.n = n
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
for i, item in enumerate(iterator):
|
||||
if i >= self.n:
|
||||
break
|
||||
yield item
|
||||
|
||||
|
||||
class SkipOperator(StreamOperator):
|
||||
"""Skip first n elements."""
|
||||
|
||||
def __init__(self, n: int):
|
||||
self.n = n
|
||||
|
||||
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||
for i, item in enumerate(iterator):
|
||||
if i >= self.n:
|
||||
yield item
|
||||
298
src/sqrtspace_spacetime/streams/stream.py
Normal file
298
src/sqrtspace_spacetime/streams/stream.py
Normal file
@ -0,0 +1,298 @@
|
||||
"""
|
||||
Memory-efficient streaming operations.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterable, Iterator, List, Optional,
|
||||
TypeVar, Union, AsyncIterator, Tuple
|
||||
)
|
||||
|
||||
from sqrtspace_spacetime.config import config
|
||||
from sqrtspace_spacetime.streams.operators import (
|
||||
MapOperator, FilterOperator, FlatMapOperator, ChunkOperator,
|
||||
TakeOperator, SkipOperator
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
|
||||
class Stream(Iterable[T]):
|
||||
"""
|
||||
A lazy, memory-efficient stream for processing large datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, source: Union[Iterable[T], Iterator[T], Callable[[], Iterator[T]]]):
|
||||
"""
|
||||
Initialize stream.
|
||||
|
||||
Args:
|
||||
source: Data source (iterable, iterator, or callable returning iterator)
|
||||
"""
|
||||
if callable(source):
|
||||
self._source = source
|
||||
elif hasattr(source, '__iter__'):
|
||||
self._source = lambda: iter(source)
|
||||
else:
|
||||
raise TypeError("Source must be iterable or callable")
|
||||
|
||||
self._operators: List[Any] = []
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
"""Create iterator with all operators applied."""
|
||||
iterator = self._source()
|
||||
|
||||
# Apply operators in sequence
|
||||
for op in self._operators:
|
||||
iterator = op.apply(iterator)
|
||||
|
||||
return iterator
|
||||
|
||||
# Transformation operators
|
||||
|
||||
def map(self, func: Callable[[T], U]) -> 'Stream[U]':
|
||||
"""Apply function to each element."""
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(MapOperator(func))
|
||||
return new_stream
|
||||
|
||||
def filter(self, predicate: Callable[[T], bool]) -> 'Stream[T]':
|
||||
"""Keep only elements matching predicate."""
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(FilterOperator(predicate))
|
||||
return new_stream
|
||||
|
||||
def flat_map(self, func: Callable[[T], Iterable[U]]) -> 'Stream[U]':
|
||||
"""Map each element to multiple elements."""
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(FlatMapOperator(func))
|
||||
return new_stream
|
||||
|
||||
def chunk(self, size: Optional[int] = None) -> 'Stream[List[T]]':
|
||||
"""Group elements into chunks."""
|
||||
if size is None:
|
||||
# Use √n chunking
|
||||
# Since we don't know total size, use a reasonable default
|
||||
size = 1000
|
||||
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(ChunkOperator(size))
|
||||
return new_stream
|
||||
|
||||
def take(self, n: int) -> 'Stream[T]':
|
||||
"""Take first n elements."""
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(TakeOperator(n))
|
||||
return new_stream
|
||||
|
||||
def skip(self, n: int) -> 'Stream[T]':
|
||||
"""Skip first n elements."""
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(SkipOperator(n))
|
||||
return new_stream
|
||||
|
||||
def distinct(self) -> 'Stream[T]':
|
||||
"""Remove duplicate elements."""
|
||||
def distinct_op(iterator):
|
||||
seen = set()
|
||||
for item in iterator:
|
||||
if item not in seen:
|
||||
seen.add(item)
|
||||
yield item
|
||||
|
||||
new_stream = Stream(self._source)
|
||||
new_stream._operators = self._operators.copy()
|
||||
new_stream._operators.append(lambda it: distinct_op(it))
|
||||
return new_stream
|
||||
|
||||
# Terminal operators
|
||||
|
||||
def collect(self) -> List[T]:
|
||||
"""Collect all elements into a list."""
|
||||
return list(self)
|
||||
|
||||
def reduce(self, func: Callable[[U, T], U], initial: U) -> U:
|
||||
"""Reduce stream to single value."""
|
||||
result = initial
|
||||
for item in self:
|
||||
result = func(result, item)
|
||||
return result
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count elements."""
|
||||
return sum(1 for _ in self)
|
||||
|
||||
def first(self) -> Optional[T]:
|
||||
"""Get first element."""
|
||||
for item in self:
|
||||
return item
|
||||
return None
|
||||
|
||||
def foreach(self, func: Callable[[T], None]) -> None:
|
||||
"""Apply function to each element."""
|
||||
for item in self:
|
||||
func(item)
|
||||
|
||||
def group_by(self, key_func: Callable[[T], Any]) -> Dict[Any, List[T]]:
|
||||
"""Group elements by key."""
|
||||
from sqrtspace_spacetime.algorithms import external_groupby
|
||||
return external_groupby(self, key_func)
|
||||
|
||||
def sort(self, key: Optional[Callable[[T], Any]] = None, reverse: bool = False) -> List[T]:
|
||||
"""Sort elements."""
|
||||
from sqrtspace_spacetime.algorithms import external_sort_key, external_sort
|
||||
|
||||
if key:
|
||||
return external_sort_key(self, key=key, reverse=reverse)
|
||||
else:
|
||||
return external_sort(self, reverse=reverse)
|
||||
|
||||
def to_file(self, path: Union[str, Path], mode: str = 'w') -> None:
|
||||
"""Write stream to file."""
|
||||
path = Path(path)
|
||||
|
||||
with open(path, mode) as f:
|
||||
for item in self:
|
||||
f.write(str(item) + '\n')
|
||||
|
||||
def to_csv(self, path: Union[str, Path], headers: Optional[List[str]] = None) -> None:
|
||||
"""Write stream to CSV file."""
|
||||
path = Path(path)
|
||||
|
||||
with open(path, 'w', newline='') as f:
|
||||
writer = None
|
||||
|
||||
for item in self:
|
||||
if writer is None:
|
||||
# Initialize writer based on first item
|
||||
if isinstance(item, dict):
|
||||
writer = csv.DictWriter(f, fieldnames=headers or item.keys())
|
||||
if headers or item:
|
||||
writer.writeheader()
|
||||
else:
|
||||
writer = csv.writer(f)
|
||||
if headers:
|
||||
writer.writerow(headers)
|
||||
|
||||
if isinstance(item, dict):
|
||||
writer.writerow(item)
|
||||
elif isinstance(item, (list, tuple)):
|
||||
writer.writerow(item)
|
||||
else:
|
||||
writer.writerow([item])
|
||||
|
||||
def to_jsonl(self, path: Union[str, Path]) -> None:
|
||||
"""Write stream to JSON Lines file."""
|
||||
path = Path(path)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
for item in self:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
# Async support
|
||||
|
||||
async def async_foreach(self, func: Callable[[T], Any]) -> None:
|
||||
"""Apply async function to each element."""
|
||||
for item in self:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
await func(item)
|
||||
else:
|
||||
func(item)
|
||||
|
||||
# Factory methods
|
||||
|
||||
@classmethod
|
||||
def from_iterable(cls, iterable: Iterable[T]) -> 'Stream[T]':
|
||||
"""Create stream from iterable."""
|
||||
return cls(iterable)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path], mode: str = 'r') -> 'Stream[str]':
|
||||
"""Create stream from file."""
|
||||
return FileStream(path, mode)
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls, path: Union[str, Path], headers: bool = True, **kwargs) -> 'Stream[Dict[str, Any]]':
|
||||
"""Create stream from CSV file."""
|
||||
return CSVStream(path, headers=headers, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_jsonl(cls, path: Union[str, Path]) -> 'Stream[Any]':
|
||||
"""Create stream from JSON Lines file."""
|
||||
return JSONLStream(path)
|
||||
|
||||
@classmethod
|
||||
def range(cls, *args) -> 'Stream[int]':
|
||||
"""Create stream of integers."""
|
||||
return cls(lambda: iter(range(*args)))
|
||||
|
||||
@classmethod
|
||||
def infinite(cls, func: Callable[[], T]) -> 'Stream[T]':
|
||||
"""Create infinite stream."""
|
||||
def generator():
|
||||
while True:
|
||||
yield func()
|
||||
return cls(generator)
|
||||
|
||||
|
||||
class FileStream(Stream[str]):
|
||||
"""Stream lines from a file."""
|
||||
|
||||
def __init__(self, path: Union[str, Path], mode: str = 'r', encoding: str = 'utf-8'):
|
||||
self.path = Path(path)
|
||||
self.mode = mode
|
||||
self.encoding = encoding
|
||||
|
||||
def file_iterator():
|
||||
with open(self.path, self.mode, encoding=self.encoding) as f:
|
||||
for line in f:
|
||||
yield line.rstrip('\n\r')
|
||||
|
||||
super().__init__(file_iterator)
|
||||
|
||||
|
||||
class CSVStream(Stream[Dict[str, Any]]):
|
||||
"""Stream rows from CSV file."""
|
||||
|
||||
def __init__(self, path: Union[str, Path], headers: bool = True, **csv_kwargs):
|
||||
self.path = Path(path)
|
||||
self.headers = headers
|
||||
self.csv_kwargs = csv_kwargs
|
||||
|
||||
def csv_iterator():
|
||||
with open(self.path, 'r', newline='') as f:
|
||||
if self.headers:
|
||||
reader = csv.DictReader(f, **self.csv_kwargs)
|
||||
else:
|
||||
reader = csv.reader(f, **self.csv_kwargs)
|
||||
|
||||
for row in reader:
|
||||
yield row
|
||||
|
||||
super().__init__(csv_iterator)
|
||||
|
||||
|
||||
class JSONLStream(Stream[Any]):
|
||||
"""Stream objects from JSON Lines file."""
|
||||
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
self.path = Path(path)
|
||||
|
||||
def jsonl_iterator():
|
||||
with open(self.path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
yield json.loads(line)
|
||||
|
||||
super().__init__(jsonl_iterator)
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Ubiquity SpaceTime Test Suite
|
||||
234
tests/test_external_algorithms.py
Normal file
234
tests/test_external_algorithms.py
Normal file
@ -0,0 +1,234 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for external algorithms with memory pressure.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import gc
|
||||
import psutil
|
||||
import time
|
||||
from sqrtspace_spacetime import external_sort, external_groupby, SpaceTimeConfig
|
||||
|
||||
|
||||
class TestExternalAlgorithms(unittest.TestCase):
|
||||
"""Test external algorithms under memory constraints."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
SpaceTimeConfig.set_defaults(
|
||||
memory_limit=100 * 1024 * 1024, # 100MB limit
|
||||
chunk_strategy='sqrt_n'
|
||||
)
|
||||
self.process = psutil.Process()
|
||||
|
||||
def test_external_sort_small(self):
|
||||
"""Test external sort with small dataset."""
|
||||
data = [random.randint(1, 1000) for _ in range(1000)]
|
||||
sorted_data = external_sort(data)
|
||||
|
||||
# Verify sorting
|
||||
self.assertEqual(len(sorted_data), len(data))
|
||||
for i in range(len(sorted_data) - 1):
|
||||
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||
|
||||
# Verify all elements present
|
||||
self.assertEqual(sorted(data), sorted_data)
|
||||
|
||||
def test_external_sort_large_with_memory_tracking(self):
|
||||
"""Test external sort with large dataset and memory tracking."""
|
||||
n = 1_000_000 # 1 million items
|
||||
|
||||
# Generate data
|
||||
print(f"\nGenerating {n:,} random integers...")
|
||||
data = [random.randint(1, 10_000_000) for _ in range(n)]
|
||||
|
||||
# Track memory before sorting
|
||||
gc.collect()
|
||||
memory_before = self.process.memory_info().rss / 1024 / 1024
|
||||
peak_memory = memory_before
|
||||
|
||||
# Sort with memory tracking
|
||||
print("Sorting with external_sort...")
|
||||
start_time = time.time()
|
||||
|
||||
# Create a custom monitoring function
|
||||
memory_samples = []
|
||||
def monitor_memory():
|
||||
current = self.process.memory_info().rss / 1024 / 1024
|
||||
memory_samples.append(current)
|
||||
return current
|
||||
|
||||
# Sort data
|
||||
sorted_data = external_sort(data)
|
||||
|
||||
# Measure final state
|
||||
gc.collect()
|
||||
memory_after = self.process.memory_info().rss / 1024 / 1024
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Sample memory during verification
|
||||
for i in range(0, len(sorted_data) - 1, 10000):
|
||||
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||
if i % 100000 == 0:
|
||||
peak_memory = max(peak_memory, monitor_memory())
|
||||
|
||||
# Calculate statistics
|
||||
memory_increase = memory_after - memory_before
|
||||
theoretical_sqrt_n = int(n ** 0.5)
|
||||
|
||||
print(f"\nExternal Sort Statistics:")
|
||||
print(f" Items sorted: {n:,}")
|
||||
print(f" Time taken: {elapsed:.2f} seconds")
|
||||
print(f" Memory before: {memory_before:.1f} MB")
|
||||
print(f" Memory after: {memory_after:.1f} MB")
|
||||
print(f" Peak memory: {peak_memory:.1f} MB")
|
||||
print(f" Memory increase: {memory_increase:.1f} MB")
|
||||
print(f" Theoretical √n: {theoretical_sqrt_n:,} items")
|
||||
print(f" Items per MB: {n / max(memory_increase, 0.1):,.0f}")
|
||||
|
||||
# Verify memory efficiency
|
||||
# With 1M items, sqrt(n) = 1000, so memory should be much less than full dataset
|
||||
self.assertLess(memory_increase, 50, f"Memory increase {memory_increase:.1f} MB is too high")
|
||||
|
||||
# Verify correctness on sample
|
||||
sample_indices = random.sample(range(len(sorted_data) - 1), min(1000, len(sorted_data) - 1))
|
||||
for i in sample_indices:
|
||||
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||
|
||||
def test_external_groupby_memory_efficiency(self):
|
||||
"""Test external groupby with memory tracking."""
|
||||
n = 100_000
|
||||
|
||||
# Generate data with limited number of groups
|
||||
print(f"\nGenerating {n:,} items for groupby...")
|
||||
categories = [f"category_{i}" for i in range(100)]
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"category": random.choice(categories),
|
||||
"value": random.randint(1, 1000),
|
||||
"data": f"data_{i}" * 10 # Make items larger
|
||||
}
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
# Track memory
|
||||
gc.collect()
|
||||
memory_before = self.process.memory_info().rss / 1024 / 1024
|
||||
|
||||
# Group by category
|
||||
print("Grouping by category...")
|
||||
start_time = time.time()
|
||||
grouped = external_groupby(data, key_func=lambda x: x["category"])
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Measure memory
|
||||
gc.collect()
|
||||
memory_after = self.process.memory_info().rss / 1024 / 1024
|
||||
memory_increase = memory_after - memory_before
|
||||
|
||||
print(f"\nExternal GroupBy Statistics:")
|
||||
print(f" Items grouped: {n:,}")
|
||||
print(f" Groups created: {len(grouped)}")
|
||||
print(f" Time taken: {elapsed:.2f} seconds")
|
||||
print(f" Memory increase: {memory_increase:.1f} MB")
|
||||
print(f" Items per MB: {n / max(memory_increase, 0.1):,.0f}")
|
||||
|
||||
# Verify correctness
|
||||
self.assertEqual(len(grouped), len(categories))
|
||||
total_items = sum(len(group) for group in grouped.values())
|
||||
self.assertEqual(total_items, n)
|
||||
|
||||
# Verify grouping
|
||||
for category, items in grouped.items():
|
||||
for item in items[:10]: # Check first 10 items in each group
|
||||
self.assertEqual(item["category"], category)
|
||||
|
||||
# Memory should be reasonable
|
||||
self.assertLess(memory_increase, 100, f"Memory increase {memory_increase:.1f} MB is too high")
|
||||
|
||||
def test_stress_test_combined_operations(self):
|
||||
"""Stress test with combined operations."""
|
||||
n = 50_000
|
||||
|
||||
print(f"\nRunning stress test with {n:,} items...")
|
||||
|
||||
# Generate complex data
|
||||
data = []
|
||||
for i in range(n):
|
||||
data.append({
|
||||
"id": i,
|
||||
"group": f"group_{i % 50}",
|
||||
"value": random.randint(1, 1000),
|
||||
"score": random.random(),
|
||||
"text": f"This is item {i} with some text" * 5
|
||||
})
|
||||
|
||||
# Track initial memory
|
||||
gc.collect()
|
||||
initial_memory = self.process.memory_info().rss / 1024 / 1024
|
||||
|
||||
# Operation 1: Group by
|
||||
print(" 1. Grouping data...")
|
||||
grouped = external_groupby(data, key_func=lambda x: x["group"])
|
||||
|
||||
# Operation 2: Sort each group
|
||||
print(" 2. Sorting each group...")
|
||||
for group_key, group_items in grouped.items():
|
||||
# Sort by value
|
||||
sorted_items = external_sort(
|
||||
group_items,
|
||||
key=lambda x: x["value"]
|
||||
)
|
||||
grouped[group_key] = sorted_items
|
||||
|
||||
# Operation 3: Extract top items from each group
|
||||
print(" 3. Extracting top items...")
|
||||
top_items = []
|
||||
for group_items in grouped.values():
|
||||
# Get top 10 by value
|
||||
top_items.extend(group_items[-10:])
|
||||
|
||||
# Operation 4: Final sort
|
||||
print(" 4. Final sort of top items...")
|
||||
final_sorted = external_sort(
|
||||
top_items,
|
||||
key=lambda x: x["score"],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Measure final memory
|
||||
gc.collect()
|
||||
final_memory = self.process.memory_info().rss / 1024 / 1024
|
||||
total_memory_increase = final_memory - initial_memory
|
||||
|
||||
print(f"\nStress Test Results:")
|
||||
print(f" Initial memory: {initial_memory:.1f} MB")
|
||||
print(f" Final memory: {final_memory:.1f} MB")
|
||||
print(f" Total increase: {total_memory_increase:.1f} MB")
|
||||
print(f" Groups processed: {len(grouped)}")
|
||||
print(f" Top items selected: {len(top_items)}")
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(len(grouped), 50) # 50 groups
|
||||
self.assertEqual(len(top_items), 50 * 10) # Top 10 from each
|
||||
self.assertEqual(len(final_sorted), len(top_items))
|
||||
|
||||
# Verify sorting
|
||||
for i in range(len(final_sorted) - 1):
|
||||
self.assertGreaterEqual(
|
||||
final_sorted[i]["score"],
|
||||
final_sorted[i + 1]["score"]
|
||||
)
|
||||
|
||||
# Memory should still be reasonable after all operations
|
||||
self.assertLess(
|
||||
total_memory_increase,
|
||||
150,
|
||||
f"Memory increase {total_memory_increase:.1f} MB is too high"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
309
tests/test_memory_pressure.py
Normal file
309
tests/test_memory_pressure.py
Normal file
@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory pressure tests to verify √n behavior under constrained memory.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import gc
|
||||
import os
|
||||
import psutil
|
||||
import resource
|
||||
import tempfile
|
||||
import shutil
|
||||
import random
|
||||
import time
|
||||
from sqrtspace_spacetime import (
|
||||
SpaceTimeArray, SpaceTimeDict, external_sort,
|
||||
external_groupby, SpaceTimeConfig
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryPressure(unittest.TestCase):
|
||||
"""Test √n memory behavior under real memory constraints."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.process = psutil.Process()
|
||||
|
||||
# Configure strict memory limits
|
||||
SpaceTimeConfig.set_defaults(
|
||||
storage_path=self.temp_dir,
|
||||
memory_limit=50 * 1024 * 1024, # 50MB limit
|
||||
chunk_strategy='sqrt_n',
|
||||
compression='gzip'
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_array_under_memory_pressure(self):
|
||||
"""Test SpaceTimeArray behavior when memory is constrained."""
|
||||
print("\n=== Testing SpaceTimeArray under memory pressure ===")
|
||||
|
||||
# Create large objects that will force spillover
|
||||
large_object_size = 1024 # 1KB per object
|
||||
n_objects = 100_000 # Total: ~100MB if all in memory
|
||||
|
||||
array = SpaceTimeArray(threshold='auto')
|
||||
|
||||
# Track metrics
|
||||
spillovers = 0
|
||||
max_memory = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Add objects and monitor memory
|
||||
for i in range(n_objects):
|
||||
# Create a large object
|
||||
obj = {
|
||||
'id': i,
|
||||
'data': 'x' * large_object_size,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
array.append(obj)
|
||||
|
||||
# Monitor every 1000 items
|
||||
if i % 1000 == 0:
|
||||
gc.collect()
|
||||
current_memory = self.process.memory_info().rss / 1024 / 1024
|
||||
max_memory = max(max_memory, current_memory)
|
||||
|
||||
if i > 0:
|
||||
hot_count = len(array._hot_data)
|
||||
cold_count = len(array._cold_indices)
|
||||
print(f" Items: {i:,} | Memory: {current_memory:.1f}MB | "
|
||||
f"Hot: {hot_count} | Cold: {cold_count}")
|
||||
|
||||
# Check if spillover is happening
|
||||
if cold_count > spillovers:
|
||||
spillovers = cold_count
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Verify all data is accessible
|
||||
print("\nVerifying data accessibility...")
|
||||
sample_indices = random.sample(range(n_objects), min(100, n_objects))
|
||||
for idx in sample_indices:
|
||||
obj = array[idx]
|
||||
self.assertEqual(obj['id'], idx)
|
||||
self.assertEqual(len(obj['data']), large_object_size)
|
||||
|
||||
# Calculate statistics
|
||||
theoretical_sqrt_n = int(n_objects ** 0.5)
|
||||
actual_hot_items = len(array._hot_data)
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Total items: {n_objects:,}")
|
||||
print(f" Time taken: {elapsed:.2f} seconds")
|
||||
print(f" Max memory used: {max_memory:.1f} MB")
|
||||
print(f" Theoretical √n: {theoretical_sqrt_n:,}")
|
||||
print(f" Actual hot items: {actual_hot_items:,}")
|
||||
print(f" Cold items: {len(array._cold_indices):,}")
|
||||
print(f" Memory efficiency: {n_objects / max_memory:.0f} items/MB")
|
||||
|
||||
# Assertions
|
||||
self.assertEqual(len(array), n_objects)
|
||||
self.assertLess(max_memory, 150) # Should use much less than 100MB
|
||||
self.assertGreater(spillovers, 0) # Should have spilled to disk
|
||||
self.assertLessEqual(actual_hot_items, theoretical_sqrt_n * 2) # Within 2x of √n
|
||||
|
||||
def test_dict_with_memory_limit(self):
|
||||
"""Test SpaceTimeDict with strict memory limit."""
|
||||
print("\n=== Testing SpaceTimeDict under memory pressure ===")
|
||||
|
||||
# Create dictionary with explicit threshold
|
||||
cache = SpaceTimeDict(threshold=1000) # Keep only 1000 items in memory
|
||||
|
||||
n_items = 50_000
|
||||
value_size = 500 # 500 bytes per value
|
||||
|
||||
# Track evictions
|
||||
evictions = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Add items
|
||||
for i in range(n_items):
|
||||
key = f"key_{i:06d}"
|
||||
value = {
|
||||
'id': i,
|
||||
'data': 'v' * value_size,
|
||||
'accessed': 0
|
||||
}
|
||||
cache[key] = value
|
||||
|
||||
# Check for evictions
|
||||
if i % 1000 == 0 and i > 0:
|
||||
current_hot = len(cache._hot_data)
|
||||
current_cold = len(cache._cold_keys)
|
||||
if current_cold > evictions:
|
||||
evictions = current_cold
|
||||
print(f" Items: {i:,} | Hot: {current_hot} | Cold: {current_cold}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Test access patterns (LRU behavior)
|
||||
print("\nTesting LRU behavior...")
|
||||
# Access some old items
|
||||
for i in range(0, 100, 10):
|
||||
key = f"key_{i:06d}"
|
||||
value = cache[key]
|
||||
value['accessed'] += 1
|
||||
|
||||
# Add more items to trigger eviction
|
||||
for i in range(n_items, n_items + 1000):
|
||||
cache[f"key_{i:06d}"] = {'id': i, 'data': 'x' * value_size}
|
||||
|
||||
# Recent items should still be hot
|
||||
stats = cache.get_stats()
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Total items: {len(cache):,}")
|
||||
print(f" Time taken: {elapsed:.2f} seconds")
|
||||
print(f" Hot items: {len(cache._hot_data)}")
|
||||
print(f" Cold items: {len(cache._cold_keys)}")
|
||||
print(f" Stats: {stats}")
|
||||
|
||||
# Verify all items accessible
|
||||
sample_keys = random.sample([f"key_{i:06d}" for i in range(n_items)], 100)
|
||||
for key in sample_keys:
|
||||
self.assertIn(key, cache)
|
||||
value = cache[key]
|
||||
self.assertIsNotNone(value)
|
||||
|
||||
def test_algorithm_memory_scaling(self):
|
||||
"""Test that algorithms scale with √n memory usage."""
|
||||
print("\n=== Testing algorithm memory scaling ===")
|
||||
|
||||
datasets = [10_000, 40_000, 90_000, 160_000] # n, 4n, 9n, 16n
|
||||
results = []
|
||||
|
||||
for n in datasets:
|
||||
print(f"\nTesting with n = {n:,}")
|
||||
|
||||
# Generate data
|
||||
data = [random.randint(1, 1_000_000) for _ in range(n)]
|
||||
|
||||
# Measure memory for sorting
|
||||
gc.collect()
|
||||
mem_before = self.process.memory_info().rss / 1024 / 1024
|
||||
|
||||
sorted_data = external_sort(data)
|
||||
|
||||
gc.collect()
|
||||
mem_after = self.process.memory_info().rss / 1024 / 1024
|
||||
mem_used = mem_after - mem_before
|
||||
|
||||
# Verify correctness
|
||||
self.assertEqual(len(sorted_data), n)
|
||||
for i in range(min(1000, len(sorted_data) - 1)):
|
||||
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||
|
||||
sqrt_n = int(n ** 0.5)
|
||||
results.append({
|
||||
'n': n,
|
||||
'sqrt_n': sqrt_n,
|
||||
'memory_used': mem_used,
|
||||
'ratio': mem_used / max(sqrt_n * 8 / 1024 / 1024, 0.001) # 8 bytes per int
|
||||
})
|
||||
|
||||
print(f" √n = {sqrt_n:,}")
|
||||
print(f" Memory used: {mem_used:.2f} MB")
|
||||
print(f" Ratio to theoretical: {results[-1]['ratio']:.2f}x")
|
||||
|
||||
# Verify √n scaling
|
||||
print("\nScaling Analysis:")
|
||||
print("n | √n | Memory (MB) | Ratio")
|
||||
print("---------|---------|-------------|-------")
|
||||
for r in results:
|
||||
print(f"{r['n']:8,} | {r['sqrt_n']:7,} | {r['memory_used']:11.2f} | {r['ratio']:6.2f}x")
|
||||
|
||||
# Memory should scale roughly with √n
|
||||
# As n increases 4x, memory should increase ~2x
|
||||
for i in range(1, len(results)):
|
||||
n_ratio = results[i]['n'] / results[i-1]['n']
|
||||
mem_ratio = results[i]['memory_used'] / max(results[i-1]['memory_used'], 0.1)
|
||||
expected_ratio = n_ratio ** 0.5
|
||||
|
||||
print(f"\nn increased {n_ratio:.1f}x, memory increased {mem_ratio:.1f}x "
|
||||
f"(expected ~{expected_ratio:.1f}x)")
|
||||
|
||||
# Allow some variance due to overheads
|
||||
self.assertLess(mem_ratio, expected_ratio * 3,
|
||||
f"Memory scaling worse than √n: {mem_ratio:.1f}x vs {expected_ratio:.1f}x")
|
||||
|
||||
def test_concurrent_memory_pressure(self):
|
||||
"""Test behavior under concurrent access with memory pressure."""
|
||||
print("\n=== Testing concurrent access under memory pressure ===")
|
||||
|
||||
import threading
|
||||
import queue
|
||||
|
||||
array = SpaceTimeArray(threshold=500)
|
||||
errors = queue.Queue()
|
||||
n_threads = 4
|
||||
items_per_thread = 25_000
|
||||
|
||||
def worker(thread_id, start_idx):
|
||||
try:
|
||||
for i in range(items_per_thread):
|
||||
item = {
|
||||
'thread': thread_id,
|
||||
'index': start_idx + i,
|
||||
'data': f"thread_{thread_id}_item_{i}" * 50
|
||||
}
|
||||
array.append(item)
|
||||
|
||||
# Occasionally read random items
|
||||
if i % 100 == 0 and len(array) > 10:
|
||||
idx = random.randint(0, len(array) - 1)
|
||||
_ = array[idx]
|
||||
except Exception as e:
|
||||
errors.put((thread_id, str(e)))
|
||||
|
||||
# Start threads
|
||||
threads = []
|
||||
start_time = time.time()
|
||||
|
||||
for i in range(n_threads):
|
||||
t = threading.Thread(
|
||||
target=worker,
|
||||
args=(i, i * items_per_thread)
|
||||
)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Monitor memory while threads run
|
||||
max_memory = 0
|
||||
while any(t.is_alive() for t in threads):
|
||||
current_memory = self.process.memory_info().rss / 1024 / 1024
|
||||
max_memory = max(max_memory, current_memory)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Wait for completion
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Check for errors
|
||||
error_list = []
|
||||
while not errors.empty():
|
||||
error_list.append(errors.get())
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Threads: {n_threads}")
|
||||
print(f" Total items: {n_threads * items_per_thread:,}")
|
||||
print(f" Time taken: {elapsed:.2f} seconds")
|
||||
print(f" Max memory: {max_memory:.1f} MB")
|
||||
print(f" Errors: {len(error_list)}")
|
||||
print(f" Final array size: {len(array):,}")
|
||||
|
||||
# Assertions
|
||||
self.assertEqual(len(error_list), 0, f"Thread errors: {error_list}")
|
||||
self.assertEqual(len(array), n_threads * items_per_thread)
|
||||
self.assertLess(max_memory, 200) # Should handle memory pressure
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
202
tests/test_spacetime_array.py
Normal file
202
tests/test_spacetime_array.py
Normal file
@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for SpaceTimeArray with memory pressure simulation.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
import gc
|
||||
import psutil
|
||||
from sqrtspace_spacetime import SpaceTimeArray, SpaceTimeConfig
|
||||
|
||||
|
||||
class TestSpaceTimeArray(unittest.TestCase):
|
||||
"""Test SpaceTimeArray functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
SpaceTimeConfig.set_defaults(
|
||||
storage_path=self.temp_dir,
|
||||
memory_limit=50 * 1024 * 1024, # 50MB for testing
|
||||
chunk_strategy='sqrt_n'
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_basic_operations(self):
|
||||
"""Test basic array operations."""
|
||||
array = SpaceTimeArray(threshold=100)
|
||||
|
||||
# Test append
|
||||
for i in range(50):
|
||||
array.append(f"item_{i}")
|
||||
|
||||
self.assertEqual(len(array), 50)
|
||||
self.assertEqual(array[0], "item_0")
|
||||
self.assertEqual(array[49], "item_49")
|
||||
|
||||
# Test negative indexing
|
||||
self.assertEqual(array[-1], "item_49")
|
||||
self.assertEqual(array[-50], "item_0")
|
||||
|
||||
# Test slice
|
||||
slice_result = array[10:20]
|
||||
self.assertEqual(len(slice_result), 10)
|
||||
self.assertEqual(slice_result[0], "item_10")
|
||||
|
||||
def test_automatic_spillover(self):
|
||||
"""Test automatic spillover to disk."""
|
||||
# Create array with small threshold
|
||||
array = SpaceTimeArray(threshold=10)
|
||||
|
||||
# Add more items than threshold
|
||||
for i in range(100):
|
||||
array.append(f"value_{i}")
|
||||
|
||||
# Check that spillover happened
|
||||
self.assertEqual(len(array), 100)
|
||||
self.assertGreater(len(array._cold_indices), 0)
|
||||
self.assertLessEqual(len(array._hot_data), array.threshold)
|
||||
|
||||
# Verify all items are accessible
|
||||
for i in range(100):
|
||||
self.assertEqual(array[i], f"value_{i}")
|
||||
|
||||
def test_memory_pressure_handling(self):
|
||||
"""Test behavior under memory pressure."""
|
||||
# Create array with auto threshold
|
||||
array = SpaceTimeArray()
|
||||
|
||||
# Generate large data items
|
||||
large_item = "x" * 10000 # 10KB string
|
||||
|
||||
# Add items until memory pressure detected
|
||||
for i in range(1000):
|
||||
array.append(f"{large_item}_{i}")
|
||||
|
||||
# Check memory usage periodically
|
||||
if i % 100 == 0:
|
||||
process = psutil.Process()
|
||||
memory_mb = process.memory_info().rss / 1024 / 1024
|
||||
# Ensure we're not using excessive memory
|
||||
self.assertLess(memory_mb, 200, f"Memory usage too high at iteration {i}")
|
||||
|
||||
# Verify all items still accessible
|
||||
self.assertEqual(len(array), 1000)
|
||||
self.assertTrue(array[0].endswith("_0"))
|
||||
self.assertTrue(array[999].endswith("_999"))
|
||||
|
||||
def test_large_dataset_sqrt_n_memory(self):
|
||||
"""Test √n memory usage with large dataset."""
|
||||
# Configure for sqrt_n strategy
|
||||
SpaceTimeConfig.set_defaults(chunk_strategy='sqrt_n')
|
||||
|
||||
n = 10000 # Total items
|
||||
sqrt_n = int(n ** 0.5) # Expected memory items
|
||||
|
||||
array = SpaceTimeArray()
|
||||
|
||||
# Track initial memory
|
||||
gc.collect()
|
||||
process = psutil.Process()
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Add n items
|
||||
for i in range(n):
|
||||
array.append({"id": i, "data": f"item_{i}" * 10})
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Check memory usage
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase_mb = (final_memory - initial_memory) / 1024 / 1024
|
||||
|
||||
# Verify sqrt_n behavior
|
||||
self.assertEqual(len(array), n)
|
||||
self.assertLessEqual(len(array._hot_data), sqrt_n * 2) # Allow some buffer
|
||||
self.assertGreater(len(array._cold_indices), n - sqrt_n * 2)
|
||||
|
||||
# Memory should be much less than storing all items
|
||||
# Rough estimate: each item ~100 bytes, so n items = ~1MB
|
||||
# With sqrt_n, should use ~10KB in memory
|
||||
self.assertLess(memory_increase_mb, 10, f"Memory increase {memory_increase_mb}MB is too high")
|
||||
|
||||
# Verify random access still works
|
||||
import random
|
||||
for _ in range(100):
|
||||
idx = random.randint(0, n - 1)
|
||||
self.assertEqual(array[idx]["id"], idx)
|
||||
|
||||
def test_persistence_across_sessions(self):
|
||||
"""Test data persistence when array is recreated."""
|
||||
storage_path = os.path.join(self.temp_dir, "persist_test")
|
||||
|
||||
# Create and populate array
|
||||
array1 = SpaceTimeArray(threshold=10, storage_path=storage_path)
|
||||
for i in range(50):
|
||||
array1.append(f"persistent_{i}")
|
||||
|
||||
# Force spillover
|
||||
array1._check_and_spill()
|
||||
del array1
|
||||
|
||||
# Create new array with same storage path
|
||||
array2 = SpaceTimeArray(threshold=10, storage_path=storage_path)
|
||||
|
||||
# Data should be accessible
|
||||
self.assertEqual(len(array2), 50)
|
||||
for i in range(50):
|
||||
self.assertEqual(array2[i], f"persistent_{i}")
|
||||
|
||||
def test_concurrent_access(self):
|
||||
"""Test thread-safe access to array."""
|
||||
import threading
|
||||
|
||||
array = SpaceTimeArray(threshold=100)
|
||||
errors = []
|
||||
|
||||
def writer(start, count):
|
||||
try:
|
||||
for i in range(start, start + count):
|
||||
array.append(f"thread_{i}")
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def reader(count):
|
||||
try:
|
||||
for _ in range(count):
|
||||
if len(array) > 0:
|
||||
_ = array[0] # Just access, don't verify
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Create threads
|
||||
threads = []
|
||||
for i in range(5):
|
||||
t = threading.Thread(target=writer, args=(i * 100, 100))
|
||||
threads.append(t)
|
||||
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=reader, args=(50,))
|
||||
threads.append(t)
|
||||
|
||||
# Run threads
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check for errors
|
||||
self.assertEqual(len(errors), 0, f"Thread errors: {errors}")
|
||||
self.assertEqual(len(array), 500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user