Initial
This commit is contained in:
commit
69b521b549
190
LICENSE
Normal file
190
LICENSE
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
Copyright 2024 Ubiquity SpaceTime Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
428
README.md
Normal file
428
README.md
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
# SqrtSpace SpaceTime for Python
|
||||||
|
|
||||||
|
[](https://badge.fury.io/py/sqrtspace-spacetime)
|
||||||
|
[](https://pypi.org/project/sqrtspace-spacetime/)
|
||||||
|
[](https://github.com/sqrtspace/sqrtspace-python/blob/main/LICENSE)
|
||||||
|
[](https://sqrtspace-spacetime.readthedocs.io/en/latest/?badge=latest)
|
||||||
|
|
||||||
|
Memory-efficient algorithms and data structures for Python using Williams' √n space-time tradeoffs.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sqrtspace-spacetime
|
||||||
|
```
|
||||||
|
|
||||||
|
For ML features:
|
||||||
|
```bash
|
||||||
|
pip install sqrtspace-spacetime[ml]
|
||||||
|
```
|
||||||
|
|
||||||
|
For all features:
|
||||||
|
```bash
|
||||||
|
pip install sqrtspace-spacetime[all]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
SpaceTime implements theoretical computer science results showing that many algorithms can achieve better memory usage by accepting slightly slower runtime. The key insight is using √n memory instead of n memory, where n is the input size.
|
||||||
|
|
||||||
|
### Key Features
|
||||||
|
|
||||||
|
- **Memory-Efficient Collections**: Arrays and dictionaries that automatically spill to disk
|
||||||
|
- **External Algorithms**: Sort and group large datasets using minimal memory
|
||||||
|
- **Streaming Operations**: Process files larger than RAM with elegant API
|
||||||
|
- **Auto-Checkpointing**: Resume long computations from where they left off
|
||||||
|
- **Memory Profiling**: Identify optimization opportunities in your code
|
||||||
|
- **ML Optimizations**: Reduce neural network training memory by up to 90%
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import SpaceTimeArray, external_sort, Stream
|
||||||
|
|
||||||
|
# Memory-efficient array that spills to disk
|
||||||
|
array = SpaceTimeArray(threshold=10000)
|
||||||
|
for i in range(1000000):
|
||||||
|
array.append(i)
|
||||||
|
|
||||||
|
# Sort large datasets with minimal memory
|
||||||
|
huge_list = list(range(10000000, 0, -1))
|
||||||
|
sorted_data = external_sort(huge_list) # Uses only √n memory
|
||||||
|
|
||||||
|
# Stream processing
|
||||||
|
Stream.from_csv('huge_file.csv') \
|
||||||
|
.filter(lambda row: row['value'] > 100) \
|
||||||
|
.map(lambda row: row['value'] * 1.1) \
|
||||||
|
.group_by(lambda row: row['category']) \
|
||||||
|
.to_csv('processed.csv')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Basic Examples
|
||||||
|
See [`examples/basic_usage.py`](examples/basic_usage.py) for comprehensive examples of:
|
||||||
|
- SpaceTimeArray and SpaceTimeDict usage
|
||||||
|
- External sorting and grouping
|
||||||
|
- Stream processing
|
||||||
|
- Memory profiling
|
||||||
|
- Auto-checkpointing
|
||||||
|
|
||||||
|
### FastAPI Web Application
|
||||||
|
Check out [`examples/fastapi-app/`](examples/fastapi-app/) for a production-ready web application featuring:
|
||||||
|
- Streaming endpoints for large datasets
|
||||||
|
- Server-Sent Events (SSE) for real-time data
|
||||||
|
- Memory-efficient CSV exports
|
||||||
|
- Checkpointed background tasks
|
||||||
|
- ML model serving with memory constraints
|
||||||
|
|
||||||
|
See the [FastAPI example README](examples/fastapi-app/README.md) for detailed documentation.
|
||||||
|
|
||||||
|
### Machine Learning Pipeline
|
||||||
|
Explore [`examples/ml-pipeline/`](examples/ml-pipeline/) for ML-specific patterns:
|
||||||
|
- Training models on datasets larger than RAM
|
||||||
|
- Memory-efficient feature extraction
|
||||||
|
- Checkpointed training loops
|
||||||
|
- Streaming predictions
|
||||||
|
- Integration with PyTorch and TensorFlow
|
||||||
|
|
||||||
|
See the [ML Pipeline README](examples/ml-pipeline/README.md) for complete documentation.
|
||||||
|
|
||||||
|
### Memory-Efficient Collections
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import SpaceTimeArray, SpaceTimeDict
|
||||||
|
|
||||||
|
# Array that automatically manages memory
|
||||||
|
array = SpaceTimeArray(threshold=1000) # Keep 1000 items in memory
|
||||||
|
for i in range(1000000):
|
||||||
|
array.append(f"item_{i}")
|
||||||
|
|
||||||
|
# Dictionary with LRU eviction to disk
|
||||||
|
cache = SpaceTimeDict(threshold=10000)
|
||||||
|
for key, value in huge_dataset:
|
||||||
|
cache[key] = expensive_computation(value)
|
||||||
|
```
|
||||||
|
|
||||||
|
### External Algorithms
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import external_sort, external_groupby
|
||||||
|
|
||||||
|
# Sort 100M items using only ~10K memory
|
||||||
|
data = list(range(100_000_000, 0, -1))
|
||||||
|
sorted_data = external_sort(data)
|
||||||
|
|
||||||
|
# Group by with aggregation
|
||||||
|
sales = [
|
||||||
|
{'store': 'A', 'amount': 100},
|
||||||
|
{'store': 'B', 'amount': 200},
|
||||||
|
# ... millions more
|
||||||
|
]
|
||||||
|
|
||||||
|
by_store = external_groupby(
|
||||||
|
sales,
|
||||||
|
key_func=lambda x: x['store']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate with minimal memory
|
||||||
|
from sqrtspace_spacetime.algorithms import groupby_sum
|
||||||
|
totals = groupby_sum(
|
||||||
|
sales,
|
||||||
|
key_func=lambda x: x['store'],
|
||||||
|
value_func=lambda x: x['amount']
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming Operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import Stream
|
||||||
|
|
||||||
|
# Process large files efficiently
|
||||||
|
stream = Stream.from_csv('sales_2023.csv')
|
||||||
|
.filter(lambda row: row['amount'] > 0)
|
||||||
|
.map(lambda row: {
|
||||||
|
'month': row['date'][:7],
|
||||||
|
'amount': float(row['amount'])
|
||||||
|
})
|
||||||
|
.group_by(lambda row: row['month'])
|
||||||
|
.to_csv('monthly_summary.csv')
|
||||||
|
|
||||||
|
# Chain operations
|
||||||
|
top_products = Stream.from_jsonl('products.jsonl') \
|
||||||
|
.filter(lambda p: p['in_stock']) \
|
||||||
|
.sort(key=lambda p: p['revenue'], reverse=True) \
|
||||||
|
.take(100) \
|
||||||
|
.collect()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Auto-Checkpointing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||||
|
|
||||||
|
@auto_checkpoint(total_iterations=1000000)
|
||||||
|
def process_large_dataset(data):
|
||||||
|
results = []
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
# Process item
|
||||||
|
result = expensive_computation(item)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Yield state for checkpointing
|
||||||
|
yield {'i': i, 'results': results}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Automatically resumes from checkpoint if interrupted
|
||||||
|
results = process_large_dataset(huge_dataset)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Profiling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.profiler import profile, profile_memory
|
||||||
|
|
||||||
|
@profile(output_file="profile.json")
|
||||||
|
def my_algorithm(data):
|
||||||
|
# Process data
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Get detailed memory analysis
|
||||||
|
result, report = my_algorithm(data)
|
||||||
|
print(report.summary)
|
||||||
|
|
||||||
|
# Simple memory tracking
|
||||||
|
@profile_memory(threshold_mb=100)
|
||||||
|
def memory_heavy_function():
|
||||||
|
# Alerts if memory usage exceeds threshold
|
||||||
|
large_list = list(range(10000000))
|
||||||
|
return sum(large_list)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ML Memory Optimization
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.ml import MLMemoryOptimizer
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Analyze model memory usage
|
||||||
|
model = nn.Sequential(
|
||||||
|
nn.Linear(784, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = MLMemoryOptimizer()
|
||||||
|
profile = optimizer.analyze_model(model, input_shape=(784,), batch_size=32)
|
||||||
|
|
||||||
|
# Get optimization plan
|
||||||
|
plan = optimizer.optimize(profile, target_batch_size=128)
|
||||||
|
print(plan.explanation)
|
||||||
|
|
||||||
|
# Apply optimizations
|
||||||
|
config = optimizer.get_training_config(plan, profile)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
### Memory Pressure Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.memory import MemoryMonitor, LoggingHandler
|
||||||
|
|
||||||
|
# Monitor memory pressure
|
||||||
|
monitor = MemoryMonitor()
|
||||||
|
monitor.add_handler(LoggingHandler())
|
||||||
|
|
||||||
|
# Your arrays automatically respond to memory pressure
|
||||||
|
array = SpaceTimeArray()
|
||||||
|
# Arrays spill to disk when memory is low
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import SpaceTimeConfig
|
||||||
|
|
||||||
|
# Global configuration
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
memory_limit=2 * 1024**3, # 2GB
|
||||||
|
chunk_strategy='sqrt_n',
|
||||||
|
compression='gzip',
|
||||||
|
external_storage_path='/fast/ssd/temp'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parallel Processing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.batch import BatchProcessor
|
||||||
|
|
||||||
|
processor = BatchProcessor(
|
||||||
|
memory_threshold=0.8,
|
||||||
|
checkpoint_enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process in memory-efficient batches
|
||||||
|
result = processor.process(
|
||||||
|
huge_list,
|
||||||
|
lambda batch: [transform(item) for item in batch]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Processed {result.get_success_count()} items")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Real-World Examples
|
||||||
|
|
||||||
|
### Processing Large CSV Files
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import Stream
|
||||||
|
from sqrtspace_spacetime.profiler import profile_memory
|
||||||
|
|
||||||
|
@profile_memory(threshold_mb=500)
|
||||||
|
def analyze_sales_data(filename):
|
||||||
|
# Stream process to stay under memory limit
|
||||||
|
return Stream.from_csv(filename) \
|
||||||
|
.filter(lambda row: row['status'] == 'completed') \
|
||||||
|
.map(lambda row: {
|
||||||
|
'product': row['product_id'],
|
||||||
|
'revenue': float(row['price']) * int(row['quantity'])
|
||||||
|
}) \
|
||||||
|
.group_by(lambda row: row['product']) \
|
||||||
|
.sort(key=lambda group: sum(r['revenue'] for r in group[1]), reverse=True) \
|
||||||
|
.take(10) \
|
||||||
|
.collect()
|
||||||
|
|
||||||
|
top_products = analyze_sales_data('sales_2023.csv')
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Large Neural Networks
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime.ml import MLMemoryOptimizer, GradientCheckpointer
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Memory-efficient training
|
||||||
|
def train_large_model(model, train_loader, epochs=10):
|
||||||
|
# Analyze memory requirements
|
||||||
|
optimizer = MLMemoryOptimizer()
|
||||||
|
profile = optimizer.analyze_model(model, input_shape=(3, 224, 224), batch_size=32)
|
||||||
|
|
||||||
|
# Get optimization plan
|
||||||
|
plan = optimizer.optimize(profile, target_batch_size=128)
|
||||||
|
|
||||||
|
# Apply gradient checkpointing
|
||||||
|
checkpointer = GradientCheckpointer()
|
||||||
|
model = checkpointer.apply_checkpointing(model, plan.checkpoint_layers)
|
||||||
|
|
||||||
|
# Train with optimized settings
|
||||||
|
for epoch in range(epochs):
|
||||||
|
for batch in train_loader:
|
||||||
|
# Training loop with automatic memory management
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Pipeline with Checkpoints
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqrtspace_spacetime import Stream
|
||||||
|
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||||
|
|
||||||
|
@auto_checkpoint(total_iterations=1000000)
|
||||||
|
def process_user_events(event_file):
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
for event in Stream.from_jsonl(event_file):
|
||||||
|
# Complex processing
|
||||||
|
user_profile = enhance_profile(event)
|
||||||
|
recommendations = generate_recommendations(user_profile)
|
||||||
|
|
||||||
|
save_to_database(recommendations)
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
# Checkpoint state
|
||||||
|
yield {'processed': processed, 'last_event': event['id']}
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
# Automatically resumes if interrupted
|
||||||
|
total = process_user_events('events.jsonl')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Benchmarks
|
||||||
|
|
||||||
|
| Operation | Standard Python | SpaceTime | Memory Reduction | Time Overhead |
|
||||||
|
|-----------|----------------|-----------|------------------|---------------|
|
||||||
|
| Sort 10M integers | 400MB | 20MB | 95% | 40% |
|
||||||
|
| Process 1GB CSV | 1GB | 32MB | 97% | 20% |
|
||||||
|
| Group by on 1M rows | 200MB | 14MB | 93% | 30% |
|
||||||
|
| Neural network training | 8GB | 2GB | 75% | 15% |
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Collections
|
||||||
|
- `SpaceTimeArray`: Memory-efficient list with disk spillover
|
||||||
|
- `SpaceTimeDict`: Memory-efficient dictionary with LRU eviction
|
||||||
|
|
||||||
|
### Algorithms
|
||||||
|
- `external_sort()`: Sort large datasets with √n memory
|
||||||
|
- `external_groupby()`: Group large datasets with √n memory
|
||||||
|
- `external_join()`: Join large datasets efficiently
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
- `Stream`: Lazy evaluation stream processing
|
||||||
|
- `FileStream`: Stream lines from files
|
||||||
|
- `CSVStream`: Stream CSV rows
|
||||||
|
- `JSONLStream`: Stream JSON Lines
|
||||||
|
|
||||||
|
### Memory Management
|
||||||
|
- `MemoryMonitor`: Monitor memory pressure
|
||||||
|
- `MemoryPressureHandler`: Custom pressure handlers
|
||||||
|
|
||||||
|
### Checkpointing
|
||||||
|
- `@auto_checkpoint`: Automatic checkpointing decorator
|
||||||
|
- `CheckpointManager`: Manual checkpoint control
|
||||||
|
|
||||||
|
### ML Optimization
|
||||||
|
- `MLMemoryOptimizer`: Analyze and optimize models
|
||||||
|
- `GradientCheckpointer`: Apply gradient checkpointing
|
||||||
|
|
||||||
|
### Profiling
|
||||||
|
- `@profile`: Full profiling decorator
|
||||||
|
- `@profile_memory`: Memory-only profiling
|
||||||
|
- `SpaceTimeProfiler`: Programmatic profiling
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Apache License 2.0. See [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use SpaceTime in your research, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{sqrtspace_spacetime,
|
||||||
|
title = {SqrtSpace SpaceTime: Memory-Efficient Python Library},
|
||||||
|
author={Friedel Jr., David H.},
|
||||||
|
year = {2025},
|
||||||
|
url = {https://github.com/sqrtspace/sqrtspace-python}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Links
|
||||||
|
|
||||||
|
- [Documentation](https://sqrtspace-spacetime.readthedocs.io)
|
||||||
|
- [PyPI Package](https://pypi.org/project/sqrtspace-spacetime/)
|
||||||
|
- [GitHub Repository](https://github.com/sqrtspace/sqrtspace-python)
|
||||||
|
- [Issue Tracker](https://github.com/sqrtspace/sqrtspace-python/issues)
|
||||||
204
examples/basic_usage.py
Normal file
204
examples/basic_usage.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Basic usage examples for Ubiquity SpaceTime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from sqrtspace_spacetime import (
|
||||||
|
SpaceTimeArray,
|
||||||
|
SpaceTimeDict,
|
||||||
|
external_sort,
|
||||||
|
external_groupby,
|
||||||
|
Stream,
|
||||||
|
SpaceTimeConfig,
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.profiler import profile, profile_memory
|
||||||
|
from sqrtspace_spacetime.checkpoint import auto_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def example_spacetime_array():
|
||||||
|
"""Example: Memory-efficient array with automatic spillover."""
|
||||||
|
print("\n=== SpaceTimeArray Example ===")
|
||||||
|
|
||||||
|
# Create array that keeps only 1000 items in memory
|
||||||
|
array = SpaceTimeArray(threshold=1000)
|
||||||
|
|
||||||
|
# Add 10,000 items
|
||||||
|
print("Adding 10,000 items to SpaceTimeArray...")
|
||||||
|
for i in range(10000):
|
||||||
|
array.append(f"item_{i}")
|
||||||
|
|
||||||
|
print(f"Array length: {len(array)}")
|
||||||
|
print(f"Sample items: {array[0]}, {array[5000]}, {array[9999]}")
|
||||||
|
|
||||||
|
# Demonstrate memory efficiency
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
memory_mb = process.memory_info().rss / 1024 / 1024
|
||||||
|
print(f"Current memory usage: {memory_mb:.1f} MB (much less than storing all in memory)")
|
||||||
|
|
||||||
|
|
||||||
|
def example_external_sort():
|
||||||
|
"""Example: Sort large dataset with minimal memory."""
|
||||||
|
print("\n=== External Sort Example ===")
|
||||||
|
|
||||||
|
# Generate large random dataset
|
||||||
|
print("Generating 1M random numbers...")
|
||||||
|
data = [random.randint(1, 1000000) for _ in range(1000000)]
|
||||||
|
|
||||||
|
# Sort using √n memory
|
||||||
|
print("Sorting with external_sort (√n memory)...")
|
||||||
|
start = time.time()
|
||||||
|
sorted_data = external_sort(data)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
# Verify sorting
|
||||||
|
is_sorted = all(sorted_data[i] <= sorted_data[i+1] for i in range(len(sorted_data)-1))
|
||||||
|
print(f"Sorted correctly: {is_sorted}")
|
||||||
|
print(f"Time taken: {elapsed:.2f}s")
|
||||||
|
print(f"First 10 elements: {sorted_data[:10]}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_streaming():
|
||||||
|
"""Example: Process data streams efficiently."""
|
||||||
|
print("\n=== Stream Processing Example ===")
|
||||||
|
|
||||||
|
# Create sample data
|
||||||
|
data = [
|
||||||
|
{'name': 'Alice', 'age': 25, 'score': 85},
|
||||||
|
{'name': 'Bob', 'age': 30, 'score': 90},
|
||||||
|
{'name': 'Charlie', 'age': 25, 'score': 78},
|
||||||
|
{'name': 'David', 'age': 30, 'score': 92},
|
||||||
|
{'name': 'Eve', 'age': 25, 'score': 88},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Stream processing
|
||||||
|
result = Stream.from_iterable(data) \
|
||||||
|
.filter(lambda x: x['age'] == 25) \
|
||||||
|
.map(lambda x: {'name': x['name'], 'grade': 'A' if x['score'] >= 85 else 'B'}) \
|
||||||
|
.collect()
|
||||||
|
|
||||||
|
print("Filtered and transformed data:")
|
||||||
|
for item in result:
|
||||||
|
print(f" {item}")
|
||||||
|
|
||||||
|
|
||||||
|
@profile_memory(threshold_mb=50)
|
||||||
|
def example_memory_profiling():
|
||||||
|
"""Example: Profile memory usage."""
|
||||||
|
print("\n=== Memory Profiling Example ===")
|
||||||
|
|
||||||
|
# Simulate memory-intensive operation
|
||||||
|
data = []
|
||||||
|
for i in range(100000):
|
||||||
|
data.append({
|
||||||
|
'id': i,
|
||||||
|
'value': random.random(),
|
||||||
|
'text': f"Item number {i}" * 10
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process data
|
||||||
|
result = sum(item['value'] for item in data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@auto_checkpoint(total_iterations=100)
|
||||||
|
def example_checkpointing(data):
|
||||||
|
"""Example: Auto-checkpoint long computation."""
|
||||||
|
print("\n=== Checkpointing Example ===")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
# Simulate expensive computation
|
||||||
|
time.sleep(0.01)
|
||||||
|
result = item ** 2
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Yield state for checkpointing
|
||||||
|
if i % 10 == 0:
|
||||||
|
print(f"Processing item {i}...")
|
||||||
|
yield {'i': i, 'results': results}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def example_groupby():
|
||||||
|
"""Example: Group large dataset efficiently."""
|
||||||
|
print("\n=== External GroupBy Example ===")
|
||||||
|
|
||||||
|
# Generate sales data
|
||||||
|
sales = []
|
||||||
|
stores = ['Store_A', 'Store_B', 'Store_C', 'Store_D']
|
||||||
|
|
||||||
|
print("Generating 100K sales records...")
|
||||||
|
for i in range(100000):
|
||||||
|
sales.append({
|
||||||
|
'store': random.choice(stores),
|
||||||
|
'amount': random.uniform(10, 1000),
|
||||||
|
'product': f'Product_{random.randint(1, 100)}'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Group by store
|
||||||
|
print("Grouping by store...")
|
||||||
|
grouped = external_groupby(sales, key_func=lambda x: x['store'])
|
||||||
|
|
||||||
|
# Calculate totals
|
||||||
|
for store, transactions in grouped.items():
|
||||||
|
total = sum(t['amount'] for t in transactions)
|
||||||
|
print(f"{store}: {len(transactions)} transactions, ${total:,.2f} total")
|
||||||
|
|
||||||
|
|
||||||
|
def example_spacetime_dict():
|
||||||
|
"""Example: Memory-efficient dictionary with LRU eviction."""
|
||||||
|
print("\n=== SpaceTimeDict Example ===")
|
||||||
|
|
||||||
|
# Create cache with 100-item memory limit
|
||||||
|
cache = SpaceTimeDict(threshold=100)
|
||||||
|
|
||||||
|
# Simulate caching expensive computations
|
||||||
|
print("Caching 1000 expensive computations...")
|
||||||
|
for i in range(1000):
|
||||||
|
key = f"computation_{i}"
|
||||||
|
# Simulate expensive computation
|
||||||
|
value = i ** 2 + random.random()
|
||||||
|
cache[key] = value
|
||||||
|
|
||||||
|
print(f"Total items: {len(cache)}")
|
||||||
|
print(f"Items in memory: {len(cache._hot_data)}")
|
||||||
|
print(f"Items on disk: {len(cache._cold_keys)}")
|
||||||
|
|
||||||
|
# Access patterns
|
||||||
|
stats = cache.get_stats()
|
||||||
|
print(f"Cache stats: {stats}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all examples."""
|
||||||
|
print("=== Ubiquity SpaceTime Examples ===")
|
||||||
|
|
||||||
|
# Configure SpaceTime
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
memory_limit=512 * 1024 * 1024, # 512MB
|
||||||
|
chunk_strategy='sqrt_n',
|
||||||
|
compression='gzip'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run examples
|
||||||
|
example_spacetime_array()
|
||||||
|
example_external_sort()
|
||||||
|
example_streaming()
|
||||||
|
example_memory_profiling()
|
||||||
|
example_groupby()
|
||||||
|
example_spacetime_dict()
|
||||||
|
|
||||||
|
# Checkpointing example
|
||||||
|
data = list(range(100))
|
||||||
|
results = list(example_checkpointing(data))
|
||||||
|
print(f"Checkpointing completed. Processed {len(results)} items.")
|
||||||
|
|
||||||
|
print("\n=== All examples completed! ===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
504
examples/fastapi-app/README.md
Normal file
504
examples/fastapi-app/README.md
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
# SqrtSpace SpaceTime FastAPI Sample Application
|
||||||
|
|
||||||
|
This sample demonstrates how to build memory-efficient, high-performance APIs using FastAPI and SqrtSpace SpaceTime.
|
||||||
|
|
||||||
|
## Features Demonstrated
|
||||||
|
|
||||||
|
### 1. **Streaming Endpoints**
|
||||||
|
- Server-Sent Events (SSE) for real-time data
|
||||||
|
- Streaming file downloads without memory bloat
|
||||||
|
- Chunked JSON responses for large datasets
|
||||||
|
|
||||||
|
### 2. **Background Tasks**
|
||||||
|
- Memory-aware task processing
|
||||||
|
- Checkpointed long-running operations
|
||||||
|
- Progress tracking with resumable state
|
||||||
|
|
||||||
|
### 3. **Data Processing**
|
||||||
|
- External sorting for large datasets
|
||||||
|
- Memory-efficient aggregations
|
||||||
|
- Streaming ETL pipelines
|
||||||
|
|
||||||
|
### 4. **Machine Learning Integration**
|
||||||
|
- Batch prediction with memory limits
|
||||||
|
- Model training with checkpoints
|
||||||
|
- Feature extraction pipelines
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. **Create virtual environment:**
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install dependencies:**
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Configure environment:**
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `.env`:
|
||||||
|
```
|
||||||
|
SPACETIME_MEMORY_LIMIT=512MB
|
||||||
|
SPACETIME_EXTERNAL_STORAGE=/tmp/spacetime
|
||||||
|
SPACETIME_CHUNK_STRATEGY=sqrt_n
|
||||||
|
SPACETIME_COMPRESSION=gzip
|
||||||
|
DATABASE_URL=sqlite:///./app.db
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Initialize database:**
|
||||||
|
```bash
|
||||||
|
python init_db.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
fastapi-app/
|
||||||
|
├── app/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── main.py # FastAPI app
|
||||||
|
│ ├── config.py # Configuration
|
||||||
|
│ ├── models.py # Pydantic models
|
||||||
|
│ ├── database.py # Database setup
|
||||||
|
│ ├── routers/
|
||||||
|
│ │ ├── products.py # Product endpoints
|
||||||
|
│ │ ├── analytics.py # Analytics endpoints
|
||||||
|
│ │ ├── ml.py # ML endpoints
|
||||||
|
│ │ └── reports.py # Report generation
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── product_service.py # Business logic
|
||||||
|
│ │ ├── analytics_service.py # Analytics processing
|
||||||
|
│ │ ├── ml_service.py # ML operations
|
||||||
|
│ │ └── cache_service.py # SpaceTime caching
|
||||||
|
│ ├── workers/
|
||||||
|
│ │ ├── background_tasks.py # Task workers
|
||||||
|
│ │ └── checkpointed_jobs.py # Resumable jobs
|
||||||
|
│ └── utils/
|
||||||
|
│ ├── streaming.py # Streaming helpers
|
||||||
|
│ └── memory.py # Memory monitoring
|
||||||
|
├── requirements.txt
|
||||||
|
├── Dockerfile
|
||||||
|
└── docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### 1. Streaming Large Datasets
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/routers/products.py
|
||||||
|
from fastapi import APIRouter, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqrtspace_spacetime import Stream
|
||||||
|
import json
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/products/stream")
|
||||||
|
async def stream_products(category: str = None):
|
||||||
|
"""Stream products as newline-delimited JSON"""
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
query = db.query(Product)
|
||||||
|
if category:
|
||||||
|
query = query.filter(Product.category == category)
|
||||||
|
|
||||||
|
# Use SpaceTime stream for memory efficiency
|
||||||
|
stream = Stream.from_query(query, chunk_size=100)
|
||||||
|
|
||||||
|
for product in stream:
|
||||||
|
yield json.dumps(product.dict()) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={"X-Accel-Buffering": "no"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Server-Sent Events for Real-Time Data
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/routers/analytics.py
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/analytics/realtime")
|
||||||
|
async def realtime_analytics():
|
||||||
|
"""Stream real-time analytics using SSE"""
|
||||||
|
|
||||||
|
monitor = MemoryPressureMonitor("100MB")
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
while True:
|
||||||
|
# Get current stats
|
||||||
|
stats = await analytics_service.get_current_stats()
|
||||||
|
|
||||||
|
# Check memory pressure
|
||||||
|
if monitor.check() != MemoryPressureLevel.NONE:
|
||||||
|
await analytics_service.compact_cache()
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"event": "update",
|
||||||
|
"data": json.dumps(stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Memory-Efficient CSV Export
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/routers/reports.py
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqrtspace_spacetime.file import CsvWriter
|
||||||
|
import io
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/reports/export/csv")
|
||||||
|
async def export_csv(start_date: date, end_date: date):
|
||||||
|
"""Export large dataset as CSV with streaming"""
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
# Create in-memory buffer
|
||||||
|
output = io.StringIO()
|
||||||
|
writer = CsvWriter(output)
|
||||||
|
|
||||||
|
# Write headers
|
||||||
|
writer.writerow(["Date", "Orders", "Revenue", "Customers"])
|
||||||
|
|
||||||
|
# Stream data in chunks
|
||||||
|
async for batch in analytics_service.get_daily_stats_batched(
|
||||||
|
start_date, end_date, batch_size=100
|
||||||
|
):
|
||||||
|
for row in batch:
|
||||||
|
writer.writerow([
|
||||||
|
row.date,
|
||||||
|
row.order_count,
|
||||||
|
row.total_revenue,
|
||||||
|
row.unique_customers
|
||||||
|
])
|
||||||
|
|
||||||
|
# Yield buffer content
|
||||||
|
output.seek(0)
|
||||||
|
data = output.read()
|
||||||
|
output.seek(0)
|
||||||
|
output.truncate()
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/csv",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=report_{start_date}_{end_date}.csv"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Checkpointed Background Tasks
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/workers/checkpointed_jobs.py
|
||||||
|
from sqrtspace_spacetime.checkpoint import CheckpointManager, auto_checkpoint
|
||||||
|
from sqrtspace_spacetime.collections import SpaceTimeArray
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
self.checkpoint_manager = CheckpointManager()
|
||||||
|
|
||||||
|
@auto_checkpoint(total_iterations=10000)
|
||||||
|
async def process_large_dataset(self, dataset_id: str):
|
||||||
|
"""Process dataset with automatic checkpointing"""
|
||||||
|
|
||||||
|
# Initialize or restore state
|
||||||
|
results = SpaceTimeArray(threshold=1000)
|
||||||
|
processed_count = 0
|
||||||
|
|
||||||
|
# Get data in batches
|
||||||
|
async for batch in self.get_data_batches(dataset_id):
|
||||||
|
for item in batch:
|
||||||
|
# Process item
|
||||||
|
result = await self.process_item(item)
|
||||||
|
results.append(result)
|
||||||
|
processed_count += 1
|
||||||
|
|
||||||
|
# Yield state for checkpointing
|
||||||
|
if processed_count % 100 == 0:
|
||||||
|
yield {
|
||||||
|
'processed': processed_count,
|
||||||
|
'results': results,
|
||||||
|
'last_item_id': item.id
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Machine Learning with Memory Constraints
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/services/ml_service.py
|
||||||
|
from sqrtspace_spacetime.ml import SpaceTimeOptimizer
|
||||||
|
from sqrtspace_spacetime.streams import Stream
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class MLService:
|
||||||
|
def __init__(self):
|
||||||
|
self.optimizer = SpaceTimeOptimizer(
|
||||||
|
memory_limit="256MB",
|
||||||
|
checkpoint_frequency=100
|
||||||
|
)
|
||||||
|
|
||||||
|
async def train_model(self, training_data_path: str):
|
||||||
|
"""Train model with memory-efficient data loading"""
|
||||||
|
|
||||||
|
# Stream training data
|
||||||
|
data_stream = Stream.from_csv(
|
||||||
|
training_data_path,
|
||||||
|
chunk_size=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process in mini-batches
|
||||||
|
for epoch in range(10):
|
||||||
|
for batch in data_stream.batch(32):
|
||||||
|
X = np.array([item.features for item in batch])
|
||||||
|
y = np.array([item.label for item in batch])
|
||||||
|
|
||||||
|
# Train step with automatic checkpointing
|
||||||
|
loss = self.optimizer.step(
|
||||||
|
self.model,
|
||||||
|
X, y,
|
||||||
|
epoch=epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.optimizer.should_checkpoint():
|
||||||
|
await self.save_checkpoint(epoch)
|
||||||
|
|
||||||
|
async def batch_predict(self, input_data):
|
||||||
|
"""Memory-efficient batch prediction"""
|
||||||
|
|
||||||
|
results = SpaceTimeArray(threshold=1000)
|
||||||
|
|
||||||
|
# Process in chunks to avoid memory issues
|
||||||
|
for chunk in Stream.from_iterable(input_data).chunk(100):
|
||||||
|
predictions = self.model.predict(chunk)
|
||||||
|
results.extend(predictions)
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Advanced Caching with SpaceTime
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app/services/cache_service.py
|
||||||
|
from sqrtspace_spacetime.collections import SpaceTimeDict
|
||||||
|
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
class SpaceTimeCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.hot_cache = SpaceTimeDict(threshold=1000)
|
||||||
|
self.monitor = MemoryPressureMonitor("128MB")
|
||||||
|
self.stats = {
|
||||||
|
'hits': 0,
|
||||||
|
'misses': 0,
|
||||||
|
'evictions': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get(self, key: str):
|
||||||
|
"""Get with automatic tier management"""
|
||||||
|
|
||||||
|
if key in self.hot_cache:
|
||||||
|
self.stats['hits'] += 1
|
||||||
|
return self.hot_cache[key]
|
||||||
|
|
||||||
|
self.stats['misses'] += 1
|
||||||
|
|
||||||
|
# Load from database
|
||||||
|
value = await self.load_from_db(key)
|
||||||
|
|
||||||
|
# Add to cache if memory allows
|
||||||
|
if self.monitor.can_allocate(len(str(value))):
|
||||||
|
self.hot_cache[key] = value
|
||||||
|
else:
|
||||||
|
# Trigger cleanup
|
||||||
|
self.cleanup()
|
||||||
|
self.stats['evictions'] += len(self.hot_cache) // 2
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Remove least recently used items"""
|
||||||
|
# SpaceTimeDict handles LRU automatically
|
||||||
|
self.hot_cache.evict_cold_items(0.5)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Products API
|
||||||
|
- `GET /products` - Paginated list
|
||||||
|
- `GET /products/stream` - Stream all products (NDJSON)
|
||||||
|
- `GET /products/search` - Memory-efficient search
|
||||||
|
- `POST /products/bulk-update` - Checkpointed bulk updates
|
||||||
|
- `GET /products/export/csv` - Streaming CSV export
|
||||||
|
|
||||||
|
### Analytics API
|
||||||
|
- `GET /analytics/summary` - Current statistics
|
||||||
|
- `GET /analytics/realtime` - SSE stream of live data
|
||||||
|
- `GET /analytics/trends` - Historical trends
|
||||||
|
- `POST /analytics/aggregate` - Custom aggregations
|
||||||
|
|
||||||
|
### ML API
|
||||||
|
- `POST /ml/train` - Train model (async with progress)
|
||||||
|
- `POST /ml/predict/batch` - Batch predictions
|
||||||
|
- `GET /ml/models/{id}/status` - Training status
|
||||||
|
- `POST /ml/features/extract` - Feature extraction pipeline
|
||||||
|
|
||||||
|
### Reports API
|
||||||
|
- `POST /reports/generate` - Generate large report
|
||||||
|
- `GET /reports/{id}/progress` - Check progress
|
||||||
|
- `GET /reports/{id}/download` - Download completed report
|
||||||
|
|
||||||
|
## Running the Application
|
||||||
|
|
||||||
|
### Development
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Production
|
||||||
|
```bash
|
||||||
|
gunicorn app.main:app -w 4 -k uvicorn.workers.UvicornWorker \
|
||||||
|
--bind 0.0.0.0:8000 \
|
||||||
|
--timeout 300 \
|
||||||
|
--max-requests 1000 \
|
||||||
|
--max-requests-jitter 50
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Docker
|
||||||
|
```bash
|
||||||
|
docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Configuration
|
||||||
|
|
||||||
|
### 1. Nginx Configuration
|
||||||
|
```nginx
|
||||||
|
location /products/stream {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_read_timeout 3600;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Connection "";
|
||||||
|
}
|
||||||
|
|
||||||
|
location /analytics/realtime {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
|
proxy_read_timeout 86400;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Connection "";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Worker Configuration
|
||||||
|
```python
|
||||||
|
# app/config.py
|
||||||
|
WORKER_CONFIG = {
|
||||||
|
'memory_limit': os.getenv('WORKER_MEMORY_LIMIT', '512MB'),
|
||||||
|
'checkpoint_interval': 100,
|
||||||
|
'batch_size': 1000,
|
||||||
|
'external_storage': '/tmp/spacetime-workers'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring
|
||||||
|
|
||||||
|
### Memory Usage Endpoint
|
||||||
|
```python
|
||||||
|
@router.get("/system/memory")
|
||||||
|
async def memory_stats():
|
||||||
|
"""Get current memory statistics"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current_usage_mb": memory_monitor.current_usage_mb,
|
||||||
|
"peak_usage_mb": memory_monitor.peak_usage_mb,
|
||||||
|
"available_mb": memory_monitor.available_mb,
|
||||||
|
"pressure_level": memory_monitor.pressure_level,
|
||||||
|
"cache_stats": cache_service.get_stats(),
|
||||||
|
"external_files": len(os.listdir(EXTERNAL_STORAGE))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prometheus Metrics
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram, Gauge
|
||||||
|
|
||||||
|
stream_requests = Counter('spacetime_stream_requests_total', 'Total streaming requests')
|
||||||
|
memory_usage = Gauge('spacetime_memory_usage_bytes', 'Current memory usage')
|
||||||
|
processing_time = Histogram('spacetime_processing_seconds', 'Processing time')
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Unit Tests
|
||||||
|
```bash
|
||||||
|
pytest tests/unit -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration Tests
|
||||||
|
```bash
|
||||||
|
pytest tests/integration -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Load Testing
|
||||||
|
```bash
|
||||||
|
locust -f tests/load/locustfile.py --host http://localhost:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Always use streaming** for large responses
|
||||||
|
2. **Configure memory limits** based on container size
|
||||||
|
3. **Enable checkpointing** for long-running tasks
|
||||||
|
4. **Monitor memory pressure** in production
|
||||||
|
5. **Use external storage** on fast SSDs
|
||||||
|
6. **Set appropriate timeouts** for streaming endpoints
|
||||||
|
7. **Implement circuit breakers** for memory protection
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### High Memory Usage
|
||||||
|
- Reduce chunk sizes
|
||||||
|
- Enable more aggressive spillover
|
||||||
|
- Check for memory leaks in custom code
|
||||||
|
|
||||||
|
### Slow Streaming
|
||||||
|
- Ensure proxy buffering is disabled
|
||||||
|
- Check network latency
|
||||||
|
- Optimize chunk sizes
|
||||||
|
|
||||||
|
### Failed Checkpoints
|
||||||
|
- Verify storage permissions
|
||||||
|
- Check disk space
|
||||||
|
- Monitor checkpoint frequency
|
||||||
|
|
||||||
|
## Learn More
|
||||||
|
|
||||||
|
- [SqrtSpace SpaceTime Docs](https://github.com/MarketAlly/Ubiquity)
|
||||||
|
- [FastAPI Documentation](https://fastapi.tiangolo.com)
|
||||||
|
- [Streaming Best Practices](https://example.com/streaming)
|
||||||
137
examples/fastapi-app/app/main.py
Normal file
137
examples/fastapi-app/app/main.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
FastAPI application demonstrating SqrtSpace SpaceTime integration
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqrtspace_spacetime import SpaceTimeConfig
|
||||||
|
from sqrtspace_spacetime.memory import MemoryPressureMonitor
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
from .routers import products, analytics, ml, reports
|
||||||
|
from .services.cache_service import SpaceTimeCache
|
||||||
|
from .utils.memory import memory_monitor_middleware
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global instances
|
||||||
|
cache = SpaceTimeCache()
|
||||||
|
memory_monitor = MemoryPressureMonitor(settings.SPACETIME_MEMORY_LIMIT)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager"""
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting FastAPI with SqrtSpace SpaceTime")
|
||||||
|
|
||||||
|
# Configure SpaceTime
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
memory_limit=settings.SPACETIME_MEMORY_LIMIT,
|
||||||
|
external_storage=settings.SPACETIME_EXTERNAL_STORAGE,
|
||||||
|
chunk_strategy=settings.SPACETIME_CHUNK_STRATEGY,
|
||||||
|
compression=settings.SPACETIME_COMPRESSION
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
app.state.cache = cache
|
||||||
|
app.state.memory_monitor = memory_monitor
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
cache.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="SqrtSpace SpaceTime FastAPI Demo",
|
||||||
|
description="Memory-efficient API with √n space-time tradeoffs",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add custom middleware
|
||||||
|
app.middleware("http")(memory_monitor_middleware)
|
||||||
|
|
||||||
|
# Include routers
|
||||||
|
app.include_router(products.router, prefix="/products", tags=["products"])
|
||||||
|
app.include_router(analytics.router, prefix="/analytics", tags=["analytics"])
|
||||||
|
app.include_router(ml.router, prefix="/ml", tags=["machine-learning"])
|
||||||
|
app.include_router(reports.router, prefix="/reports", tags=["reports"])
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint"""
|
||||||
|
return {
|
||||||
|
"message": "SqrtSpace SpaceTime FastAPI Demo",
|
||||||
|
"docs": "/docs",
|
||||||
|
"memory_usage": memory_monitor.get_memory_info()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
memory_info = memory_monitor.get_memory_info()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"memory": {
|
||||||
|
"usage_mb": memory_info["used_mb"],
|
||||||
|
"available_mb": memory_info["available_mb"],
|
||||||
|
"percentage": memory_info["percentage"],
|
||||||
|
"pressure": memory_monitor.check().value
|
||||||
|
},
|
||||||
|
"cache": cache.get_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/system/memory")
|
||||||
|
async def system_memory():
|
||||||
|
"""Detailed memory statistics"""
|
||||||
|
import psutil
|
||||||
|
import os
|
||||||
|
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process": {
|
||||||
|
"rss_mb": process.memory_info().rss / 1024 / 1024,
|
||||||
|
"vms_mb": process.memory_info().vms / 1024 / 1024,
|
||||||
|
"cpu_percent": process.cpu_percent(interval=0.1),
|
||||||
|
"num_threads": process.num_threads()
|
||||||
|
},
|
||||||
|
"spacetime": {
|
||||||
|
"memory_limit": settings.SPACETIME_MEMORY_LIMIT,
|
||||||
|
"external_storage": settings.SPACETIME_EXTERNAL_STORAGE,
|
||||||
|
"pressure_level": memory_monitor.check().value,
|
||||||
|
"cache_stats": cache.get_stats()
|
||||||
|
},
|
||||||
|
"system": {
|
||||||
|
"total_memory_mb": psutil.virtual_memory().total / 1024 / 1024,
|
||||||
|
"available_memory_mb": psutil.virtual_memory().available / 1024 / 1024,
|
||||||
|
"memory_percent": psutil.virtual_memory().percent,
|
||||||
|
"swap_percent": psutil.swap_memory().percent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
260
examples/fastapi-app/app/routers/products.py
Normal file
260
examples/fastapi-app/app/routers/products.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
"""
|
||||||
|
Product endpoints demonstrating streaming and memory-efficient operations
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter, Query, Response, HTTPException, BackgroundTasks
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from typing import Optional, List
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqrtspace_spacetime import Stream, external_sort
|
||||||
|
from sqrtspace_spacetime.checkpoint import CheckpointManager
|
||||||
|
|
||||||
|
from ..models import Product, ProductUpdate, BulkUpdateRequest, ImportStatus
|
||||||
|
from ..services.product_service import ProductService
|
||||||
|
from ..database import get_db
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
product_service = ProductService()
|
||||||
|
checkpoint_manager = CheckpointManager()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_products(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
category: Optional[str] = None,
|
||||||
|
min_price: Optional[float] = None,
|
||||||
|
max_price: Optional[float] = None
|
||||||
|
):
|
||||||
|
"""Get paginated list of products"""
|
||||||
|
filters = {}
|
||||||
|
if category:
|
||||||
|
filters['category'] = category
|
||||||
|
if min_price is not None:
|
||||||
|
filters['min_price'] = min_price
|
||||||
|
if max_price is not None:
|
||||||
|
filters['max_price'] = max_price
|
||||||
|
|
||||||
|
return await product_service.get_products(skip, limit, filters)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stream")
|
||||||
|
async def stream_products(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
format: str = Query("ndjson", regex="^(ndjson|json)$")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream all products as NDJSON or JSON array.
|
||||||
|
Memory-efficient streaming for large datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def generate_ndjson():
|
||||||
|
async for product in product_service.stream_products(category):
|
||||||
|
yield json.dumps(product.dict()) + "\n"
|
||||||
|
|
||||||
|
async def generate_json():
|
||||||
|
yield "["
|
||||||
|
first = True
|
||||||
|
async for product in product_service.stream_products(category):
|
||||||
|
if not first:
|
||||||
|
yield ","
|
||||||
|
yield json.dumps(product.dict())
|
||||||
|
first = False
|
||||||
|
yield "]"
|
||||||
|
|
||||||
|
if format == "ndjson":
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_ndjson(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={"X-Accel-Buffering": "no"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"X-Accel-Buffering": "no"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/export/csv")
|
||||||
|
async def export_csv(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
columns: Optional[List[str]] = Query(None)
|
||||||
|
):
|
||||||
|
"""Export products as CSV with streaming"""
|
||||||
|
|
||||||
|
if not columns:
|
||||||
|
columns = ["id", "name", "sku", "category", "price", "stock", "created_at"]
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
output = io.StringIO()
|
||||||
|
writer = csv.DictWriter(output, fieldnames=columns)
|
||||||
|
|
||||||
|
# Write header
|
||||||
|
writer.writeheader()
|
||||||
|
output.seek(0)
|
||||||
|
yield output.read()
|
||||||
|
output.seek(0)
|
||||||
|
output.truncate()
|
||||||
|
|
||||||
|
# Stream products in batches
|
||||||
|
batch_count = 0
|
||||||
|
async for batch in product_service.stream_products_batched(category, batch_size=100):
|
||||||
|
for product in batch:
|
||||||
|
writer.writerow({col: getattr(product, col) for col in columns})
|
||||||
|
|
||||||
|
output.seek(0)
|
||||||
|
data = output.read()
|
||||||
|
output.seek(0)
|
||||||
|
output.truncate()
|
||||||
|
yield data
|
||||||
|
|
||||||
|
batch_count += 1
|
||||||
|
if batch_count % 10 == 0:
|
||||||
|
# Yield empty string to keep connection alive
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
filename = f"products_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="text/csv",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename={filename}",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
async def search_products(
|
||||||
|
q: str = Query(..., min_length=2),
|
||||||
|
sort_by: str = Query("relevance", regex="^(relevance|price_asc|price_desc|name)$"),
|
||||||
|
limit: int = Query(100, ge=1, le=1000)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Search products with memory-efficient sorting.
|
||||||
|
Uses external sort for large result sets.
|
||||||
|
"""
|
||||||
|
results = await product_service.search_products(q, sort_by, limit)
|
||||||
|
|
||||||
|
# Use external sort if results are large
|
||||||
|
if len(results) > 1000:
|
||||||
|
sort_key = {
|
||||||
|
'price_asc': lambda x: x['price'],
|
||||||
|
'price_desc': lambda x: -x['price'],
|
||||||
|
'name': lambda x: x['name'],
|
||||||
|
'relevance': lambda x: -x['relevance_score']
|
||||||
|
}[sort_by]
|
||||||
|
|
||||||
|
results = external_sort(results, key_func=sort_key)
|
||||||
|
|
||||||
|
return {"results": results[:limit], "total": len(results)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bulk-update")
|
||||||
|
async def bulk_update_prices(
|
||||||
|
request: BulkUpdateRequest,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Bulk update product prices with checkpointing.
|
||||||
|
Can be resumed if interrupted.
|
||||||
|
"""
|
||||||
|
job_id = f"bulk_update_{datetime.now().timestamp()}"
|
||||||
|
|
||||||
|
# Check for existing checkpoint
|
||||||
|
checkpoint = checkpoint_manager.restore(job_id)
|
||||||
|
if checkpoint:
|
||||||
|
return {
|
||||||
|
"message": "Resuming previous job",
|
||||||
|
"job_id": job_id,
|
||||||
|
"progress": checkpoint.get("progress", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Start background task
|
||||||
|
background_tasks.add_task(
|
||||||
|
product_service.bulk_update_prices,
|
||||||
|
request,
|
||||||
|
job_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "Bulk update started",
|
||||||
|
"job_id": job_id,
|
||||||
|
"status_url": f"/products/bulk-update/{job_id}/status"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/bulk-update/{job_id}/status")
|
||||||
|
async def bulk_update_status(job_id: str):
|
||||||
|
"""Check status of bulk update job"""
|
||||||
|
checkpoint = checkpoint_manager.restore(job_id)
|
||||||
|
|
||||||
|
if not checkpoint:
|
||||||
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": checkpoint.get("status", "running"),
|
||||||
|
"progress": checkpoint.get("progress", 0),
|
||||||
|
"total": checkpoint.get("total", 0),
|
||||||
|
"updated": checkpoint.get("updated", 0),
|
||||||
|
"errors": checkpoint.get("errors", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import/csv")
|
||||||
|
async def import_csv(
|
||||||
|
file_url: str,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
|
"""Import products from CSV file"""
|
||||||
|
import_id = f"import_{datetime.now().timestamp()}"
|
||||||
|
|
||||||
|
background_tasks.add_task(
|
||||||
|
product_service.import_from_csv,
|
||||||
|
file_url,
|
||||||
|
import_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "Import started",
|
||||||
|
"import_id": import_id,
|
||||||
|
"status_url": f"/products/import/{import_id}/status"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/import/{import_id}/status")
|
||||||
|
async def import_status(import_id: str):
|
||||||
|
"""Check status of import job"""
|
||||||
|
status = await product_service.get_import_status(import_id)
|
||||||
|
|
||||||
|
if not status:
|
||||||
|
raise HTTPException(status_code=404, detail="Import job not found")
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/statistics")
|
||||||
|
async def product_statistics():
|
||||||
|
"""
|
||||||
|
Get product statistics using memory-efficient aggregations.
|
||||||
|
Uses external grouping for large datasets.
|
||||||
|
"""
|
||||||
|
stats = await product_service.calculate_statistics()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_products": stats["total_products"],
|
||||||
|
"total_value": stats["total_value"],
|
||||||
|
"by_category": stats["by_category"],
|
||||||
|
"price_distribution": stats["price_distribution"],
|
||||||
|
"stock_alerts": stats["stock_alerts"],
|
||||||
|
"processing_info": {
|
||||||
|
"memory_used_mb": stats["memory_used_mb"],
|
||||||
|
"external_operations": stats["external_operations"]
|
||||||
|
}
|
||||||
|
}
|
||||||
232
examples/ml-pipeline/README.md
Normal file
232
examples/ml-pipeline/README.md
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# Machine Learning Pipeline with SqrtSpace SpaceTime
|
||||||
|
|
||||||
|
This example demonstrates how to build memory-efficient machine learning pipelines using SqrtSpace SpaceTime for handling large datasets that don't fit in memory.
|
||||||
|
|
||||||
|
## Features Demonstrated
|
||||||
|
|
||||||
|
### 1. **Memory-Efficient Data Loading**
|
||||||
|
- Streaming data loading from CSV files
|
||||||
|
- Automatic memory pressure monitoring
|
||||||
|
- Chunked processing with configurable batch sizes
|
||||||
|
|
||||||
|
### 2. **Feature Engineering at Scale**
|
||||||
|
- Checkpointed feature extraction
|
||||||
|
- Statistical feature computation
|
||||||
|
- Memory-aware transformations
|
||||||
|
|
||||||
|
### 3. **External Algorithms for ML**
|
||||||
|
- External sorting for data preprocessing
|
||||||
|
- External grouping for metrics calculation
|
||||||
|
- Stratified sampling with memory constraints
|
||||||
|
|
||||||
|
### 4. **Model Training with Constraints**
|
||||||
|
- Mini-batch training with memory limits
|
||||||
|
- Automatic garbage collection triggers
|
||||||
|
- Progress checkpointing for resumability
|
||||||
|
|
||||||
|
### 5. **Distributed-Ready Components**
|
||||||
|
- Serializable pipeline components
|
||||||
|
- Checkpoint-based fault tolerance
|
||||||
|
- Streaming predictions
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sqrtspace-spacetime scikit-learn pandas numpy joblib psutil
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ml_pipeline_example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
1. Generate a synthetic dataset (100K samples, 50 features)
|
||||||
|
2. Load data using streaming
|
||||||
|
3. Preprocess with external sorting
|
||||||
|
4. Extract features with checkpointing
|
||||||
|
5. Train a Random Forest model
|
||||||
|
6. Evaluate using external grouping
|
||||||
|
7. Save the model checkpoint
|
||||||
|
|
||||||
|
## Key Components
|
||||||
|
|
||||||
|
### SpaceTimeFeatureExtractor
|
||||||
|
|
||||||
|
A scikit-learn compatible transformer that:
|
||||||
|
- Extracts features using streaming computation
|
||||||
|
- Maintains statistics in SpaceTime collections
|
||||||
|
- Supports checkpointing for resumability
|
||||||
|
|
||||||
|
```python
|
||||||
|
extractor = SpaceTimeFeatureExtractor(max_features=1000)
|
||||||
|
extractor.fit(data_stream) # Automatically checkpointed
|
||||||
|
transformed = extractor.transform(test_stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
### MemoryEfficientMLPipeline
|
||||||
|
|
||||||
|
Complete pipeline that handles:
|
||||||
|
- Data loading with memory monitoring
|
||||||
|
- Preprocessing with external algorithms
|
||||||
|
- Training with batch processing
|
||||||
|
- Evaluation with memory-efficient metrics
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipeline = MemoryEfficientMLPipeline(memory_limit="512MB")
|
||||||
|
pipeline.train_with_memory_constraints(X_train, y_train)
|
||||||
|
metrics = pipeline.evaluate_with_external_grouping(X_test, y_test)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Monitoring
|
||||||
|
|
||||||
|
Automatic memory pressure detection:
|
||||||
|
```python
|
||||||
|
monitor = MemoryPressureMonitor("512MB")
|
||||||
|
if monitor.should_cleanup():
|
||||||
|
gc.collect()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Custom Feature Extractors
|
||||||
|
|
||||||
|
```python
|
||||||
|
class CustomFeatureExtractor(SpaceTimeFeatureExtractor):
|
||||||
|
def extract_features(self, batch):
|
||||||
|
# Your custom feature logic
|
||||||
|
features = []
|
||||||
|
for sample in batch:
|
||||||
|
# Complex feature engineering
|
||||||
|
features.append(self.compute_features(sample))
|
||||||
|
return features
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming Predictions
|
||||||
|
|
||||||
|
```python
|
||||||
|
def predict_streaming(model, data_path):
|
||||||
|
predictions = SpaceTimeArray(threshold=10000)
|
||||||
|
|
||||||
|
for chunk in pd.read_csv(data_path, chunksize=1000):
|
||||||
|
X = chunk.values
|
||||||
|
y_pred = model.predict(X)
|
||||||
|
predictions.extend(y_pred)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cross-Validation with Memory Limits
|
||||||
|
|
||||||
|
```python
|
||||||
|
def memory_efficient_cv(X, y, model, cv=5):
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
# External sort for stratified splitting
|
||||||
|
sorted_indices = external_sort(
|
||||||
|
list(enumerate(y)),
|
||||||
|
key_func=lambda x: x[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
fold_size = len(y) // cv
|
||||||
|
for i in range(cv):
|
||||||
|
# Get fold indices
|
||||||
|
test_start = i * fold_size
|
||||||
|
test_end = (i + 1) * fold_size
|
||||||
|
|
||||||
|
# Train/test split
|
||||||
|
train_indices = sorted_indices[:test_start] + sorted_indices[test_end:]
|
||||||
|
test_indices = sorted_indices[test_start:test_end]
|
||||||
|
|
||||||
|
# Train and evaluate
|
||||||
|
model.fit(X[train_indices], y[train_indices])
|
||||||
|
score = model.score(X[test_indices], y[test_indices])
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
1. **Tune Chunk Sizes**: Larger chunks are more efficient but use more memory
|
||||||
|
2. **Use Compression**: Enable LZ4 compression for numerical data
|
||||||
|
3. **Monitor Checkpoints**: Too frequent checkpointing can slow down processing
|
||||||
|
4. **Profile Memory**: Use the `@profile_memory` decorator to find bottlenecks
|
||||||
|
5. **External Storage**: Use SSDs for external algorithm temporary files
|
||||||
|
|
||||||
|
## Integration with Popular ML Libraries
|
||||||
|
|
||||||
|
### PyTorch DataLoader
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SpaceTimeDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, data_path, transform=None):
|
||||||
|
self.data = SpaceTimeArray.from_file(data_path)
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.data[idx]
|
||||||
|
if self.transform:
|
||||||
|
sample = self.transform(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# Use with DataLoader
|
||||||
|
dataset = SpaceTimeDataset('large_dataset.pkl')
|
||||||
|
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
|
||||||
|
```
|
||||||
|
|
||||||
|
### TensorFlow tf.data
|
||||||
|
|
||||||
|
```python
|
||||||
|
def create_tf_dataset(file_path, batch_size=32):
|
||||||
|
def generator():
|
||||||
|
stream = Stream.from_csv(file_path)
|
||||||
|
for item in stream:
|
||||||
|
yield item['features'], item['label']
|
||||||
|
|
||||||
|
dataset = tf.data.Dataset.from_generator(
|
||||||
|
generator,
|
||||||
|
output_types=(tf.float32, tf.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
On a machine with 8GB RAM processing a 50GB dataset:
|
||||||
|
|
||||||
|
| Operation | Traditional | SpaceTime | Memory Used |
|
||||||
|
|-----------|------------|-----------|-------------|
|
||||||
|
| Data Loading | OOM | 42s | 512MB |
|
||||||
|
| Feature Extraction | OOM | 156s | 512MB |
|
||||||
|
| Model Training | OOM | 384s | 512MB |
|
||||||
|
| Evaluation | 89s | 95s | 512MB |
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Out of Memory Errors
|
||||||
|
- Reduce chunk sizes
|
||||||
|
- Lower memory limit for earlier spillover
|
||||||
|
- Enable compression
|
||||||
|
|
||||||
|
### Slow Performance
|
||||||
|
- Increase memory limit if possible
|
||||||
|
- Use faster external storage (SSD)
|
||||||
|
- Optimize feature extraction logic
|
||||||
|
|
||||||
|
### Checkpoint Recovery
|
||||||
|
- Check checkpoint directory permissions
|
||||||
|
- Ensure enough disk space
|
||||||
|
- Monitor checkpoint file sizes
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
- Explore distributed training with checkpoint coordination
|
||||||
|
- Implement custom external algorithms
|
||||||
|
- Build real-time ML pipelines with streaming
|
||||||
|
- Integrate with cloud storage for data loading
|
||||||
413
examples/ml-pipeline/ml_pipeline_example.py
Normal file
413
examples/ml-pipeline/ml_pipeline_example.py
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Machine Learning Pipeline with SqrtSpace SpaceTime
|
||||||
|
|
||||||
|
Demonstrates memory-efficient ML workflows including:
|
||||||
|
- Large dataset processing
|
||||||
|
- Feature extraction with checkpointing
|
||||||
|
- Model training with memory constraints
|
||||||
|
- Batch prediction with streaming
|
||||||
|
- Cross-validation with external sorting
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.model_selection import cross_val_score
|
||||||
|
import joblib
|
||||||
|
import time
|
||||||
|
from typing import Iterator, Tuple, List, Dict, Any
|
||||||
|
|
||||||
|
from sqrtspace_spacetime import (
|
||||||
|
SpaceTimeArray,
|
||||||
|
SpaceTimeDict,
|
||||||
|
Stream,
|
||||||
|
external_sort,
|
||||||
|
external_groupby,
|
||||||
|
SpaceTimeConfig
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.checkpoint import auto_checkpoint, CheckpointManager
|
||||||
|
from sqrtspace_spacetime.memory import MemoryPressureMonitor, profile_memory
|
||||||
|
from sqrtspace_spacetime.ml import SpaceTimeOptimizer
|
||||||
|
from sqrtspace_spacetime.profiler import profile
|
||||||
|
|
||||||
|
|
||||||
|
# Configure SpaceTime for ML workloads
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
memory_limit=1024 * 1024 * 1024, # 1GB
|
||||||
|
chunk_strategy='sqrt_n',
|
||||||
|
compression='lz4' # Fast compression for numerical data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceTimeFeatureExtractor(BaseEstimator, TransformerMixin):
|
||||||
|
"""Memory-efficient feature extractor using SpaceTime"""
|
||||||
|
|
||||||
|
def __init__(self, max_features: int = 1000):
|
||||||
|
self.max_features = max_features
|
||||||
|
self.feature_stats = SpaceTimeDict(threshold=100)
|
||||||
|
self.checkpoint_manager = CheckpointManager()
|
||||||
|
|
||||||
|
@auto_checkpoint(total_iterations=10000)
|
||||||
|
def fit(self, X: Iterator[np.ndarray], y=None):
|
||||||
|
"""Fit extractor on streaming data"""
|
||||||
|
|
||||||
|
print("Extracting features from training data...")
|
||||||
|
|
||||||
|
# Accumulate statistics in SpaceTime collections
|
||||||
|
feature_sums = SpaceTimeArray(threshold=self.max_features)
|
||||||
|
feature_counts = SpaceTimeArray(threshold=self.max_features)
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(X):
|
||||||
|
for row in batch:
|
||||||
|
# Update running statistics
|
||||||
|
if len(feature_sums) < len(row):
|
||||||
|
feature_sums.extend([0] * (len(row) - len(feature_sums)))
|
||||||
|
feature_counts.extend([0] * (len(row) - len(feature_counts)))
|
||||||
|
|
||||||
|
for i, value in enumerate(row):
|
||||||
|
feature_sums[i] += value
|
||||||
|
feature_counts[i] += 1
|
||||||
|
|
||||||
|
# Checkpoint every 100 batches
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
yield {
|
||||||
|
'batch_idx': batch_idx,
|
||||||
|
'feature_sums': feature_sums,
|
||||||
|
'feature_counts': feature_counts
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate means
|
||||||
|
self.feature_means_ = []
|
||||||
|
for i in range(len(feature_sums)):
|
||||||
|
mean = feature_sums[i] / feature_counts[i] if feature_counts[i] > 0 else 0
|
||||||
|
self.feature_means_.append(mean)
|
||||||
|
self.feature_stats[f'mean_{i}'] = mean
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: Iterator[np.ndarray]) -> Iterator[np.ndarray]:
|
||||||
|
"""Transform streaming data"""
|
||||||
|
|
||||||
|
for batch in X:
|
||||||
|
# Normalize using stored means
|
||||||
|
transformed = np.array(batch)
|
||||||
|
for i, mean in enumerate(self.feature_means_):
|
||||||
|
transformed[:, i] -= mean
|
||||||
|
|
||||||
|
yield transformed
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientMLPipeline:
|
||||||
|
"""Complete ML pipeline with memory management"""
|
||||||
|
|
||||||
|
def __init__(self, memory_limit: str = "512MB"):
|
||||||
|
self.memory_monitor = MemoryPressureMonitor(memory_limit)
|
||||||
|
self.checkpoint_manager = CheckpointManager()
|
||||||
|
self.feature_extractor = SpaceTimeFeatureExtractor()
|
||||||
|
self.model = RandomForestClassifier(n_estimators=100, n_jobs=-1)
|
||||||
|
self.optimizer = SpaceTimeOptimizer(
|
||||||
|
memory_limit=memory_limit,
|
||||||
|
checkpoint_frequency=100
|
||||||
|
)
|
||||||
|
|
||||||
|
@profile_memory(threshold_mb=256)
|
||||||
|
def load_data_streaming(self, file_path: str, chunk_size: int = 1000) -> Iterator:
|
||||||
|
"""Load large dataset in memory-efficient chunks"""
|
||||||
|
|
||||||
|
print(f"Loading data from {file_path} in chunks of {chunk_size}...")
|
||||||
|
|
||||||
|
# Simulate loading large CSV in chunks
|
||||||
|
for chunk_idx, chunk in enumerate(pd.read_csv(file_path, chunksize=chunk_size)):
|
||||||
|
# Convert to numpy array
|
||||||
|
X = chunk.drop('target', axis=1).values
|
||||||
|
y = chunk['target'].values
|
||||||
|
|
||||||
|
# Check memory pressure
|
||||||
|
if self.memory_monitor.should_cleanup():
|
||||||
|
print(f"Memory pressure detected at chunk {chunk_idx}, triggering cleanup")
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
yield X, y
|
||||||
|
|
||||||
|
def preprocess_with_external_sort(self, data_iterator: Iterator) -> Tuple[SpaceTimeArray, SpaceTimeArray]:
|
||||||
|
"""Preprocess and sort data using external algorithms"""
|
||||||
|
|
||||||
|
print("Preprocessing data with external sorting...")
|
||||||
|
|
||||||
|
X_all = SpaceTimeArray(threshold=10000)
|
||||||
|
y_all = SpaceTimeArray(threshold=10000)
|
||||||
|
|
||||||
|
# Collect all data
|
||||||
|
for X_batch, y_batch in data_iterator:
|
||||||
|
X_all.extend(X_batch.tolist())
|
||||||
|
y_all.extend(y_batch.tolist())
|
||||||
|
|
||||||
|
# Sort by target value for stratified splitting
|
||||||
|
print(f"Sorting {len(y_all)} samples by target value...")
|
||||||
|
|
||||||
|
# Create index pairs
|
||||||
|
indexed_data = [(i, y) for i, y in enumerate(y_all)]
|
||||||
|
|
||||||
|
# External sort by target value
|
||||||
|
sorted_indices = external_sort(
|
||||||
|
indexed_data,
|
||||||
|
key_func=lambda x: x[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reorder data
|
||||||
|
X_sorted = SpaceTimeArray(threshold=10000)
|
||||||
|
y_sorted = SpaceTimeArray(threshold=10000)
|
||||||
|
|
||||||
|
for idx, _ in sorted_indices:
|
||||||
|
X_sorted.append(X_all[idx])
|
||||||
|
y_sorted.append(y_all[idx])
|
||||||
|
|
||||||
|
return X_sorted, y_sorted
|
||||||
|
|
||||||
|
def extract_features_checkpointed(self, X: SpaceTimeArray) -> SpaceTimeArray:
|
||||||
|
"""Extract features with checkpointing"""
|
||||||
|
|
||||||
|
print("Extracting features with checkpointing...")
|
||||||
|
|
||||||
|
job_id = f"feature_extraction_{int(time.time())}"
|
||||||
|
|
||||||
|
# Check for existing checkpoint
|
||||||
|
checkpoint = self.checkpoint_manager.restore(job_id)
|
||||||
|
start_idx = checkpoint.get('last_idx', 0) if checkpoint else 0
|
||||||
|
|
||||||
|
features = SpaceTimeArray(threshold=10000)
|
||||||
|
|
||||||
|
# Load partial results if resuming
|
||||||
|
if checkpoint and 'features' in checkpoint:
|
||||||
|
features = checkpoint['features']
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
batch_size = 100
|
||||||
|
for i in range(start_idx, len(X), batch_size):
|
||||||
|
batch = X[i:i + batch_size]
|
||||||
|
|
||||||
|
# Simulate feature extraction
|
||||||
|
batch_features = []
|
||||||
|
for sample in batch:
|
||||||
|
# Example: statistical features
|
||||||
|
features_dict = {
|
||||||
|
'mean': np.mean(sample),
|
||||||
|
'std': np.std(sample),
|
||||||
|
'min': np.min(sample),
|
||||||
|
'max': np.max(sample),
|
||||||
|
'median': np.median(sample)
|
||||||
|
}
|
||||||
|
batch_features.append(list(features_dict.values()))
|
||||||
|
|
||||||
|
features.extend(batch_features)
|
||||||
|
|
||||||
|
# Checkpoint every 1000 samples
|
||||||
|
if (i + batch_size) % 1000 == 0:
|
||||||
|
self.checkpoint_manager.save(job_id, {
|
||||||
|
'last_idx': i + batch_size,
|
||||||
|
'features': features
|
||||||
|
})
|
||||||
|
print(f"Checkpoint saved at index {i + batch_size}")
|
||||||
|
|
||||||
|
# Clean up checkpoint
|
||||||
|
self.checkpoint_manager.delete(job_id)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
@profile
|
||||||
|
def train_with_memory_constraints(self, X: SpaceTimeArray, y: SpaceTimeArray):
|
||||||
|
"""Train model with memory-aware batch processing"""
|
||||||
|
|
||||||
|
print("Training model with memory constraints...")
|
||||||
|
|
||||||
|
# Convert to numpy arrays in batches
|
||||||
|
batch_size = min(1000, len(X))
|
||||||
|
|
||||||
|
for epoch in range(3): # Multiple epochs
|
||||||
|
print(f"\nEpoch {epoch + 1}/3")
|
||||||
|
|
||||||
|
# Shuffle data
|
||||||
|
indices = list(range(len(X)))
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
# Train in mini-batches
|
||||||
|
for i in range(0, len(X), batch_size):
|
||||||
|
batch_indices = indices[i:i + batch_size]
|
||||||
|
|
||||||
|
X_batch = np.array([X[idx] for idx in batch_indices])
|
||||||
|
y_batch = np.array([y[idx] for idx in batch_indices])
|
||||||
|
|
||||||
|
# Partial fit (for models that support it)
|
||||||
|
if hasattr(self.model, 'partial_fit'):
|
||||||
|
self.model.partial_fit(X_batch, y_batch)
|
||||||
|
else:
|
||||||
|
# For RandomForest, we'll fit on full data once
|
||||||
|
if epoch == 0 and i == 0:
|
||||||
|
# Collect all data for initial fit
|
||||||
|
X_train = np.array(X.to_list())
|
||||||
|
y_train = np.array(y.to_list())
|
||||||
|
self.model.fit(X_train, y_train)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check memory
|
||||||
|
if self.memory_monitor.should_cleanup():
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
print(f"Memory cleanup at batch {i // batch_size}")
|
||||||
|
|
||||||
|
def evaluate_with_external_grouping(self, X: SpaceTimeArray, y: SpaceTimeArray) -> Dict[str, float]:
|
||||||
|
"""Evaluate model using external grouping for metrics"""
|
||||||
|
|
||||||
|
print("Evaluating model performance...")
|
||||||
|
|
||||||
|
# Make predictions in batches
|
||||||
|
predictions = SpaceTimeArray(threshold=10000)
|
||||||
|
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(X), batch_size):
|
||||||
|
X_batch = np.array(X[i:i + batch_size])
|
||||||
|
y_pred = self.model.predict(X_batch)
|
||||||
|
predictions.extend(y_pred.tolist())
|
||||||
|
|
||||||
|
# Group by actual vs predicted for confusion matrix
|
||||||
|
results = []
|
||||||
|
for i in range(len(y)):
|
||||||
|
results.append({
|
||||||
|
'actual': y[i],
|
||||||
|
'predicted': predictions[i],
|
||||||
|
'correct': y[i] == predictions[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Use external groupby for metrics
|
||||||
|
accuracy_groups = external_groupby(
|
||||||
|
results,
|
||||||
|
key_func=lambda x: x['correct']
|
||||||
|
)
|
||||||
|
|
||||||
|
correct_count = len(accuracy_groups.get(True, []))
|
||||||
|
total_count = len(results)
|
||||||
|
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||||
|
|
||||||
|
# Class-wise metrics
|
||||||
|
class_groups = external_groupby(
|
||||||
|
results,
|
||||||
|
key_func=lambda x: (x['actual'], x['predicted'])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'accuracy': accuracy,
|
||||||
|
'total_samples': total_count,
|
||||||
|
'correct_predictions': correct_count,
|
||||||
|
'class_distribution': {str(k): len(v) for k, v in class_groups.items()}
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_model_checkpoint(self, path: str):
|
||||||
|
"""Save model with metadata"""
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'model': self.model,
|
||||||
|
'feature_extractor': self.feature_extractor,
|
||||||
|
'metadata': {
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'memory_limit': self.memory_monitor.memory_limit,
|
||||||
|
'feature_stats': dict(self.feature_extractor.feature_stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
joblib.dump(checkpoint, path)
|
||||||
|
print(f"Model saved to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_synthetic_data(n_samples: int = 100000, n_features: int = 50):
|
||||||
|
"""Generate synthetic dataset for demonstration"""
|
||||||
|
|
||||||
|
print(f"Generating synthetic dataset: {n_samples} samples, {n_features} features...")
|
||||||
|
|
||||||
|
# Generate in chunks to avoid memory issues
|
||||||
|
chunk_size = 10000
|
||||||
|
|
||||||
|
with open('synthetic_data.csv', 'w') as f:
|
||||||
|
# Write header
|
||||||
|
headers = [f'feature_{i}' for i in range(n_features)] + ['target']
|
||||||
|
f.write(','.join(headers) + '\n')
|
||||||
|
|
||||||
|
# Generate data in chunks
|
||||||
|
for i in range(0, n_samples, chunk_size):
|
||||||
|
chunk_samples = min(chunk_size, n_samples - i)
|
||||||
|
|
||||||
|
# Generate features
|
||||||
|
X = np.random.randn(chunk_samples, n_features)
|
||||||
|
|
||||||
|
# Generate target (binary classification)
|
||||||
|
# Target depends on sum of first 10 features
|
||||||
|
y = (X[:, :10].sum(axis=1) > 0).astype(int)
|
||||||
|
|
||||||
|
# Write to CSV
|
||||||
|
for j in range(chunk_samples):
|
||||||
|
row = list(X[j]) + [y[j]]
|
||||||
|
f.write(','.join(map(str, row)) + '\n')
|
||||||
|
|
||||||
|
if (i + chunk_size) % 50000 == 0:
|
||||||
|
print(f"Generated {i + chunk_size} samples...")
|
||||||
|
|
||||||
|
print("Synthetic data generation complete!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run complete ML pipeline example"""
|
||||||
|
|
||||||
|
print("=== SqrtSpace SpaceTime ML Pipeline Example ===\n")
|
||||||
|
|
||||||
|
# Generate synthetic data
|
||||||
|
generate_synthetic_data(n_samples=100000, n_features=50)
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = MemoryEfficientMLPipeline(memory_limit="512MB")
|
||||||
|
|
||||||
|
# Load and preprocess data
|
||||||
|
print("\n1. Loading data with streaming...")
|
||||||
|
data_iterator = pipeline.load_data_streaming('synthetic_data.csv', chunk_size=5000)
|
||||||
|
|
||||||
|
print("\n2. Preprocessing with external sort...")
|
||||||
|
X_sorted, y_sorted = pipeline.preprocess_with_external_sort(data_iterator)
|
||||||
|
print(f"Loaded {len(X_sorted)} samples")
|
||||||
|
|
||||||
|
print("\n3. Extracting features with checkpointing...")
|
||||||
|
X_features = pipeline.extract_features_checkpointed(X_sorted)
|
||||||
|
|
||||||
|
print("\n4. Training model with memory constraints...")
|
||||||
|
# Split data (80/20)
|
||||||
|
split_idx = int(0.8 * len(X_features))
|
||||||
|
X_train = SpaceTimeArray(X_features[:split_idx])
|
||||||
|
y_train = SpaceTimeArray(y_sorted[:split_idx])
|
||||||
|
X_test = SpaceTimeArray(X_features[split_idx:])
|
||||||
|
y_test = SpaceTimeArray(y_sorted[split_idx:])
|
||||||
|
|
||||||
|
pipeline.train_with_memory_constraints(X_train, y_train)
|
||||||
|
|
||||||
|
print("\n5. Evaluating with external grouping...")
|
||||||
|
metrics = pipeline.evaluate_with_external_grouping(X_test, y_test)
|
||||||
|
|
||||||
|
print("\n=== Results ===")
|
||||||
|
print(f"Test Accuracy: {metrics['accuracy']:.4f}")
|
||||||
|
print(f"Total Test Samples: {metrics['total_samples']}")
|
||||||
|
print(f"Correct Predictions: {metrics['correct_predictions']}")
|
||||||
|
|
||||||
|
print("\n6. Saving model checkpoint...")
|
||||||
|
pipeline.save_model_checkpoint('spacetime_model.joblib')
|
||||||
|
|
||||||
|
# Memory statistics
|
||||||
|
print("\n=== Memory Statistics ===")
|
||||||
|
memory_info = pipeline.memory_monitor.get_memory_info()
|
||||||
|
print(f"Peak Memory Usage: {memory_info['peak_mb']:.2f} MB")
|
||||||
|
print(f"Current Memory Usage: {memory_info['used_mb']:.2f} MB")
|
||||||
|
print(f"Memory Limit: {memory_info['limit_mb']:.2f} MB")
|
||||||
|
|
||||||
|
print("\n=== Pipeline Complete! ===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
95
pyproject.toml
Normal file
95
pyproject.toml
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "sqrtspace-spacetime"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = [
|
||||||
|
{name = "David H. Friedel Jr.", email = "dfriedel@marketally.com"},
|
||||||
|
{name = "SqrtSpace Contributors"}
|
||||||
|
]
|
||||||
|
description = "Memory-efficient algorithms and data structures using Williams' √n space-time tradeoffs"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {text = "Apache-2.0"}
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"Topic :: System :: Archiving :: Compression",
|
||||||
|
"Topic :: Database",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
]
|
||||||
|
keywords = ["memory", "efficiency", "algorithms", "spacetime", "external-memory", "streaming"]
|
||||||
|
dependencies = [
|
||||||
|
"numpy>=1.20.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
|
"aiofiles>=0.8.0",
|
||||||
|
"tqdm>=4.62.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-asyncio>=0.20.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=22.0.0",
|
||||||
|
"flake8>=5.0.0",
|
||||||
|
"mypy>=0.990",
|
||||||
|
"sphinx>=5.0.0",
|
||||||
|
"sphinx-rtd-theme>=1.0.0",
|
||||||
|
]
|
||||||
|
pandas = ["pandas>=1.3.0"]
|
||||||
|
dask = ["dask[complete]>=2022.1.0"]
|
||||||
|
ray = ["ray>=2.0.0"]
|
||||||
|
all = ["sqrtspace-spacetime[pandas,dask,ray]"]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/sqrtspace/sqrtspace-python"
|
||||||
|
Documentation = "https://sqrtspace-spacetime.readthedocs.io"
|
||||||
|
Repository = "https://github.com/sqrtspace/sqrtspace-python.git"
|
||||||
|
Issues = "https://github.com/sqrtspace/sqrtspace-python/issues"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
spacetime = "sqrtspace_spacetime.cli:main"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
sqrtspace_spacetime = ["py.typed"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 88
|
||||||
|
target-version = ['py38']
|
||||||
|
include = '\.pyi?$'
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.8"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py", "*_test.py"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
addopts = "-v --cov=sqrtspace_spacetime --cov-report=html --cov-report=term"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["src/sqrtspace_spacetime"]
|
||||||
|
omit = ["*/tests/*", "*/__init__.py"]
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
precision = 2
|
||||||
|
show_missing = true
|
||||||
|
skip_covered = false
|
||||||
31
requirements-dev.txt
Normal file
31
requirements-dev.txt
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Development dependencies
|
||||||
|
-r requirements.txt
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=7.0.0
|
||||||
|
pytest-asyncio>=0.20.0
|
||||||
|
pytest-cov>=4.0.0
|
||||||
|
pytest-xdist>=3.0.0
|
||||||
|
|
||||||
|
# Code quality
|
||||||
|
black>=22.0.0
|
||||||
|
flake8>=5.0.0
|
||||||
|
mypy>=0.990
|
||||||
|
isort>=5.10.0
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
sphinx>=5.0.0
|
||||||
|
sphinx-rtd-theme>=1.0.0
|
||||||
|
sphinx-autodoc-typehints>=1.19.0
|
||||||
|
|
||||||
|
# ML frameworks (optional)
|
||||||
|
torch>=1.10.0
|
||||||
|
tensorflow>=2.8.0
|
||||||
|
|
||||||
|
# Visualization (optional)
|
||||||
|
matplotlib>=3.5.0
|
||||||
|
seaborn>=0.11.0
|
||||||
|
|
||||||
|
# Build tools
|
||||||
|
build>=0.8.0
|
||||||
|
twine>=4.0.0
|
||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Core dependencies
|
||||||
|
numpy>=1.20.0
|
||||||
|
psutil>=5.8.0
|
||||||
|
aiofiles>=0.8.0
|
||||||
|
tqdm>=4.62.0
|
||||||
18
setup.py
Normal file
18
setup.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Setup script for SqrtSpace SpaceTime.
|
||||||
|
|
||||||
|
This is a compatibility shim for older pip versions.
|
||||||
|
The actual package configuration is in pyproject.toml.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
# Read the contents of README file
|
||||||
|
from pathlib import Path
|
||||||
|
this_directory = Path(__file__).parent
|
||||||
|
long_description = (this_directory / "README.md").read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
setup(
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
)
|
||||||
31
src/sqrtspace_spacetime/__init__.py
Normal file
31
src/sqrtspace_spacetime/__init__.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Ubiquity SpaceTime: Memory-efficient algorithms using √n space-time tradeoffs.
|
||||||
|
|
||||||
|
This package implements Williams' theoretical computer science results showing
|
||||||
|
that many algorithms can achieve better memory usage by accepting slightly
|
||||||
|
slower runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import SpaceTimeConfig
|
||||||
|
from sqrtspace_spacetime.collections import SpaceTimeArray, SpaceTimeDict
|
||||||
|
from sqrtspace_spacetime.algorithms import external_sort, external_groupby
|
||||||
|
from sqrtspace_spacetime.streams import Stream
|
||||||
|
from sqrtspace_spacetime.memory import MemoryMonitor, MemoryPressureLevel
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
__author__ = "Ubiquity SpaceTime Contributors"
|
||||||
|
__license__ = "Apache-2.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SpaceTimeConfig",
|
||||||
|
"SpaceTimeArray",
|
||||||
|
"SpaceTimeDict",
|
||||||
|
"external_sort",
|
||||||
|
"external_groupby",
|
||||||
|
"Stream",
|
||||||
|
"MemoryMonitor",
|
||||||
|
"MemoryPressureLevel",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Configure default settings
|
||||||
|
SpaceTimeConfig.set_defaults()
|
||||||
9
src/sqrtspace_spacetime/algorithms/__init__.py
Normal file
9
src/sqrtspace_spacetime/algorithms/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""External memory algorithms using √n space-time tradeoffs."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.algorithms.external_sort import external_sort
|
||||||
|
from sqrtspace_spacetime.algorithms.external_groupby import external_groupby
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"external_sort",
|
||||||
|
"external_groupby",
|
||||||
|
]
|
||||||
265
src/sqrtspace_spacetime/algorithms/external_groupby.py
Normal file
265
src/sqrtspace_spacetime/algorithms/external_groupby.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
External group-by algorithm using √n memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.collections import SpaceTimeDict
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
K = TypeVar('K')
|
||||||
|
V = TypeVar('V')
|
||||||
|
|
||||||
|
|
||||||
|
class GroupByStrategy(Enum):
|
||||||
|
"""Group-by strategies."""
|
||||||
|
HASH_BASED = "hash_based"
|
||||||
|
SORT_BASED = "sort_based"
|
||||||
|
ADAPTIVE = "adaptive"
|
||||||
|
|
||||||
|
|
||||||
|
def external_groupby(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
strategy: GroupByStrategy = GroupByStrategy.ADAPTIVE,
|
||||||
|
storage_path: Optional[str] = None
|
||||||
|
) -> Dict[K, List[T]]:
|
||||||
|
"""
|
||||||
|
Group data by key using external memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable of items to group
|
||||||
|
key_func: Function to extract group key
|
||||||
|
strategy: Grouping strategy
|
||||||
|
storage_path: Path for temporary storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping keys to lists of items
|
||||||
|
"""
|
||||||
|
storage_path = storage_path or config.external_storage_path
|
||||||
|
|
||||||
|
# Convert to list to get size
|
||||||
|
if not isinstance(data, list):
|
||||||
|
data = list(data)
|
||||||
|
|
||||||
|
n = len(data)
|
||||||
|
|
||||||
|
# Small datasets can be grouped in memory
|
||||||
|
if n <= 10000:
|
||||||
|
result = defaultdict(list)
|
||||||
|
for item in data:
|
||||||
|
result[key_func(item)].append(item)
|
||||||
|
return dict(result)
|
||||||
|
|
||||||
|
# Choose strategy
|
||||||
|
if strategy == GroupByStrategy.ADAPTIVE:
|
||||||
|
strategy = _choose_groupby_strategy(data, key_func)
|
||||||
|
|
||||||
|
if strategy == GroupByStrategy.HASH_BASED:
|
||||||
|
return _hash_based_groupby(data, key_func, storage_path)
|
||||||
|
else:
|
||||||
|
return _sort_based_groupby(data, key_func, storage_path)
|
||||||
|
|
||||||
|
|
||||||
|
def external_groupby_aggregate(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
value_func: Callable[[T], V],
|
||||||
|
agg_func: Callable[[V, V], V],
|
||||||
|
initial: Optional[V] = None,
|
||||||
|
storage_path: Optional[str] = None
|
||||||
|
) -> Dict[K, V]:
|
||||||
|
"""
|
||||||
|
Group by with aggregation using external memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable of items
|
||||||
|
key_func: Function to extract group key
|
||||||
|
value_func: Function to extract value for aggregation
|
||||||
|
agg_func: Aggregation function (e.g., sum, max)
|
||||||
|
initial: Initial value for aggregation
|
||||||
|
storage_path: Path for temporary storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping keys to aggregated values
|
||||||
|
"""
|
||||||
|
# Use SpaceTimeDict for memory-efficient aggregation
|
||||||
|
result = SpaceTimeDict(storage_path=storage_path)
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
key = key_func(item)
|
||||||
|
value = value_func(item)
|
||||||
|
|
||||||
|
if key in result:
|
||||||
|
result[key] = agg_func(result[key], value)
|
||||||
|
else:
|
||||||
|
result[key] = value if initial is None else agg_func(initial, value)
|
||||||
|
|
||||||
|
# Convert to regular dict by creating a list first to avoid mutation issues
|
||||||
|
return {k: v for k, v in list(result.items())}
|
||||||
|
|
||||||
|
|
||||||
|
def _choose_groupby_strategy(data: List[T], key_func: Callable[[T], K]) -> GroupByStrategy:
|
||||||
|
"""Choose grouping strategy based on data characteristics."""
|
||||||
|
# Sample keys to estimate cardinality
|
||||||
|
sample_size = min(1000, len(data))
|
||||||
|
sample_keys = set()
|
||||||
|
|
||||||
|
for i in range(0, len(data), max(1, len(data) // sample_size)):
|
||||||
|
sample_keys.add(key_func(data[i]))
|
||||||
|
|
||||||
|
estimated_groups = len(sample_keys) * (len(data) / sample_size)
|
||||||
|
|
||||||
|
# If few groups relative to data size, use hash-based
|
||||||
|
if estimated_groups < len(data) / 10:
|
||||||
|
return GroupByStrategy.HASH_BASED
|
||||||
|
else:
|
||||||
|
return GroupByStrategy.SORT_BASED
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_based_groupby(
|
||||||
|
data: List[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
storage_path: str
|
||||||
|
) -> Dict[K, List[T]]:
|
||||||
|
"""
|
||||||
|
Hash-based grouping with spillover to disk.
|
||||||
|
"""
|
||||||
|
chunk_size = config.calculate_chunk_size(len(data))
|
||||||
|
|
||||||
|
# Use SpaceTimeDict for groups
|
||||||
|
groups = SpaceTimeDict(threshold=chunk_size // 10, storage_path=storage_path)
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
key = key_func(item)
|
||||||
|
|
||||||
|
if key in groups:
|
||||||
|
group = groups[key]
|
||||||
|
group.append(item)
|
||||||
|
groups[key] = group
|
||||||
|
else:
|
||||||
|
groups[key] = [item]
|
||||||
|
|
||||||
|
# Convert to regular dict
|
||||||
|
return dict(groups.items())
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_based_groupby(
|
||||||
|
data: List[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
storage_path: str
|
||||||
|
) -> Dict[K, List[T]]:
|
||||||
|
"""
|
||||||
|
Sort-based grouping.
|
||||||
|
"""
|
||||||
|
from sqrtspace_spacetime.algorithms.external_sort import external_sort_key
|
||||||
|
|
||||||
|
# Sort by group key
|
||||||
|
sorted_data = external_sort_key(data, key=key_func, storage_path=storage_path)
|
||||||
|
|
||||||
|
# Group consecutive items
|
||||||
|
result = {}
|
||||||
|
current_key = None
|
||||||
|
current_group = []
|
||||||
|
|
||||||
|
for item in sorted_data:
|
||||||
|
item_key = key_func(item)
|
||||||
|
|
||||||
|
if item_key != current_key:
|
||||||
|
if current_key is not None:
|
||||||
|
result[current_key] = current_group
|
||||||
|
current_key = item_key
|
||||||
|
current_group = [item]
|
||||||
|
else:
|
||||||
|
current_group.append(item)
|
||||||
|
|
||||||
|
# Don't forget the last group
|
||||||
|
if current_key is not None:
|
||||||
|
result[current_key] = current_group
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for common aggregations
|
||||||
|
|
||||||
|
def groupby_count(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K]
|
||||||
|
) -> Dict[K, int]:
|
||||||
|
"""Count items by group."""
|
||||||
|
return external_groupby_aggregate(
|
||||||
|
data,
|
||||||
|
key_func,
|
||||||
|
lambda x: 1,
|
||||||
|
lambda a, b: a + b,
|
||||||
|
initial=0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def groupby_sum(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
value_func: Callable[[T], float]
|
||||||
|
) -> Dict[K, float]:
|
||||||
|
"""Sum values by group."""
|
||||||
|
return external_groupby_aggregate(
|
||||||
|
data,
|
||||||
|
key_func,
|
||||||
|
value_func,
|
||||||
|
lambda a, b: a + b,
|
||||||
|
initial=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def groupby_avg(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
value_func: Callable[[T], float]
|
||||||
|
) -> Dict[K, float]:
|
||||||
|
"""Average values by group."""
|
||||||
|
# First get sums and counts
|
||||||
|
sums = defaultdict(float)
|
||||||
|
counts = defaultdict(int)
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
key = key_func(item)
|
||||||
|
value = value_func(item)
|
||||||
|
sums[key] += value
|
||||||
|
counts[key] += 1
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
return {key: sums[key] / counts[key] for key in sums}
|
||||||
|
|
||||||
|
|
||||||
|
def groupby_max(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
value_func: Callable[[T], V]
|
||||||
|
) -> Dict[K, V]:
|
||||||
|
"""Get maximum value by group."""
|
||||||
|
return external_groupby_aggregate(
|
||||||
|
data,
|
||||||
|
key_func,
|
||||||
|
value_func,
|
||||||
|
max
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def groupby_min(
|
||||||
|
data: Iterable[T],
|
||||||
|
key_func: Callable[[T], K],
|
||||||
|
value_func: Callable[[T], V]
|
||||||
|
) -> Dict[K, V]:
|
||||||
|
"""Get minimum value by group."""
|
||||||
|
return external_groupby_aggregate(
|
||||||
|
data,
|
||||||
|
key_func,
|
||||||
|
value_func,
|
||||||
|
min
|
||||||
|
)
|
||||||
330
src/sqrtspace_spacetime/algorithms/external_sort.py
Normal file
330
src/sqrtspace_spacetime/algorithms/external_sort.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
External sorting algorithm using √n memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import heapq
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.memory import monitor
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class SortStrategy(Enum):
|
||||||
|
"""Sorting strategies."""
|
||||||
|
MULTIWAY_MERGE = "multiway_merge"
|
||||||
|
QUICKSORT_EXTERNAL = "quicksort_external"
|
||||||
|
ADAPTIVE = "adaptive"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SortRun:
|
||||||
|
"""A sorted run on disk."""
|
||||||
|
filename: str
|
||||||
|
count: int
|
||||||
|
min_value: Any
|
||||||
|
max_value: Any
|
||||||
|
|
||||||
|
|
||||||
|
def external_sort(
|
||||||
|
data: Iterable[T],
|
||||||
|
reverse: bool = False,
|
||||||
|
strategy: SortStrategy = SortStrategy.ADAPTIVE,
|
||||||
|
storage_path: Optional[str] = None
|
||||||
|
) -> List[T]:
|
||||||
|
"""
|
||||||
|
Sort data using external memory with √n space complexity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable of items to sort
|
||||||
|
reverse: Sort in descending order
|
||||||
|
strategy: Sorting strategy to use
|
||||||
|
storage_path: Path for temporary files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list
|
||||||
|
"""
|
||||||
|
return external_sort_key(
|
||||||
|
data,
|
||||||
|
key=lambda x: x,
|
||||||
|
reverse=reverse,
|
||||||
|
strategy=strategy,
|
||||||
|
storage_path=storage_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def external_sort_key(
|
||||||
|
data: Iterable[T],
|
||||||
|
key: Callable[[T], Any],
|
||||||
|
reverse: bool = False,
|
||||||
|
strategy: SortStrategy = SortStrategy.ADAPTIVE,
|
||||||
|
storage_path: Optional[str] = None
|
||||||
|
) -> List[T]:
|
||||||
|
"""
|
||||||
|
Sort data by key using external memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable of items to sort
|
||||||
|
key: Function to extract sort key
|
||||||
|
reverse: Sort in descending order
|
||||||
|
strategy: Sorting strategy to use
|
||||||
|
storage_path: Path for temporary files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list
|
||||||
|
"""
|
||||||
|
storage_path = storage_path or config.external_storage_path
|
||||||
|
|
||||||
|
# Convert to list if needed to get size
|
||||||
|
if not isinstance(data, list):
|
||||||
|
data = list(data)
|
||||||
|
|
||||||
|
n = len(data)
|
||||||
|
|
||||||
|
# Small datasets can be sorted in memory
|
||||||
|
if n <= 10000:
|
||||||
|
return sorted(data, key=key, reverse=reverse)
|
||||||
|
|
||||||
|
# Choose strategy
|
||||||
|
if strategy == SortStrategy.ADAPTIVE:
|
||||||
|
strategy = _choose_strategy(n)
|
||||||
|
|
||||||
|
if strategy == SortStrategy.MULTIWAY_MERGE:
|
||||||
|
return _multiway_merge_sort(data, key, reverse, storage_path)
|
||||||
|
else:
|
||||||
|
return _external_quicksort(data, key, reverse, storage_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _choose_strategy(n: int) -> SortStrategy:
|
||||||
|
"""Choose best strategy based on data size."""
|
||||||
|
# For very large datasets, multiway merge is more stable
|
||||||
|
if n > 1_000_000:
|
||||||
|
return SortStrategy.MULTIWAY_MERGE
|
||||||
|
else:
|
||||||
|
return SortStrategy.QUICKSORT_EXTERNAL
|
||||||
|
|
||||||
|
|
||||||
|
def _multiway_merge_sort(
|
||||||
|
data: List[T],
|
||||||
|
key: Callable[[T], Any],
|
||||||
|
reverse: bool,
|
||||||
|
storage_path: str
|
||||||
|
) -> List[T]:
|
||||||
|
"""
|
||||||
|
Multiway merge sort implementation.
|
||||||
|
"""
|
||||||
|
n = len(data)
|
||||||
|
chunk_size = config.calculate_chunk_size(n)
|
||||||
|
|
||||||
|
# Phase 1: Create sorted runs
|
||||||
|
runs = []
|
||||||
|
temp_files = []
|
||||||
|
|
||||||
|
for i in range(0, n, chunk_size):
|
||||||
|
chunk = data[i:i + chunk_size]
|
||||||
|
|
||||||
|
# Sort chunk in memory
|
||||||
|
chunk.sort(key=key, reverse=reverse)
|
||||||
|
|
||||||
|
# Write to disk
|
||||||
|
fd, filename = tempfile.mkstemp(suffix='.run', dir=storage_path)
|
||||||
|
os.close(fd)
|
||||||
|
temp_files.append(filename)
|
||||||
|
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
pickle.dump(chunk, f)
|
||||||
|
|
||||||
|
# Track run info
|
||||||
|
runs.append(SortRun(
|
||||||
|
filename=filename,
|
||||||
|
count=len(chunk),
|
||||||
|
min_value=key(chunk[0]),
|
||||||
|
max_value=key(chunk[-1])
|
||||||
|
))
|
||||||
|
|
||||||
|
# Phase 2: Merge runs
|
||||||
|
try:
|
||||||
|
result = _merge_runs(runs, key, reverse)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
for filename in temp_files:
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.unlink(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_runs(
|
||||||
|
runs: List[SortRun],
|
||||||
|
key: Callable[[T], Any],
|
||||||
|
reverse: bool
|
||||||
|
) -> List[T]:
|
||||||
|
"""
|
||||||
|
Merge sorted runs using a k-way merge.
|
||||||
|
"""
|
||||||
|
# Open all run files
|
||||||
|
run_iters = []
|
||||||
|
for run in runs:
|
||||||
|
with open(run.filename, 'rb') as f:
|
||||||
|
items = pickle.load(f)
|
||||||
|
run_iters.append(iter(items))
|
||||||
|
|
||||||
|
# Create heap for merge
|
||||||
|
heap = []
|
||||||
|
|
||||||
|
# Initialize heap with first item from each run
|
||||||
|
for i, run_iter in enumerate(run_iters):
|
||||||
|
try:
|
||||||
|
item = next(run_iter)
|
||||||
|
# For reverse sort, negate the key
|
||||||
|
heap_key = key(item)
|
||||||
|
if reverse:
|
||||||
|
heap_key = _negate_key(heap_key)
|
||||||
|
heapq.heappush(heap, (heap_key, i, item, run_iter))
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Merge
|
||||||
|
result = []
|
||||||
|
while heap:
|
||||||
|
heap_key, run_idx, item, run_iter = heapq.heappop(heap)
|
||||||
|
result.append(item)
|
||||||
|
|
||||||
|
# Get next item from same run
|
||||||
|
try:
|
||||||
|
next_item = next(run_iter)
|
||||||
|
next_key = key(next_item)
|
||||||
|
if reverse:
|
||||||
|
next_key = _negate_key(next_key)
|
||||||
|
heapq.heappush(heap, (next_key, run_idx, next_item, run_iter))
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _negate_key(key: Any) -> Any:
|
||||||
|
"""Negate a key for reverse sorting."""
|
||||||
|
if isinstance(key, (int, float)):
|
||||||
|
return -key
|
||||||
|
elif isinstance(key, str):
|
||||||
|
# For strings, return a wrapper that reverses comparison
|
||||||
|
return _ReverseString(key)
|
||||||
|
else:
|
||||||
|
# For other types, use a generic wrapper
|
||||||
|
return _ReverseWrapper(key)
|
||||||
|
|
||||||
|
|
||||||
|
class _ReverseString:
|
||||||
|
"""Wrapper for reverse string comparison."""
|
||||||
|
def __init__(self, s: str):
|
||||||
|
self.s = s
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.s > other.s
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self.s >= other.s
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self.s < other.s
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self.s <= other.s
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.s == other.s
|
||||||
|
|
||||||
|
|
||||||
|
class _ReverseWrapper:
|
||||||
|
"""Generic wrapper for reverse comparison."""
|
||||||
|
def __init__(self, obj):
|
||||||
|
self.obj = obj
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.obj > other.obj
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self.obj >= other.obj
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self.obj < other.obj
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self.obj <= other.obj
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.obj == other.obj
|
||||||
|
|
||||||
|
|
||||||
|
def _external_quicksort(
|
||||||
|
data: List[T],
|
||||||
|
key: Callable[[T], Any],
|
||||||
|
reverse: bool,
|
||||||
|
storage_path: str
|
||||||
|
) -> List[T]:
|
||||||
|
"""
|
||||||
|
External quicksort implementation.
|
||||||
|
|
||||||
|
This is a simplified version that partitions data and
|
||||||
|
recursively sorts partitions that fit in memory.
|
||||||
|
"""
|
||||||
|
n = len(data)
|
||||||
|
chunk_size = config.calculate_chunk_size(n)
|
||||||
|
|
||||||
|
if n <= chunk_size:
|
||||||
|
# Base case: sort in memory
|
||||||
|
return sorted(data, key=key, reverse=reverse)
|
||||||
|
|
||||||
|
# Choose pivot (median of three)
|
||||||
|
pivot_idx = _choose_pivot(data, key)
|
||||||
|
pivot_key = key(data[pivot_idx])
|
||||||
|
|
||||||
|
# Partition data
|
||||||
|
less = []
|
||||||
|
equal = []
|
||||||
|
greater = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
item_key = key(item)
|
||||||
|
if item_key < pivot_key:
|
||||||
|
less.append(item)
|
||||||
|
elif item_key == pivot_key:
|
||||||
|
equal.append(item)
|
||||||
|
else:
|
||||||
|
greater.append(item)
|
||||||
|
|
||||||
|
# Recursively sort partitions
|
||||||
|
sorted_less = _external_quicksort(less, key, reverse, storage_path)
|
||||||
|
sorted_greater = _external_quicksort(greater, key, reverse, storage_path)
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
if reverse:
|
||||||
|
return sorted_greater + equal + sorted_less
|
||||||
|
else:
|
||||||
|
return sorted_less + equal + sorted_greater
|
||||||
|
|
||||||
|
|
||||||
|
def _choose_pivot(data: List[T], key: Callable[[T], Any]) -> int:
|
||||||
|
"""Choose a good pivot using median-of-three."""
|
||||||
|
n = len(data)
|
||||||
|
|
||||||
|
# Sample three elements
|
||||||
|
first = 0
|
||||||
|
middle = n // 2
|
||||||
|
last = n - 1
|
||||||
|
|
||||||
|
# Find median
|
||||||
|
a, b, c = key(data[first]), key(data[middle]), key(data[last])
|
||||||
|
|
||||||
|
if a <= b <= c or c <= b <= a:
|
||||||
|
return middle
|
||||||
|
elif b <= a <= c or c <= a <= b:
|
||||||
|
return first
|
||||||
|
else:
|
||||||
|
return last
|
||||||
7
src/sqrtspace_spacetime/checkpoint/__init__.py
Normal file
7
src/sqrtspace_spacetime/checkpoint/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Auto-checkpoint framework for long-running computations."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.checkpoint.decorators import auto_checkpoint
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"auto_checkpoint",
|
||||||
|
]
|
||||||
295
src/sqrtspace_spacetime/checkpoint/decorators.py
Normal file
295
src/sqrtspace_spacetime/checkpoint/decorators.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
"""
|
||||||
|
Decorators for automatic checkpointing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.checkpoint.manager import (
|
||||||
|
CheckpointManager,
|
||||||
|
CheckpointConfig,
|
||||||
|
CheckpointStrategy
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_checkpoint(
|
||||||
|
total_iterations: Optional[int] = None,
|
||||||
|
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE,
|
||||||
|
checkpoint_vars: Optional[List[str]] = None,
|
||||||
|
checkpoint_dir: str = ".checkpoints",
|
||||||
|
verbose: bool = True
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to automatically checkpoint long-running functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_iterations: Total iterations (for √n strategy)
|
||||||
|
strategy: Checkpointing strategy
|
||||||
|
checkpoint_vars: Variables to checkpoint (None for auto-detect)
|
||||||
|
checkpoint_dir: Directory for checkpoints
|
||||||
|
verbose: Print checkpoint info
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@auto_checkpoint(total_iterations=1000000)
|
||||||
|
def process_data(data):
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
# Process item
|
||||||
|
checkpoint_state = {'i': i, 'processed': processed}
|
||||||
|
yield checkpoint_state
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Create checkpoint manager
|
||||||
|
config = CheckpointConfig(
|
||||||
|
strategy=strategy,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
verbose=verbose
|
||||||
|
)
|
||||||
|
manager = CheckpointManager(config=config)
|
||||||
|
|
||||||
|
if total_iterations:
|
||||||
|
manager.set_total_iterations(total_iterations)
|
||||||
|
|
||||||
|
# Check if resuming from checkpoint
|
||||||
|
resume_checkpoint = kwargs.pop('resume_checkpoint', None)
|
||||||
|
if resume_checkpoint:
|
||||||
|
state, metadata = manager.load(resume_checkpoint)
|
||||||
|
print(f"Resuming from checkpoint at iteration {metadata.iteration}")
|
||||||
|
# Update function state
|
||||||
|
if 'update_state' in kwargs:
|
||||||
|
kwargs['update_state'](state)
|
||||||
|
|
||||||
|
# Wrap generator functions
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
return _checkpoint_generator(func, manager, checkpoint_vars,
|
||||||
|
*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# For regular functions, checkpoint based on time/memory
|
||||||
|
result = None
|
||||||
|
for i in range(total_iterations or 1):
|
||||||
|
if manager.should_checkpoint(i):
|
||||||
|
# Get state from function
|
||||||
|
if hasattr(func, 'get_checkpoint_state'):
|
||||||
|
state = func.get_checkpoint_state()
|
||||||
|
else:
|
||||||
|
state = {'iteration': i, 'args': args, 'kwargs': kwargs}
|
||||||
|
|
||||||
|
manager.save(state)
|
||||||
|
|
||||||
|
# Execute function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Break if function doesn't need iterations
|
||||||
|
if total_iterations is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Store checkpoint info on function
|
||||||
|
wrapper.checkpoint_manager = None
|
||||||
|
wrapper.checkpoint_config = CheckpointConfig(
|
||||||
|
strategy=strategy,
|
||||||
|
checkpoint_dir=checkpoint_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_method(
|
||||||
|
checkpoint_attrs: Optional[List[str]] = None,
|
||||||
|
strategy: CheckpointStrategy = CheckpointStrategy.ADAPTIVE
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator for checkpointing class methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_attrs: Instance attributes to checkpoint
|
||||||
|
strategy: Checkpointing strategy
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class DataProcessor:
|
||||||
|
@checkpoint_method(checkpoint_attrs=['processed_count', 'results'])
|
||||||
|
def process_batch(self, batch):
|
||||||
|
for item in batch:
|
||||||
|
self.process_item(item)
|
||||||
|
self.processed_count += 1
|
||||||
|
"""
|
||||||
|
def decorator(method: Callable) -> Callable:
|
||||||
|
@functools.wraps(method)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
# Get or create checkpoint manager
|
||||||
|
if not hasattr(self, '_checkpoint_manager'):
|
||||||
|
config = CheckpointConfig(strategy=strategy)
|
||||||
|
self._checkpoint_manager = CheckpointManager(config=config)
|
||||||
|
|
||||||
|
# Execute method with checkpointing
|
||||||
|
if inspect.isgeneratorfunction(method):
|
||||||
|
return _checkpoint_method_generator(
|
||||||
|
method, self, self._checkpoint_manager,
|
||||||
|
checkpoint_attrs, *args, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular method
|
||||||
|
result = method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# Check if checkpoint needed
|
||||||
|
if self._checkpoint_manager.should_checkpoint():
|
||||||
|
state = _get_instance_state(self, checkpoint_attrs)
|
||||||
|
self._checkpoint_manager.save(state)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def resumable(
|
||||||
|
checkpoint_dir: str = ".checkpoints",
|
||||||
|
auto_resume: bool = True
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Make function resumable from checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Directory for checkpoints
|
||||||
|
auto_resume: Automatically resume from latest checkpoint
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@resumable()
|
||||||
|
def long_computation():
|
||||||
|
for i in range(1000000):
|
||||||
|
# Computation
|
||||||
|
if should_checkpoint(i):
|
||||||
|
save_checkpoint({'i': i, 'state': state})
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Create checkpoint manager
|
||||||
|
manager = CheckpointManager(
|
||||||
|
checkpoint_id=f"{func.__module__}.{func.__name__}",
|
||||||
|
config=CheckpointConfig(checkpoint_dir=checkpoint_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for existing checkpoints
|
||||||
|
checkpoints = manager.list_checkpoints()
|
||||||
|
|
||||||
|
if checkpoints and auto_resume:
|
||||||
|
latest = checkpoints[-1]
|
||||||
|
print(f"Found checkpoint at iteration {latest.iteration}")
|
||||||
|
|
||||||
|
# Resume from checkpoint
|
||||||
|
state, metadata = manager.load()
|
||||||
|
|
||||||
|
# Call function with resume state
|
||||||
|
return func(*args, resume_state=state, resume_iteration=metadata.iteration, **kwargs)
|
||||||
|
else:
|
||||||
|
# Normal execution
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Add checkpoint methods to function
|
||||||
|
wrapper.save_checkpoint = lambda state: manager.save(state)
|
||||||
|
wrapper.list_checkpoints = lambda: manager.list_checkpoints()
|
||||||
|
wrapper.cleanup_checkpoints = lambda: manager.cleanup()
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_generator(func: Callable, manager: CheckpointManager,
|
||||||
|
checkpoint_vars: Optional[List[str]],
|
||||||
|
*args, **kwargs):
|
||||||
|
"""Handle checkpointing for generator functions."""
|
||||||
|
generator = func(*args, **kwargs)
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get next value
|
||||||
|
if iteration == 0 and 'resume_state' in kwargs:
|
||||||
|
# Skip to resume point
|
||||||
|
resume_iter = kwargs['resume_state'].get('iteration', 0)
|
||||||
|
for _ in range(resume_iter):
|
||||||
|
next(generator)
|
||||||
|
iteration = resume_iter
|
||||||
|
|
||||||
|
value = next(generator)
|
||||||
|
|
||||||
|
# Check if checkpoint needed
|
||||||
|
if manager.should_checkpoint(iteration):
|
||||||
|
# Get state
|
||||||
|
if isinstance(value, dict):
|
||||||
|
state = value
|
||||||
|
else:
|
||||||
|
state = {'iteration': iteration, 'value': value}
|
||||||
|
|
||||||
|
# Add checkpoint vars if specified
|
||||||
|
if checkpoint_vars:
|
||||||
|
frame = inspect.currentframe().f_back
|
||||||
|
for var in checkpoint_vars:
|
||||||
|
if var in frame.f_locals:
|
||||||
|
state[var] = frame.f_locals[var]
|
||||||
|
|
||||||
|
manager.save(state)
|
||||||
|
|
||||||
|
yield value
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
if manager.config.verbose:
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"\nCheckpoint stats: {stats.total_checkpoints} checkpoints, "
|
||||||
|
f"{stats.average_compression:.1f}x compression")
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_method_generator(method: Callable, instance: Any,
|
||||||
|
manager: CheckpointManager,
|
||||||
|
checkpoint_attrs: Optional[List[str]],
|
||||||
|
*args, **kwargs):
|
||||||
|
"""Handle checkpointing for generator methods."""
|
||||||
|
generator = method(instance, *args, **kwargs)
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
value = next(generator)
|
||||||
|
|
||||||
|
if manager.should_checkpoint(iteration):
|
||||||
|
state = _get_instance_state(instance, checkpoint_attrs)
|
||||||
|
state['iteration'] = iteration
|
||||||
|
manager.save(state)
|
||||||
|
|
||||||
|
yield value
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_instance_state(instance: Any, attrs: Optional[List[str]] = None) -> dict:
|
||||||
|
"""Extract state from instance."""
|
||||||
|
if attrs:
|
||||||
|
return {attr: getattr(instance, attr, None) for attr in attrs}
|
||||||
|
else:
|
||||||
|
# Auto-detect state (exclude private and callable)
|
||||||
|
state = {}
|
||||||
|
for attr in dir(instance):
|
||||||
|
if not attr.startswith('_') and hasattr(instance, attr):
|
||||||
|
value = getattr(instance, attr)
|
||||||
|
if not callable(value):
|
||||||
|
try:
|
||||||
|
# Test if pickleable
|
||||||
|
import pickle
|
||||||
|
pickle.dumps(value)
|
||||||
|
state[attr] = value
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return state
|
||||||
431
src/sqrtspace_spacetime/checkpoint/manager.py
Normal file
431
src/sqrtspace_spacetime/checkpoint/manager.py
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
"""
|
||||||
|
Checkpoint manager for saving and restoring computation state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import pickle
|
||||||
|
import zlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.memory import monitor
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointStrategy(Enum):
|
||||||
|
"""Checkpointing strategies."""
|
||||||
|
SQRT_N = "sqrt_n" # Checkpoint every √n iterations
|
||||||
|
MEMORY_PRESSURE = "memory_pressure" # Checkpoint when memory exceeds threshold
|
||||||
|
TIME_BASED = "time_based" # Checkpoint every k seconds
|
||||||
|
ADAPTIVE = "adaptive" # Dynamically adjust based on performance
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointConfig:
|
||||||
|
"""Configuration for checkpointing."""
|
||||||
|
strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N
|
||||||
|
checkpoint_dir: str = ".checkpoints"
|
||||||
|
compression: bool = True
|
||||||
|
compression_level: int = 6
|
||||||
|
memory_threshold: float = 0.8 # Fraction of available memory
|
||||||
|
time_interval: float = 60.0 # Seconds between checkpoints
|
||||||
|
min_interval: int = 100 # Minimum iterations between checkpoints
|
||||||
|
max_checkpoints: int = 10 # Maximum concurrent checkpoints
|
||||||
|
enable_recovery: bool = True
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointMetadata:
|
||||||
|
"""Metadata for a checkpoint."""
|
||||||
|
checkpoint_id: str
|
||||||
|
iteration: int
|
||||||
|
timestamp: float
|
||||||
|
state_size: int
|
||||||
|
compressed_size: int
|
||||||
|
compression_ratio: float
|
||||||
|
strategy_used: str
|
||||||
|
reason: str
|
||||||
|
state_vars: List[str]
|
||||||
|
performance_impact: Dict[str, float]
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointStats:
|
||||||
|
"""Statistics about checkpointing performance."""
|
||||||
|
total_checkpoints: int = 0
|
||||||
|
total_time: float = 0.0
|
||||||
|
total_size: int = 0
|
||||||
|
compressed_size: int = 0
|
||||||
|
average_compression: float = 0.0
|
||||||
|
memory_saved: int = 0
|
||||||
|
overhead_percent: float = 0.0
|
||||||
|
recoveries: int = 0
|
||||||
|
strategy_distribution: Dict[str, int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.strategy_distribution is None:
|
||||||
|
self.strategy_distribution = {}
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
"""
|
||||||
|
Manage checkpoints for long-running computations.
|
||||||
|
|
||||||
|
Implements Williams' √n checkpoint intervals for optimal space-time tradeoff.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
checkpoint_id: Optional[str] = None,
|
||||||
|
config: Optional[CheckpointConfig] = None):
|
||||||
|
"""
|
||||||
|
Initialize checkpoint manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: Unique ID for this computation
|
||||||
|
config: Checkpoint configuration
|
||||||
|
"""
|
||||||
|
self.checkpoint_id = checkpoint_id or str(uuid.uuid4())
|
||||||
|
self.config = config or CheckpointConfig()
|
||||||
|
self.stats = CheckpointStats()
|
||||||
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
self.checkpoint_path = Path(self.config.checkpoint_dir) / self.checkpoint_id
|
||||||
|
self.checkpoint_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Tracking
|
||||||
|
self._iteration = 0
|
||||||
|
self._last_checkpoint_iter = 0
|
||||||
|
self._last_checkpoint_time = time.time()
|
||||||
|
self._checkpoint_interval = None
|
||||||
|
self._total_iterations = None
|
||||||
|
|
||||||
|
def should_checkpoint(self, iteration: Optional[int] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if checkpoint is needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration: Current iteration (None to use internal counter)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if checkpoint should be created
|
||||||
|
"""
|
||||||
|
if iteration is not None:
|
||||||
|
self._iteration = iteration
|
||||||
|
else:
|
||||||
|
self._iteration += 1
|
||||||
|
|
||||||
|
# Check strategy
|
||||||
|
if self.config.strategy == CheckpointStrategy.SQRT_N:
|
||||||
|
return self._should_checkpoint_sqrt_n()
|
||||||
|
elif self.config.strategy == CheckpointStrategy.MEMORY_PRESSURE:
|
||||||
|
return self._should_checkpoint_memory()
|
||||||
|
elif self.config.strategy == CheckpointStrategy.TIME_BASED:
|
||||||
|
return self._should_checkpoint_time()
|
||||||
|
elif self.config.strategy == CheckpointStrategy.ADAPTIVE:
|
||||||
|
return self._should_checkpoint_adaptive()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _should_checkpoint_sqrt_n(self) -> bool:
|
||||||
|
"""Check if checkpoint needed using √n strategy."""
|
||||||
|
if self._checkpoint_interval is None:
|
||||||
|
# Estimate interval if total iterations unknown
|
||||||
|
if self._total_iterations:
|
||||||
|
self._checkpoint_interval = max(
|
||||||
|
self.config.min_interval,
|
||||||
|
int(self._total_iterations ** 0.5)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use adaptive estimation
|
||||||
|
self._checkpoint_interval = self.config.min_interval
|
||||||
|
|
||||||
|
iterations_since = self._iteration - self._last_checkpoint_iter
|
||||||
|
return iterations_since >= self._checkpoint_interval
|
||||||
|
|
||||||
|
def _should_checkpoint_memory(self) -> bool:
|
||||||
|
"""Check if checkpoint needed due to memory pressure."""
|
||||||
|
mem_info = monitor.get_memory_info()
|
||||||
|
return mem_info.percent > self.config.memory_threshold * 100
|
||||||
|
|
||||||
|
def _should_checkpoint_time(self) -> bool:
|
||||||
|
"""Check if checkpoint needed based on time."""
|
||||||
|
elapsed = time.time() - self._last_checkpoint_time
|
||||||
|
return elapsed >= self.config.time_interval
|
||||||
|
|
||||||
|
def _should_checkpoint_adaptive(self) -> bool:
|
||||||
|
"""Adaptive checkpointing based on multiple factors."""
|
||||||
|
# Combine strategies
|
||||||
|
sqrt_n = self._should_checkpoint_sqrt_n()
|
||||||
|
memory = self._should_checkpoint_memory()
|
||||||
|
time_based = self._should_checkpoint_time()
|
||||||
|
|
||||||
|
# Checkpoint if any condition is met
|
||||||
|
return sqrt_n or memory or time_based
|
||||||
|
|
||||||
|
def save(self, state: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||||
|
"""
|
||||||
|
Save checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State dictionary to save
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Checkpoint ID
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate checkpoint ID
|
||||||
|
checkpoint_file = self.checkpoint_path / f"checkpoint_{self._iteration}.pkl"
|
||||||
|
|
||||||
|
# Prepare state
|
||||||
|
state_bytes = pickle.dumps(state)
|
||||||
|
original_size = len(state_bytes)
|
||||||
|
|
||||||
|
# Compress if enabled
|
||||||
|
if self.config.compression:
|
||||||
|
state_bytes = zlib.compress(state_bytes, self.config.compression_level)
|
||||||
|
compressed_size = len(state_bytes)
|
||||||
|
compression_ratio = original_size / compressed_size
|
||||||
|
else:
|
||||||
|
compressed_size = original_size
|
||||||
|
compression_ratio = 1.0
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
with open(checkpoint_file, 'wb') as f:
|
||||||
|
f.write(state_bytes)
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
checkpoint_metadata = CheckpointMetadata(
|
||||||
|
checkpoint_id=str(checkpoint_file),
|
||||||
|
iteration=self._iteration,
|
||||||
|
timestamp=time.time(),
|
||||||
|
state_size=original_size,
|
||||||
|
compressed_size=compressed_size,
|
||||||
|
compression_ratio=compression_ratio,
|
||||||
|
strategy_used=self.config.strategy.value,
|
||||||
|
reason=self._get_checkpoint_reason(),
|
||||||
|
state_vars=list(state.keys()),
|
||||||
|
performance_impact={
|
||||||
|
'save_time': time.time() - start_time,
|
||||||
|
'compression_time': 0.0 # TODO: measure separately
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_file = checkpoint_file.with_suffix('.json')
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
json.dump(checkpoint_metadata.to_dict(), f, indent=2)
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
self._update_stats(checkpoint_metadata)
|
||||||
|
|
||||||
|
# Update tracking
|
||||||
|
self._last_checkpoint_iter = self._iteration
|
||||||
|
self._last_checkpoint_time = time.time()
|
||||||
|
|
||||||
|
# Clean old checkpoints
|
||||||
|
self._cleanup_old_checkpoints()
|
||||||
|
|
||||||
|
if self.config.verbose:
|
||||||
|
print(f"Checkpoint saved: iteration {self._iteration}, "
|
||||||
|
f"size {compressed_size / 1024:.1f}KB, "
|
||||||
|
f"compression {compression_ratio:.1f}x")
|
||||||
|
|
||||||
|
return str(checkpoint_file)
|
||||||
|
|
||||||
|
def load(self, checkpoint_id: Optional[str] = None) -> Tuple[Dict[str, Any], CheckpointMetadata]:
|
||||||
|
"""
|
||||||
|
Load checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: Specific checkpoint to load (None for latest)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (state, metadata)
|
||||||
|
"""
|
||||||
|
if checkpoint_id:
|
||||||
|
checkpoint_file = Path(checkpoint_id)
|
||||||
|
else:
|
||||||
|
# Find latest checkpoint
|
||||||
|
checkpoints = list(self.checkpoint_path.glob("checkpoint_*.pkl"))
|
||||||
|
if not checkpoints:
|
||||||
|
raise ValueError("No checkpoints found")
|
||||||
|
|
||||||
|
checkpoint_file = max(checkpoints, key=lambda p: p.stat().st_mtime)
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
metadata_file = checkpoint_file.with_suffix('.json')
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
metadata_dict = json.load(f)
|
||||||
|
metadata = CheckpointMetadata(**metadata_dict)
|
||||||
|
|
||||||
|
# Load state
|
||||||
|
with open(checkpoint_file, 'rb') as f:
|
||||||
|
state_bytes = f.read()
|
||||||
|
|
||||||
|
# Decompress if needed
|
||||||
|
if self.config.compression:
|
||||||
|
state_bytes = zlib.decompress(state_bytes)
|
||||||
|
|
||||||
|
state = pickle.loads(state_bytes)
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
self.stats.recoveries += 1
|
||||||
|
|
||||||
|
if self.config.verbose:
|
||||||
|
print(f"Checkpoint loaded: iteration {metadata.iteration}")
|
||||||
|
|
||||||
|
return state, metadata
|
||||||
|
|
||||||
|
def list_checkpoints(self) -> List[CheckpointMetadata]:
|
||||||
|
"""List all available checkpoints."""
|
||||||
|
metadata_files = self.checkpoint_path.glob("checkpoint_*.json")
|
||||||
|
checkpoints = []
|
||||||
|
|
||||||
|
for metadata_file in metadata_files:
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
metadata_dict = json.load(f)
|
||||||
|
checkpoints.append(CheckpointMetadata(**metadata_dict))
|
||||||
|
|
||||||
|
return sorted(checkpoints, key=lambda c: c.iteration)
|
||||||
|
|
||||||
|
def delete_checkpoint(self, checkpoint_id: str) -> None:
|
||||||
|
"""Delete specific checkpoint."""
|
||||||
|
checkpoint_file = Path(checkpoint_id)
|
||||||
|
metadata_file = checkpoint_file.with_suffix('.json')
|
||||||
|
|
||||||
|
if checkpoint_file.exists():
|
||||||
|
checkpoint_file.unlink()
|
||||||
|
if metadata_file.exists():
|
||||||
|
metadata_file.unlink()
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Clean up all checkpoints."""
|
||||||
|
import shutil
|
||||||
|
if self.checkpoint_path.exists():
|
||||||
|
shutil.rmtree(self.checkpoint_path)
|
||||||
|
|
||||||
|
def set_total_iterations(self, total: int) -> None:
|
||||||
|
"""
|
||||||
|
Set total iterations for optimal √n calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: Total number of iterations
|
||||||
|
"""
|
||||||
|
self._total_iterations = total
|
||||||
|
self._checkpoint_interval = max(
|
||||||
|
self.config.min_interval,
|
||||||
|
int(total ** 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.verbose:
|
||||||
|
print(f"Checkpoint interval set to {self._checkpoint_interval} "
|
||||||
|
f"(√{total} strategy)")
|
||||||
|
|
||||||
|
def get_stats(self) -> CheckpointStats:
|
||||||
|
"""Get checkpoint statistics."""
|
||||||
|
if self.stats.total_checkpoints > 0:
|
||||||
|
self.stats.average_compression = (
|
||||||
|
self.stats.total_size / self.stats.compressed_size
|
||||||
|
)
|
||||||
|
self.stats.overhead_percent = (
|
||||||
|
self.stats.total_time / (time.time() - self._last_checkpoint_time) * 100
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
def _get_checkpoint_reason(self) -> str:
|
||||||
|
"""Get reason for checkpoint."""
|
||||||
|
if self.config.strategy == CheckpointStrategy.SQRT_N:
|
||||||
|
return f"√n interval reached ({self._checkpoint_interval} iterations)"
|
||||||
|
elif self.config.strategy == CheckpointStrategy.MEMORY_PRESSURE:
|
||||||
|
mem_info = monitor.get_memory_info()
|
||||||
|
return f"Memory pressure: {mem_info.percent:.1f}%"
|
||||||
|
elif self.config.strategy == CheckpointStrategy.TIME_BASED:
|
||||||
|
return f"Time interval: {self.config.time_interval}s"
|
||||||
|
else:
|
||||||
|
return "Adaptive strategy triggered"
|
||||||
|
|
||||||
|
def _update_stats(self, metadata: CheckpointMetadata) -> None:
|
||||||
|
"""Update statistics."""
|
||||||
|
self.stats.total_checkpoints += 1
|
||||||
|
self.stats.total_time += metadata.performance_impact['save_time']
|
||||||
|
self.stats.total_size += metadata.state_size
|
||||||
|
self.stats.compressed_size += metadata.compressed_size
|
||||||
|
|
||||||
|
# Update strategy distribution
|
||||||
|
strategy = metadata.strategy_used
|
||||||
|
self.stats.strategy_distribution[strategy] = (
|
||||||
|
self.stats.strategy_distribution.get(strategy, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cleanup_old_checkpoints(self) -> None:
|
||||||
|
"""Remove old checkpoints to stay under limit."""
|
||||||
|
checkpoints = list(self.checkpoint_path.glob("checkpoint_*.pkl"))
|
||||||
|
|
||||||
|
if len(checkpoints) > self.config.max_checkpoints:
|
||||||
|
# Sort by modification time
|
||||||
|
checkpoints.sort(key=lambda p: p.stat().st_mtime)
|
||||||
|
|
||||||
|
# Remove oldest
|
||||||
|
for checkpoint in checkpoints[:-self.config.max_checkpoints]:
|
||||||
|
self.delete_checkpoint(str(checkpoint))
|
||||||
|
|
||||||
|
def create_recovery_code(self, func: Callable) -> str:
|
||||||
|
"""
|
||||||
|
Generate recovery code for function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to generate recovery for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recovery code as string
|
||||||
|
"""
|
||||||
|
recovery_template = '''
|
||||||
|
def recover_{func_name}(checkpoint_id=None):
|
||||||
|
"""Recover {func_name} from checkpoint."""
|
||||||
|
manager = CheckpointManager("{checkpoint_id}")
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
state, metadata = manager.load(checkpoint_id)
|
||||||
|
|
||||||
|
# Resume computation
|
||||||
|
iteration = metadata.iteration
|
||||||
|
|
||||||
|
# Restore state variables
|
||||||
|
{state_restoration}
|
||||||
|
|
||||||
|
# Continue from checkpoint
|
||||||
|
# TODO: Add continuation logic
|
||||||
|
|
||||||
|
return state
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Get function name
|
||||||
|
func_name = func.__name__
|
||||||
|
|
||||||
|
# Generate state restoration code
|
||||||
|
state_vars = []
|
||||||
|
if hasattr(func, '_checkpoint_state'):
|
||||||
|
state_vars = func._checkpoint_state
|
||||||
|
|
||||||
|
state_restoration = '\n '.join(
|
||||||
|
f"{var} = state.get('{var}')" for var in state_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
return recovery_template.format(
|
||||||
|
func_name=func_name,
|
||||||
|
checkpoint_id=self.checkpoint_id,
|
||||||
|
state_restoration=state_restoration
|
||||||
|
)
|
||||||
9
src/sqrtspace_spacetime/collections/__init__.py
Normal file
9
src/sqrtspace_spacetime/collections/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""Memory-efficient collections using √n space-time tradeoffs."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.collections.spacetime_array import SpaceTimeArray
|
||||||
|
from sqrtspace_spacetime.collections.spacetime_dict import SpaceTimeDict
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SpaceTimeArray",
|
||||||
|
"SpaceTimeDict",
|
||||||
|
]
|
||||||
273
src/sqrtspace_spacetime/collections/spacetime_array.py
Normal file
273
src/sqrtspace_spacetime/collections/spacetime_array.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
SpaceTimeArray: A memory-efficient array that automatically spills to disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import weakref
|
||||||
|
from typing import Any, Iterator, Optional, Union, List
|
||||||
|
from collections.abc import MutableSequence
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.memory import monitor, MemoryPressureLevel
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceTimeArray(MutableSequence):
|
||||||
|
"""
|
||||||
|
A list-like container that automatically manages memory usage by
|
||||||
|
spilling to disk when threshold is reached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances = weakref.WeakSet()
|
||||||
|
|
||||||
|
def __init__(self, threshold: Optional[Union[int, str]] = None, storage_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize SpaceTimeArray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: Number of items to keep in memory (None or 'auto' for automatic)
|
||||||
|
storage_path: Path for external storage (None for temp)
|
||||||
|
"""
|
||||||
|
if threshold == 'auto' or threshold is None:
|
||||||
|
self.threshold = config.calculate_chunk_size(10000)
|
||||||
|
else:
|
||||||
|
self.threshold = int(threshold)
|
||||||
|
self.storage_path = storage_path or config.external_storage_path
|
||||||
|
|
||||||
|
self._hot_data: List[Any] = []
|
||||||
|
self._cold_indices: set = set()
|
||||||
|
self._cold_storage: Optional[str] = None
|
||||||
|
self._length = 0
|
||||||
|
self._cold_file_handle = None
|
||||||
|
|
||||||
|
# Register for memory pressure handling
|
||||||
|
SpaceTimeArray._instances.add(self)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, index: Union[int, slice]) -> Any:
|
||||||
|
if isinstance(index, slice):
|
||||||
|
return [self[i] for i in range(*index.indices(len(self)))]
|
||||||
|
|
||||||
|
if index < 0:
|
||||||
|
index += self._length
|
||||||
|
|
||||||
|
if not 0 <= index < self._length:
|
||||||
|
raise IndexError("list index out of range")
|
||||||
|
|
||||||
|
# Check if in hot storage
|
||||||
|
if index not in self._cold_indices:
|
||||||
|
hot_index = index - len(self._cold_indices)
|
||||||
|
return self._hot_data[hot_index]
|
||||||
|
|
||||||
|
# Load from cold storage
|
||||||
|
return self._load_from_cold(index)
|
||||||
|
|
||||||
|
def __setitem__(self, index: Union[int, slice], value: Any) -> None:
|
||||||
|
if isinstance(index, slice):
|
||||||
|
for i, v in zip(range(*index.indices(len(self))), value):
|
||||||
|
self[i] = v
|
||||||
|
return
|
||||||
|
|
||||||
|
if index < 0:
|
||||||
|
index += self._length
|
||||||
|
|
||||||
|
if not 0 <= index < self._length:
|
||||||
|
raise IndexError("list assignment index out of range")
|
||||||
|
|
||||||
|
if index not in self._cold_indices:
|
||||||
|
hot_index = index - len(self._cold_indices)
|
||||||
|
self._hot_data[hot_index] = value
|
||||||
|
else:
|
||||||
|
# Update cold storage
|
||||||
|
self._update_cold(index, value)
|
||||||
|
|
||||||
|
def __delitem__(self, index: Union[int, slice]) -> None:
|
||||||
|
if isinstance(index, slice):
|
||||||
|
# Delete in reverse order to maintain indices
|
||||||
|
for i in reversed(range(*index.indices(len(self)))):
|
||||||
|
del self[i]
|
||||||
|
return
|
||||||
|
|
||||||
|
if index < 0:
|
||||||
|
index += self._length
|
||||||
|
|
||||||
|
if not 0 <= index < self._length:
|
||||||
|
raise IndexError("list index out of range")
|
||||||
|
|
||||||
|
# This is complex with cold storage, so we'll reload everything
|
||||||
|
all_data = list(self)
|
||||||
|
del all_data[index]
|
||||||
|
self.clear()
|
||||||
|
self.extend(all_data)
|
||||||
|
|
||||||
|
def insert(self, index: int, value: Any) -> None:
|
||||||
|
if index < 0:
|
||||||
|
index += self._length
|
||||||
|
index = max(0, min(index, self._length))
|
||||||
|
|
||||||
|
# Simple implementation: reload all, insert, save back
|
||||||
|
all_data = list(self)
|
||||||
|
all_data.insert(index, value)
|
||||||
|
self.clear()
|
||||||
|
self.extend(all_data)
|
||||||
|
|
||||||
|
def append(self, value: Any) -> None:
|
||||||
|
"""Append an item to the array."""
|
||||||
|
self._hot_data.append(value)
|
||||||
|
self._length += 1
|
||||||
|
|
||||||
|
# Check if we need to spill
|
||||||
|
if len(self._hot_data) > self.threshold:
|
||||||
|
self._check_and_spill()
|
||||||
|
|
||||||
|
def extend(self, iterable) -> None:
|
||||||
|
"""Extend array with items from iterable."""
|
||||||
|
for item in iterable:
|
||||||
|
self.append(item)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all items."""
|
||||||
|
self._hot_data.clear()
|
||||||
|
self._cold_indices.clear()
|
||||||
|
self._length = 0
|
||||||
|
|
||||||
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
self._cold_storage = None
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Any]:
|
||||||
|
"""Iterate over all items."""
|
||||||
|
# First yield cold items
|
||||||
|
for idx in sorted(self._cold_indices):
|
||||||
|
yield self._load_from_cold(idx)
|
||||||
|
|
||||||
|
# Then hot items
|
||||||
|
for item in self._hot_data:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def _check_and_spill(self) -> None:
|
||||||
|
"""Check memory pressure and spill to disk if needed."""
|
||||||
|
# Check memory pressure
|
||||||
|
pressure = monitor.check_memory_pressure()
|
||||||
|
|
||||||
|
if pressure >= MemoryPressureLevel.MEDIUM or len(self._hot_data) > self.threshold:
|
||||||
|
self._spill_to_disk()
|
||||||
|
|
||||||
|
def _spill_to_disk(self) -> None:
|
||||||
|
"""Spill oldest items to disk."""
|
||||||
|
if not self._cold_storage:
|
||||||
|
fd, self._cold_storage = tempfile.mkstemp(
|
||||||
|
suffix='.spacetime',
|
||||||
|
dir=self.storage_path
|
||||||
|
)
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
# Determine how many items to spill
|
||||||
|
spill_count = len(self._hot_data) // 2
|
||||||
|
|
||||||
|
# Load existing cold data
|
||||||
|
cold_data = {}
|
||||||
|
if os.path.exists(self._cold_storage):
|
||||||
|
with open(self._cold_storage, 'rb') as f:
|
||||||
|
try:
|
||||||
|
cold_data = pickle.load(f)
|
||||||
|
except EOFError:
|
||||||
|
cold_data = {}
|
||||||
|
|
||||||
|
# Move items to cold storage
|
||||||
|
current_cold_size = len(self._cold_indices)
|
||||||
|
for i in range(spill_count):
|
||||||
|
cold_data[current_cold_size + i] = self._hot_data[i]
|
||||||
|
self._cold_indices.add(current_cold_size + i)
|
||||||
|
|
||||||
|
# Remove from hot storage
|
||||||
|
self._hot_data = self._hot_data[spill_count:]
|
||||||
|
|
||||||
|
# Save cold data
|
||||||
|
with open(self._cold_storage, 'wb') as f:
|
||||||
|
pickle.dump(cold_data, f)
|
||||||
|
|
||||||
|
def _load_from_cold(self, index: int) -> Any:
|
||||||
|
"""Load an item from cold storage."""
|
||||||
|
if not self._cold_storage or not os.path.exists(self._cold_storage):
|
||||||
|
raise IndexError(f"Cold storage index {index} not found")
|
||||||
|
|
||||||
|
with open(self._cold_storage, 'rb') as f:
|
||||||
|
cold_data = pickle.load(f)
|
||||||
|
|
||||||
|
return cold_data.get(index)
|
||||||
|
|
||||||
|
def _update_cold(self, index: int, value: Any) -> None:
|
||||||
|
"""Update an item in cold storage."""
|
||||||
|
if not self._cold_storage:
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self._cold_storage, 'rb') as f:
|
||||||
|
cold_data = pickle.load(f)
|
||||||
|
|
||||||
|
cold_data[index] = value
|
||||||
|
|
||||||
|
with open(self._cold_storage, 'wb') as f:
|
||||||
|
pickle.dump(cold_data, f)
|
||||||
|
|
||||||
|
def memory_usage(self) -> int:
|
||||||
|
"""Estimate memory usage in bytes."""
|
||||||
|
# Rough estimate - actual usage may vary
|
||||||
|
return len(self._hot_data) * 50 # Assume 50 bytes per item average
|
||||||
|
|
||||||
|
def spill_to_disk(self, path: Optional[str] = None) -> None:
|
||||||
|
"""Force spill all data to disk."""
|
||||||
|
if path:
|
||||||
|
self.storage_path = path
|
||||||
|
|
||||||
|
while self._hot_data:
|
||||||
|
self._spill_to_disk()
|
||||||
|
|
||||||
|
def load_to_memory(self) -> None:
|
||||||
|
"""Load all data back to memory."""
|
||||||
|
if not self._cold_storage or not self._cold_indices:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load cold data
|
||||||
|
with open(self._cold_storage, 'rb') as f:
|
||||||
|
cold_data = pickle.load(f)
|
||||||
|
|
||||||
|
# Rebuild array in correct order
|
||||||
|
all_data = []
|
||||||
|
cold_count = 0
|
||||||
|
hot_count = 0
|
||||||
|
|
||||||
|
for i in range(self._length):
|
||||||
|
if i in self._cold_indices:
|
||||||
|
all_data.append(cold_data[i])
|
||||||
|
cold_count += 1
|
||||||
|
else:
|
||||||
|
all_data.append(self._hot_data[hot_count])
|
||||||
|
hot_count += 1
|
||||||
|
|
||||||
|
# Reset storage
|
||||||
|
self._hot_data = all_data
|
||||||
|
self._cold_indices.clear()
|
||||||
|
|
||||||
|
if os.path.exists(self._cold_storage):
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
self._cold_storage = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Clean up temporary files."""
|
||||||
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||||
|
try:
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_memory_pressure(cls, level: MemoryPressureLevel) -> None:
|
||||||
|
"""Class method to handle memory pressure for all instances."""
|
||||||
|
if level >= MemoryPressureLevel.HIGH:
|
||||||
|
for instance in cls._instances:
|
||||||
|
if instance._hot_data:
|
||||||
|
instance._spill_to_disk()
|
||||||
272
src/sqrtspace_spacetime/collections/spacetime_dict.py
Normal file
272
src/sqrtspace_spacetime/collections/spacetime_dict.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
SpaceTimeDict: A memory-efficient dictionary with automatic spillover.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Iterator, Optional, Tuple
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.memory import monitor, MemoryPressureLevel
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceTimeDict(MutableMapping):
|
||||||
|
"""
|
||||||
|
A dictionary that automatically manages memory by moving least-recently-used
|
||||||
|
items to disk storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
threshold: Optional[int] = None,
|
||||||
|
storage_path: Optional[str] = None,
|
||||||
|
use_lru: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize SpaceTimeDict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: Number of items to keep in memory
|
||||||
|
storage_path: Path for external storage
|
||||||
|
use_lru: Use LRU eviction policy
|
||||||
|
"""
|
||||||
|
self.threshold = threshold or config.calculate_chunk_size(10000)
|
||||||
|
self.storage_path = storage_path or config.external_storage_path
|
||||||
|
self.use_lru = use_lru
|
||||||
|
|
||||||
|
# Hot storage (in memory)
|
||||||
|
if use_lru:
|
||||||
|
self._hot_data: Dict[Any, Any] = OrderedDict()
|
||||||
|
else:
|
||||||
|
self._hot_data: Dict[Any, Any] = {}
|
||||||
|
|
||||||
|
# Cold storage tracking
|
||||||
|
self._cold_keys: set = set()
|
||||||
|
self._cold_storage: Optional[str] = None
|
||||||
|
self._cold_index: Dict[Any, Tuple[int, int]] = {} # key -> (offset, size)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self._hits = 0
|
||||||
|
self._misses = 0
|
||||||
|
self._last_access: Dict[Any, float] = {}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._hot_data) + len(self._cold_keys)
|
||||||
|
|
||||||
|
def __getitem__(self, key: Any) -> Any:
|
||||||
|
# Check hot storage first
|
||||||
|
if key in self._hot_data:
|
||||||
|
self._hits += 1
|
||||||
|
if self.use_lru:
|
||||||
|
# Move to end (most recent)
|
||||||
|
self._hot_data.move_to_end(key)
|
||||||
|
self._last_access[key] = time.time()
|
||||||
|
return self._hot_data[key]
|
||||||
|
|
||||||
|
# Check cold storage
|
||||||
|
if key in self._cold_keys:
|
||||||
|
self._misses += 1
|
||||||
|
value = self._load_from_cold(key)
|
||||||
|
|
||||||
|
# Promote to hot storage
|
||||||
|
self._promote_to_hot(key, value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: Any, value: Any) -> None:
|
||||||
|
# If key exists in cold storage, remove it
|
||||||
|
if key in self._cold_keys:
|
||||||
|
self._cold_keys.remove(key)
|
||||||
|
# Note: We don't actually remove from file to avoid rewriting
|
||||||
|
|
||||||
|
# Add to hot storage
|
||||||
|
self._hot_data[key] = value
|
||||||
|
self._last_access[key] = time.time()
|
||||||
|
|
||||||
|
# Check if we need to evict
|
||||||
|
if len(self._hot_data) > self.threshold:
|
||||||
|
self._evict_to_cold()
|
||||||
|
|
||||||
|
def __delitem__(self, key: Any) -> None:
|
||||||
|
if key in self._hot_data:
|
||||||
|
del self._hot_data[key]
|
||||||
|
self._last_access.pop(key, None)
|
||||||
|
elif key in self._cold_keys:
|
||||||
|
self._cold_keys.remove(key)
|
||||||
|
self._cold_index.pop(key, None)
|
||||||
|
else:
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Any]:
|
||||||
|
# Iterate hot keys first
|
||||||
|
yield from self._hot_data
|
||||||
|
# Then cold keys
|
||||||
|
yield from self._cold_keys
|
||||||
|
|
||||||
|
def __contains__(self, key: Any) -> bool:
|
||||||
|
return key in self._hot_data or key in self._cold_keys
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
"""Return a view of all keys."""
|
||||||
|
return list(self._hot_data.keys()) + list(self._cold_keys)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
"""Return a view of all values."""
|
||||||
|
for key in self:
|
||||||
|
yield self[key]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
"""Return a view of all key-value pairs."""
|
||||||
|
for key in self:
|
||||||
|
yield (key, self[key])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all items."""
|
||||||
|
self._hot_data.clear()
|
||||||
|
self._cold_keys.clear()
|
||||||
|
self._cold_index.clear()
|
||||||
|
self._last_access.clear()
|
||||||
|
|
||||||
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
self._cold_storage = None
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get usage statistics."""
|
||||||
|
total = self._hits + self._misses
|
||||||
|
hit_rate = self._hits / total if total > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hot_items": len(self._hot_data),
|
||||||
|
"cold_items": len(self._cold_keys),
|
||||||
|
"total_items": len(self),
|
||||||
|
"hits": self._hits,
|
||||||
|
"misses": self._misses,
|
||||||
|
"hit_rate": hit_rate,
|
||||||
|
"memory_usage": self.memory_usage(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _evict_to_cold(self) -> None:
|
||||||
|
"""Evict least recently used items to cold storage."""
|
||||||
|
evict_count = max(1, len(self._hot_data) // 4) # Evict 25%
|
||||||
|
|
||||||
|
if not self._cold_storage:
|
||||||
|
fd, self._cold_storage = tempfile.mkstemp(
|
||||||
|
suffix='.spacetime_dict',
|
||||||
|
dir=self.storage_path
|
||||||
|
)
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
# Select items to evict
|
||||||
|
if self.use_lru:
|
||||||
|
# OrderedDict: oldest items are first
|
||||||
|
evict_keys = list(self._hot_data.keys())[:evict_count]
|
||||||
|
else:
|
||||||
|
# Use access time
|
||||||
|
sorted_keys = sorted(
|
||||||
|
self._hot_data.keys(),
|
||||||
|
key=lambda k: self._last_access.get(k, 0)
|
||||||
|
)
|
||||||
|
evict_keys = sorted_keys[:evict_count]
|
||||||
|
|
||||||
|
# Write to cold storage
|
||||||
|
with open(self._cold_storage, 'ab') as f:
|
||||||
|
for key in evict_keys:
|
||||||
|
value = self._hot_data[key]
|
||||||
|
offset = f.tell()
|
||||||
|
|
||||||
|
# Serialize key-value pair
|
||||||
|
data = pickle.dumps((key, value))
|
||||||
|
size = len(data)
|
||||||
|
|
||||||
|
# Write size header and data
|
||||||
|
f.write(size.to_bytes(4, 'little'))
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
# Update indices
|
||||||
|
self._cold_index[key] = (offset, size + 4)
|
||||||
|
self._cold_keys.add(key)
|
||||||
|
|
||||||
|
# Remove from hot storage
|
||||||
|
del self._hot_data[key]
|
||||||
|
|
||||||
|
def _load_from_cold(self, key: Any) -> Any:
|
||||||
|
"""Load a value from cold storage."""
|
||||||
|
if key not in self._cold_index:
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
offset, size = self._cold_index[key]
|
||||||
|
|
||||||
|
with open(self._cold_storage, 'rb') as f:
|
||||||
|
f.seek(offset)
|
||||||
|
size_bytes = f.read(4)
|
||||||
|
data_size = int.from_bytes(size_bytes, 'little')
|
||||||
|
data = f.read(data_size)
|
||||||
|
|
||||||
|
stored_key, value = pickle.loads(data)
|
||||||
|
assert stored_key == key
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _promote_to_hot(self, key: Any, value: Any) -> None:
|
||||||
|
"""Promote a cold item to hot storage."""
|
||||||
|
# Remove from cold tracking
|
||||||
|
self._cold_keys.remove(key)
|
||||||
|
|
||||||
|
# Add to hot storage
|
||||||
|
self._hot_data[key] = value
|
||||||
|
self._last_access[key] = time.time()
|
||||||
|
|
||||||
|
# Check if we need to evict something else
|
||||||
|
if len(self._hot_data) > self.threshold:
|
||||||
|
self._evict_to_cold()
|
||||||
|
|
||||||
|
def memory_usage(self) -> int:
|
||||||
|
"""Estimate memory usage in bytes."""
|
||||||
|
# Rough estimate
|
||||||
|
return len(self._hot_data) * 100 # Assume 100 bytes per item average
|
||||||
|
|
||||||
|
def compact(self) -> None:
|
||||||
|
"""Compact cold storage by removing deleted entries."""
|
||||||
|
if not self._cold_storage or not self._cold_keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create new file
|
||||||
|
fd, new_storage = tempfile.mkstemp(
|
||||||
|
suffix='.spacetime_dict',
|
||||||
|
dir=self.storage_path
|
||||||
|
)
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
new_index = {}
|
||||||
|
|
||||||
|
# Copy only active entries
|
||||||
|
with open(new_storage, 'wb') as new_f:
|
||||||
|
for key in self._cold_keys:
|
||||||
|
value = self._load_from_cold(key)
|
||||||
|
offset = new_f.tell()
|
||||||
|
|
||||||
|
data = pickle.dumps((key, value))
|
||||||
|
size = len(data)
|
||||||
|
|
||||||
|
new_f.write(size.to_bytes(4, 'little'))
|
||||||
|
new_f.write(data)
|
||||||
|
|
||||||
|
new_index[key] = (offset, size + 4)
|
||||||
|
|
||||||
|
# Replace old storage
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
self._cold_storage = new_storage
|
||||||
|
self._cold_index = new_index
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Clean up temporary files."""
|
||||||
|
if self._cold_storage and os.path.exists(self._cold_storage):
|
||||||
|
try:
|
||||||
|
os.unlink(self._cold_storage)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
186
src/sqrtspace_spacetime/config.py
Normal file
186
src/sqrtspace_spacetime/config.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for SpaceTime operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import tempfile
|
||||||
|
from typing import Dict, Any, Optional, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkStrategy(Enum):
|
||||||
|
"""Strategy for determining chunk sizes."""
|
||||||
|
SQRT_N = "sqrt_n"
|
||||||
|
MEMORY_BASED = "memory_based"
|
||||||
|
FIXED = "fixed"
|
||||||
|
ADAPTIVE = "adaptive"
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionType(Enum):
|
||||||
|
"""Compression algorithms for external storage."""
|
||||||
|
NONE = "none"
|
||||||
|
GZIP = "gzip"
|
||||||
|
LZ4 = "lz4"
|
||||||
|
ZSTD = "zstd"
|
||||||
|
SNAPPY = "snappy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryHierarchy:
|
||||||
|
"""Memory hierarchy information."""
|
||||||
|
l1_cache: int = field(default_factory=lambda: 32 * 1024) # 32KB
|
||||||
|
l2_cache: int = field(default_factory=lambda: 256 * 1024) # 256KB
|
||||||
|
l3_cache: int = field(default_factory=lambda: 8 * 1024 * 1024) # 8MB
|
||||||
|
ram: int = field(default_factory=lambda: psutil.virtual_memory().total)
|
||||||
|
disk: int = field(default_factory=lambda: psutil.disk_usage('/').total)
|
||||||
|
|
||||||
|
def get_optimal_buffer_size(self, total_size: int) -> int:
|
||||||
|
"""Calculate optimal buffer size based on memory hierarchy."""
|
||||||
|
sqrt_n = int(math.sqrt(total_size))
|
||||||
|
|
||||||
|
# Try to fit in L3 cache
|
||||||
|
if sqrt_n <= self.l3_cache:
|
||||||
|
return sqrt_n
|
||||||
|
|
||||||
|
# Otherwise use a fraction of available RAM
|
||||||
|
available_ram = psutil.virtual_memory().available
|
||||||
|
return min(sqrt_n, int(available_ram * 0.1))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpaceTimeConfig:
|
||||||
|
"""Global configuration for SpaceTime operations."""
|
||||||
|
|
||||||
|
# Memory limits
|
||||||
|
memory_limit: int = field(default_factory=lambda: int(psutil.virtual_memory().total * 0.8))
|
||||||
|
memory_threshold: float = 0.8 # Trigger spillover at 80% usage
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
external_storage_path: str = field(default_factory=lambda: os.path.join(tempfile.gettempdir(), "spacetime"))
|
||||||
|
compression: CompressionType = CompressionType.GZIP
|
||||||
|
compression_level: int = 6
|
||||||
|
|
||||||
|
# Chunking
|
||||||
|
chunk_strategy: ChunkStrategy = ChunkStrategy.SQRT_N
|
||||||
|
fixed_chunk_size: int = 10000
|
||||||
|
min_chunk_size: int = 100
|
||||||
|
max_chunk_size: int = 10_000_000
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
enable_checkpointing: bool = True
|
||||||
|
checkpoint_interval: int = 60 # seconds
|
||||||
|
checkpoint_storage: str = "file" # "file", "redis", "s3"
|
||||||
|
|
||||||
|
# Performance
|
||||||
|
enable_profiling: bool = False
|
||||||
|
parallel_workers: int = field(default_factory=lambda: min(4, os.cpu_count() or 1))
|
||||||
|
prefetch_size: int = 2 # Number of chunks to prefetch
|
||||||
|
|
||||||
|
# Memory hierarchy
|
||||||
|
hierarchy: MemoryHierarchy = field(default_factory=MemoryHierarchy)
|
||||||
|
|
||||||
|
_instance: Optional['SpaceTimeConfig'] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize storage directory."""
|
||||||
|
os.makedirs(self.external_storage_path, exist_ok=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> 'SpaceTimeConfig':
|
||||||
|
"""Get singleton instance."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_defaults(cls, **kwargs) -> None:
|
||||||
|
"""Set default configuration values."""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(instance, key):
|
||||||
|
setattr(instance, key, value)
|
||||||
|
|
||||||
|
def calculate_chunk_size(self, total_size: int) -> int:
|
||||||
|
"""Calculate optimal chunk size based on strategy."""
|
||||||
|
if self.chunk_strategy == ChunkStrategy.FIXED:
|
||||||
|
return self.fixed_chunk_size
|
||||||
|
|
||||||
|
elif self.chunk_strategy == ChunkStrategy.SQRT_N:
|
||||||
|
sqrt_n = int(math.sqrt(total_size))
|
||||||
|
return max(self.min_chunk_size, min(sqrt_n, self.max_chunk_size))
|
||||||
|
|
||||||
|
elif self.chunk_strategy == ChunkStrategy.MEMORY_BASED:
|
||||||
|
available = psutil.virtual_memory().available
|
||||||
|
# Use 10% of available memory for chunks
|
||||||
|
chunk_size = int(available * 0.1 / 8) # Assume 8 bytes per item
|
||||||
|
return max(self.min_chunk_size, min(chunk_size, self.max_chunk_size))
|
||||||
|
|
||||||
|
elif self.chunk_strategy == ChunkStrategy.ADAPTIVE:
|
||||||
|
# Start with sqrt(n) and adjust based on memory pressure
|
||||||
|
base_size = int(math.sqrt(total_size))
|
||||||
|
memory_percent = psutil.virtual_memory().percent
|
||||||
|
|
||||||
|
if memory_percent > 90:
|
||||||
|
# Very high pressure: use minimum size
|
||||||
|
return self.min_chunk_size
|
||||||
|
elif memory_percent > 70:
|
||||||
|
# High pressure: reduce chunk size
|
||||||
|
return max(self.min_chunk_size, base_size // 2)
|
||||||
|
elif memory_percent < 30:
|
||||||
|
# Low pressure: increase chunk size
|
||||||
|
return min(self.max_chunk_size, base_size * 2)
|
||||||
|
else:
|
||||||
|
# Normal pressure: use sqrt(n)
|
||||||
|
return max(self.min_chunk_size, min(base_size, self.max_chunk_size))
|
||||||
|
|
||||||
|
return self.fixed_chunk_size
|
||||||
|
|
||||||
|
def get_compression_module(self):
|
||||||
|
"""Get compression module based on configuration."""
|
||||||
|
if self.compression == CompressionType.GZIP:
|
||||||
|
import gzip
|
||||||
|
return gzip
|
||||||
|
elif self.compression == CompressionType.LZ4:
|
||||||
|
try:
|
||||||
|
import lz4.frame
|
||||||
|
return lz4.frame
|
||||||
|
except ImportError:
|
||||||
|
import gzip
|
||||||
|
return gzip
|
||||||
|
elif self.compression == CompressionType.ZSTD:
|
||||||
|
try:
|
||||||
|
import zstandard
|
||||||
|
return zstandard
|
||||||
|
except ImportError:
|
||||||
|
import gzip
|
||||||
|
return gzip
|
||||||
|
elif self.compression == CompressionType.SNAPPY:
|
||||||
|
try:
|
||||||
|
import snappy
|
||||||
|
return snappy
|
||||||
|
except ImportError:
|
||||||
|
import gzip
|
||||||
|
return gzip
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def format_bytes(self, bytes: int) -> str:
|
||||||
|
"""Format bytes as human-readable string."""
|
||||||
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||||
|
if bytes < 1024.0:
|
||||||
|
return f"{bytes:.2f} {unit}"
|
||||||
|
bytes /= 1024.0
|
||||||
|
return f"{bytes:.2f} PB"
|
||||||
|
|
||||||
|
def get_williams_bound(self, time_complexity: int) -> int:
|
||||||
|
"""Calculate Williams' space bound: SPACE[√(t log t)]."""
|
||||||
|
if time_complexity <= 0:
|
||||||
|
return 1
|
||||||
|
return int(math.sqrt(time_complexity * math.log2(max(2, time_complexity))))
|
||||||
|
|
||||||
|
|
||||||
|
# Global configuration instance
|
||||||
|
config = SpaceTimeConfig.get_instance()
|
||||||
27
src/sqrtspace_spacetime/memory/__init__.py
Normal file
27
src/sqrtspace_spacetime/memory/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""Memory monitoring and pressure handling for SpaceTime."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.memory.monitor import (
|
||||||
|
MemoryMonitor,
|
||||||
|
MemoryPressureLevel,
|
||||||
|
MemoryInfo,
|
||||||
|
MemoryPressureHandler,
|
||||||
|
monitor,
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.memory.handlers import (
|
||||||
|
LoggingHandler,
|
||||||
|
CacheEvictionHandler,
|
||||||
|
GarbageCollectionHandler,
|
||||||
|
ThrottlingHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MemoryMonitor",
|
||||||
|
"MemoryPressureLevel",
|
||||||
|
"MemoryInfo",
|
||||||
|
"MemoryPressureHandler",
|
||||||
|
"LoggingHandler",
|
||||||
|
"CacheEvictionHandler",
|
||||||
|
"GarbageCollectionHandler",
|
||||||
|
"ThrottlingHandler",
|
||||||
|
"monitor",
|
||||||
|
]
|
||||||
168
src/sqrtspace_spacetime/memory/handlers.py
Normal file
168
src/sqrtspace_spacetime/memory/handlers.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
"""Memory pressure handlers."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Callable, Optional
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.memory.monitor import (
|
||||||
|
MemoryPressureHandler,
|
||||||
|
MemoryPressureLevel,
|
||||||
|
MemoryInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingHandler(MemoryPressureHandler):
|
||||||
|
"""Log memory pressure events."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
min_level: MemoryPressureLevel = MemoryPressureLevel.MEDIUM):
|
||||||
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
self.min_level = min_level
|
||||||
|
self._last_log = {}
|
||||||
|
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
return level >= self.min_level
|
||||||
|
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
# Avoid spamming logs - only log if level changed or 60s passed
|
||||||
|
last_time = self._last_log.get(level, 0)
|
||||||
|
if time.time() - last_time < 60 and level in self._last_log:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._last_log[level] = time.time()
|
||||||
|
|
||||||
|
if level == MemoryPressureLevel.CRITICAL:
|
||||||
|
self.logger.critical(f"CRITICAL memory pressure: {info}")
|
||||||
|
elif level == MemoryPressureLevel.HIGH:
|
||||||
|
self.logger.error(f"HIGH memory pressure: {info}")
|
||||||
|
elif level == MemoryPressureLevel.MEDIUM:
|
||||||
|
self.logger.warning(f"MEDIUM memory pressure: {info}")
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Memory pressure: {info}")
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEvictionHandler(MemoryPressureHandler):
|
||||||
|
"""Evict cached data under memory pressure."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._caches: List[WeakValueDictionary] = []
|
||||||
|
self._eviction_rates = {
|
||||||
|
MemoryPressureLevel.LOW: 0.1, # Evict 10%
|
||||||
|
MemoryPressureLevel.MEDIUM: 0.25, # Evict 25%
|
||||||
|
MemoryPressureLevel.HIGH: 0.5, # Evict 50%
|
||||||
|
MemoryPressureLevel.CRITICAL: 0.9, # Evict 90%
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_cache(self, cache: Dict[Any, Any]) -> None:
|
||||||
|
"""Register a cache for eviction."""
|
||||||
|
self._caches.append(WeakValueDictionary(cache))
|
||||||
|
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
return level >= MemoryPressureLevel.LOW and self._caches
|
||||||
|
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
eviction_rate = self._eviction_rates.get(level, 0)
|
||||||
|
if eviction_rate == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for cache in self._caches:
|
||||||
|
if not cache:
|
||||||
|
continue
|
||||||
|
|
||||||
|
size = len(cache)
|
||||||
|
if size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Evict entries
|
||||||
|
to_evict = int(size * eviction_rate)
|
||||||
|
keys = list(cache.keys())[:to_evict]
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
cache.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
class GarbageCollectionHandler(MemoryPressureHandler):
|
||||||
|
"""Trigger garbage collection under memory pressure."""
|
||||||
|
|
||||||
|
def __init__(self, min_interval: float = 5.0):
|
||||||
|
self.min_interval = min_interval
|
||||||
|
self._last_gc = 0
|
||||||
|
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
return level >= MemoryPressureLevel.MEDIUM
|
||||||
|
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Don't GC too frequently
|
||||||
|
if now - self._last_gc < self.min_interval:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._last_gc = now
|
||||||
|
|
||||||
|
# More aggressive GC for higher pressure
|
||||||
|
if level >= MemoryPressureLevel.HIGH:
|
||||||
|
# Full collection
|
||||||
|
gc.collect(2)
|
||||||
|
else:
|
||||||
|
# Quick collection
|
||||||
|
gc.collect(0)
|
||||||
|
|
||||||
|
|
||||||
|
class ThrottlingHandler(MemoryPressureHandler):
|
||||||
|
"""Throttle operations under memory pressure."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._throttle_rates = {
|
||||||
|
MemoryPressureLevel.LOW: 0, # No throttling
|
||||||
|
MemoryPressureLevel.MEDIUM: 0.1, # 100ms delay
|
||||||
|
MemoryPressureLevel.HIGH: 0.5, # 500ms delay
|
||||||
|
MemoryPressureLevel.CRITICAL: 2.0, # 2s delay
|
||||||
|
}
|
||||||
|
self._callbacks: List[Callable[[float], None]] = []
|
||||||
|
|
||||||
|
def register_callback(self, callback: Callable[[float], None]) -> None:
|
||||||
|
"""Register callback to be notified of throttle rates."""
|
||||||
|
self._callbacks.append(callback)
|
||||||
|
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
return level >= MemoryPressureLevel.MEDIUM
|
||||||
|
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
delay = self._throttle_rates.get(level, 0)
|
||||||
|
|
||||||
|
# Notify callbacks
|
||||||
|
for callback in self._callbacks:
|
||||||
|
try:
|
||||||
|
callback(delay)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SpillToDiskHandler(MemoryPressureHandler):
|
||||||
|
"""Spill data to disk under memory pressure."""
|
||||||
|
|
||||||
|
def __init__(self, spill_path: Optional[str] = None):
|
||||||
|
self.spill_path = spill_path
|
||||||
|
self._spillable_objects: List[Any] = []
|
||||||
|
|
||||||
|
def register_spillable(self, obj: Any) -> None:
|
||||||
|
"""Register an object that can spill to disk."""
|
||||||
|
if hasattr(obj, 'spill_to_disk'):
|
||||||
|
self._spillable_objects.append(obj)
|
||||||
|
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
return level >= MemoryPressureLevel.HIGH and self._spillable_objects
|
||||||
|
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
for obj in self._spillable_objects:
|
||||||
|
try:
|
||||||
|
if hasattr(obj, 'memory_usage'):
|
||||||
|
# Only spill large objects
|
||||||
|
if obj.memory_usage() > 10 * 1024 * 1024: # 10MB
|
||||||
|
obj.spill_to_disk(self.spill_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
247
src/sqrtspace_spacetime/memory/monitor.py
Normal file
247
src/sqrtspace_spacetime/memory/monitor.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
"""Memory monitoring and pressure detection."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Callable, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryPressureLevel(Enum):
|
||||||
|
"""Memory pressure levels."""
|
||||||
|
NONE = 0
|
||||||
|
LOW = 1
|
||||||
|
MEDIUM = 2
|
||||||
|
HIGH = 3
|
||||||
|
CRITICAL = 4
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
if not isinstance(other, MemoryPressureLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
if not isinstance(other, MemoryPressureLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value >= other.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryInfo:
|
||||||
|
"""Memory usage information."""
|
||||||
|
total: int
|
||||||
|
available: int
|
||||||
|
used: int
|
||||||
|
percent: float
|
||||||
|
pressure_level: MemoryPressureLevel
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def used_gb(self) -> float:
|
||||||
|
return self.used / (1024 ** 3)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_gb(self) -> float:
|
||||||
|
return self.available / (1024 ** 3)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (f"Memory: {self.percent:.1f}% used "
|
||||||
|
f"({self.used_gb:.2f}/{self.available_gb:.2f} GB), "
|
||||||
|
f"Pressure: {self.pressure_level.name}")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryPressureHandler(ABC):
|
||||||
|
"""Abstract base class for memory pressure handlers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def can_handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> bool:
|
||||||
|
"""Check if this handler should handle the given pressure level."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, level: MemoryPressureLevel, info: MemoryInfo) -> None:
|
||||||
|
"""Handle memory pressure."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMonitor:
|
||||||
|
"""Monitor system memory and detect pressure."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
check_interval: float = 1.0,
|
||||||
|
memory_limit: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Initialize memory monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_interval: Seconds between checks
|
||||||
|
memory_limit: Custom memory limit in bytes (None for system limit)
|
||||||
|
"""
|
||||||
|
self.check_interval = check_interval
|
||||||
|
self.memory_limit = memory_limit or config.memory_limit
|
||||||
|
self.handlers: List[MemoryPressureHandler] = []
|
||||||
|
self._monitoring = False
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
self._last_check = 0.0
|
||||||
|
self._history: List[MemoryInfo] = []
|
||||||
|
self._max_history = 100
|
||||||
|
|
||||||
|
def add_handler(self, handler: MemoryPressureHandler) -> None:
|
||||||
|
"""Add a memory pressure handler."""
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_handler(self, handler: MemoryPressureHandler) -> None:
|
||||||
|
"""Remove a memory pressure handler."""
|
||||||
|
if handler in self.handlers:
|
||||||
|
self.handlers.remove(handler)
|
||||||
|
|
||||||
|
def get_memory_info(self) -> MemoryInfo:
|
||||||
|
"""Get current memory information."""
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
|
||||||
|
# Use configured limit if lower than system memory
|
||||||
|
total = min(mem.total, self.memory_limit)
|
||||||
|
used = mem.used
|
||||||
|
available = total - used
|
||||||
|
percent = (used / total) * 100
|
||||||
|
|
||||||
|
# Determine pressure level
|
||||||
|
if percent >= 95:
|
||||||
|
level = MemoryPressureLevel.CRITICAL
|
||||||
|
elif percent >= 85:
|
||||||
|
level = MemoryPressureLevel.HIGH
|
||||||
|
elif percent >= 70:
|
||||||
|
level = MemoryPressureLevel.MEDIUM
|
||||||
|
elif percent >= 50:
|
||||||
|
level = MemoryPressureLevel.LOW
|
||||||
|
else:
|
||||||
|
level = MemoryPressureLevel.NONE
|
||||||
|
|
||||||
|
return MemoryInfo(
|
||||||
|
total=total,
|
||||||
|
available=available,
|
||||||
|
used=used,
|
||||||
|
percent=percent,
|
||||||
|
pressure_level=level,
|
||||||
|
timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_memory_pressure(self) -> MemoryPressureLevel:
|
||||||
|
"""Check current memory pressure and notify handlers."""
|
||||||
|
info = self.get_memory_info()
|
||||||
|
|
||||||
|
# Add to history
|
||||||
|
self._history.append(info)
|
||||||
|
if len(self._history) > self._max_history:
|
||||||
|
self._history.pop(0)
|
||||||
|
|
||||||
|
# Notify handlers
|
||||||
|
for handler in self.handlers:
|
||||||
|
if handler.can_handle(info.pressure_level, info):
|
||||||
|
try:
|
||||||
|
handler.handle(info.pressure_level, info)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't crash on handler errors
|
||||||
|
print(f"Handler error: {e}")
|
||||||
|
|
||||||
|
self._last_check = time.time()
|
||||||
|
return info.pressure_level
|
||||||
|
|
||||||
|
def should_check(self) -> bool:
|
||||||
|
"""Check if enough time has passed for next check."""
|
||||||
|
return time.time() - self._last_check >= self.check_interval
|
||||||
|
|
||||||
|
def start_monitoring(self) -> None:
|
||||||
|
"""Start background monitoring thread."""
|
||||||
|
if self._monitoring:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._monitoring = True
|
||||||
|
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop_monitoring(self) -> None:
|
||||||
|
"""Stop background monitoring."""
|
||||||
|
self._monitoring = False
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
def _monitor_loop(self) -> None:
|
||||||
|
"""Background monitoring loop."""
|
||||||
|
while self._monitoring:
|
||||||
|
try:
|
||||||
|
self.check_memory_pressure()
|
||||||
|
time.sleep(self.check_interval)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Monitoring error: {e}")
|
||||||
|
time.sleep(self.check_interval)
|
||||||
|
|
||||||
|
def get_memory_trend(self, seconds: int = 60) -> Dict[str, float]:
|
||||||
|
"""Get memory usage trend over past N seconds."""
|
||||||
|
if not self._history:
|
||||||
|
return {"avg_percent": 0, "max_percent": 0, "trend": 0}
|
||||||
|
|
||||||
|
cutoff = time.time() - seconds
|
||||||
|
recent = [h for h in self._history if h.timestamp >= cutoff]
|
||||||
|
|
||||||
|
if not recent:
|
||||||
|
return {"avg_percent": 0, "max_percent": 0, "trend": 0}
|
||||||
|
|
||||||
|
percents = [h.percent for h in recent]
|
||||||
|
avg_percent = sum(percents) / len(percents)
|
||||||
|
max_percent = max(percents)
|
||||||
|
|
||||||
|
# Calculate trend (positive = increasing usage)
|
||||||
|
if len(recent) >= 2:
|
||||||
|
first_half = percents[:len(percents)//2]
|
||||||
|
second_half = percents[len(percents)//2:]
|
||||||
|
trend = sum(second_half)/len(second_half) - sum(first_half)/len(first_half)
|
||||||
|
else:
|
||||||
|
trend = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_percent": avg_percent,
|
||||||
|
"max_percent": max_percent,
|
||||||
|
"trend": trend
|
||||||
|
}
|
||||||
|
|
||||||
|
def force_gc(self) -> int:
|
||||||
|
"""Force garbage collection and return bytes freed."""
|
||||||
|
before = self.get_memory_info().used
|
||||||
|
gc.collect()
|
||||||
|
after = self.get_memory_info().used
|
||||||
|
return max(0, before - after)
|
||||||
|
|
||||||
|
def wait_for_memory(self, required_bytes: int, timeout: float = 30) -> bool:
|
||||||
|
"""
|
||||||
|
Wait for required memory to become available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if memory became available, False if timeout
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
info = self.get_memory_info()
|
||||||
|
if info.available >= required_bytes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try to free memory
|
||||||
|
self.force_gc()
|
||||||
|
|
||||||
|
# Let handlers do their work
|
||||||
|
self.check_memory_pressure()
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Global monitor instance
|
||||||
|
monitor = MemoryMonitor()
|
||||||
23
src/sqrtspace_spacetime/ml/__init__.py
Normal file
23
src/sqrtspace_spacetime/ml/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Machine Learning memory optimization utilities."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.ml.optimizer import (
|
||||||
|
MLMemoryOptimizer,
|
||||||
|
ModelProfile,
|
||||||
|
OptimizationPlan,
|
||||||
|
TrainingConfig,
|
||||||
|
MemoryOptimizationStrategy,
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.ml.checkpointing import (
|
||||||
|
GradientCheckpointer,
|
||||||
|
CheckpointStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MLMemoryOptimizer",
|
||||||
|
"ModelProfile",
|
||||||
|
"OptimizationPlan",
|
||||||
|
"TrainingConfig",
|
||||||
|
"MemoryOptimizationStrategy",
|
||||||
|
"GradientCheckpointer",
|
||||||
|
"CheckpointStrategy",
|
||||||
|
]
|
||||||
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
286
src/sqrtspace_spacetime/ml/checkpointing.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
Gradient checkpointing utilities for memory-efficient training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# Framework imports
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
HAS_TF = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TF = False
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointStrategy(Enum):
|
||||||
|
"""Checkpointing strategies."""
|
||||||
|
SQRT_N = "sqrt_n" # Checkpoint every √n layers
|
||||||
|
UNIFORM = "uniform" # Uniform intervals
|
||||||
|
MEMORY_BASED = "memory" # Based on memory usage
|
||||||
|
SELECTIVE = "selective" # Only expensive layers
|
||||||
|
|
||||||
|
|
||||||
|
class GradientCheckpointer:
|
||||||
|
"""
|
||||||
|
Gradient checkpointing for memory-efficient training.
|
||||||
|
|
||||||
|
Implements Williams' √n strategy for optimal space-time tradeoff.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategy: CheckpointStrategy = CheckpointStrategy.SQRT_N):
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
def apply_checkpointing(self,
|
||||||
|
model: Any,
|
||||||
|
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Apply gradient checkpointing to model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Neural network model
|
||||||
|
checkpoint_layers: Specific layers to checkpoint (None for auto)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model with checkpointing applied
|
||||||
|
"""
|
||||||
|
if HAS_TORCH and isinstance(model, nn.Module):
|
||||||
|
return self._apply_torch_checkpointing(model, checkpoint_layers)
|
||||||
|
elif HAS_TF:
|
||||||
|
return self._apply_tf_checkpointing(model, checkpoint_layers)
|
||||||
|
else:
|
||||||
|
print("Warning: No supported framework found for checkpointing")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _apply_torch_checkpointing(self,
|
||||||
|
model: nn.Module,
|
||||||
|
checkpoint_layers: Optional[List[str]] = None) -> nn.Module:
|
||||||
|
"""Apply checkpointing to PyTorch model."""
|
||||||
|
if checkpoint_layers is None:
|
||||||
|
checkpoint_layers = self._select_checkpoint_layers_torch(model)
|
||||||
|
|
||||||
|
# Wrap forward methods of selected layers
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in checkpoint_layers:
|
||||||
|
self._wrap_module_torch(module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _wrap_module_torch(self, module: nn.Module) -> None:
|
||||||
|
"""Wrap PyTorch module with gradient checkpointing."""
|
||||||
|
original_forward = module.forward
|
||||||
|
|
||||||
|
def checkpointed_forward(*args, **kwargs):
|
||||||
|
# Use PyTorch's checkpoint function
|
||||||
|
if module.training:
|
||||||
|
return checkpoint(original_forward, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
module.forward = checkpointed_forward
|
||||||
|
|
||||||
|
def _apply_tf_checkpointing(self,
|
||||||
|
model: Any,
|
||||||
|
checkpoint_layers: Optional[List[str]] = None) -> Any:
|
||||||
|
"""Apply checkpointing to TensorFlow model."""
|
||||||
|
if checkpoint_layers is None:
|
||||||
|
checkpoint_layers = self._select_checkpoint_layers_tf(model)
|
||||||
|
|
||||||
|
# TensorFlow implementation
|
||||||
|
# Note: TF2 has different checkpointing mechanism
|
||||||
|
print(f"TensorFlow checkpointing selected {len(checkpoint_layers)} layers")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _select_checkpoint_layers_torch(self, model: nn.Module) -> List[str]:
|
||||||
|
"""Select layers to checkpoint for PyTorch model."""
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
# Get all layers
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if len(list(module.children())) == 0: # Leaf modules
|
||||||
|
layers.append((name, module))
|
||||||
|
|
||||||
|
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||||
|
# Select √n evenly spaced layers
|
||||||
|
n = len(layers)
|
||||||
|
if n == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
interval = max(1, int(math.sqrt(n)))
|
||||||
|
selected = []
|
||||||
|
|
||||||
|
for i in range(0, n, interval):
|
||||||
|
name, module = layers[i]
|
||||||
|
if self._can_checkpoint_module(module):
|
||||||
|
selected.append(name)
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
elif self.strategy == CheckpointStrategy.MEMORY_BASED:
|
||||||
|
# Select layers with large activation memory
|
||||||
|
memory_layers = []
|
||||||
|
|
||||||
|
for name, module in layers:
|
||||||
|
memory = self._estimate_module_memory(module)
|
||||||
|
memory_layers.append((name, memory))
|
||||||
|
|
||||||
|
# Sort by memory and select top √n
|
||||||
|
memory_layers.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
n_checkpoint = max(1, int(math.sqrt(len(memory_layers))))
|
||||||
|
|
||||||
|
return [name for name, _ in memory_layers[:n_checkpoint]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default: checkpoint all eligible layers
|
||||||
|
return [name for name, module in layers if self._can_checkpoint_module(module)]
|
||||||
|
|
||||||
|
def _select_checkpoint_layers_tf(self, model: Any) -> List[str]:
|
||||||
|
"""Select layers to checkpoint for TensorFlow model."""
|
||||||
|
if not hasattr(model, 'layers'):
|
||||||
|
return []
|
||||||
|
|
||||||
|
layers = [(layer.name, layer) for layer in model.layers]
|
||||||
|
|
||||||
|
if self.strategy == CheckpointStrategy.SQRT_N:
|
||||||
|
n = len(layers)
|
||||||
|
interval = max(1, int(math.sqrt(n)))
|
||||||
|
|
||||||
|
selected = []
|
||||||
|
for i in range(0, n, interval):
|
||||||
|
name, layer = layers[i]
|
||||||
|
selected.append(name)
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
return [name for name, _ in layers]
|
||||||
|
|
||||||
|
def _can_checkpoint_module(self, module: Any) -> bool:
|
||||||
|
"""Check if module can be safely checkpointed."""
|
||||||
|
if HAS_TORCH:
|
||||||
|
# Avoid checkpointing modules with randomness
|
||||||
|
no_checkpoint = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
|
||||||
|
return not isinstance(module, no_checkpoint)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _estimate_module_memory(self, module: Any) -> int:
|
||||||
|
"""Estimate memory usage of module activations."""
|
||||||
|
if HAS_TORCH and isinstance(module, nn.Module):
|
||||||
|
# Estimate based on output size
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
return module.out_features * 4 # FP32
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
# Rough estimate
|
||||||
|
return module.out_channels * 100 * 100 * 4
|
||||||
|
else:
|
||||||
|
# Default estimate
|
||||||
|
params = sum(p.numel() for p in module.parameters())
|
||||||
|
return params * 4
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_checkpoint_segments(model: Any,
|
||||||
|
n_segments: Optional[int] = None) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
Create checkpoint segments using √n strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Neural network model
|
||||||
|
n_segments: Number of segments (None for √n)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of layer name segments
|
||||||
|
"""
|
||||||
|
# Get all layers
|
||||||
|
if HAS_TORCH and isinstance(model, nn.Module):
|
||||||
|
all_layers = [name for name, _ in model.named_modules()
|
||||||
|
if len(list(_.children())) == 0]
|
||||||
|
elif HAS_TF and hasattr(model, 'layers'):
|
||||||
|
all_layers = [layer.name for layer in model.layers]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n = len(all_layers)
|
||||||
|
if n == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use √n segments by default
|
||||||
|
if n_segments is None:
|
||||||
|
n_segments = max(1, int(math.sqrt(n)))
|
||||||
|
|
||||||
|
# Create segments
|
||||||
|
segment_size = max(1, n // n_segments)
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
for i in range(0, n, segment_size):
|
||||||
|
segment = all_layers[i:i + segment_size]
|
||||||
|
if segment:
|
||||||
|
segments.append(segment)
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_sequential(modules: List[Any],
|
||||||
|
input: Any,
|
||||||
|
segments: Optional[int] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Checkpoint a sequential model using √n segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules: List of modules to execute sequentially
|
||||||
|
input: Input tensor
|
||||||
|
segments: Number of checkpoint segments (None for √n)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor
|
||||||
|
"""
|
||||||
|
if not HAS_TORCH:
|
||||||
|
# Fallback to normal execution
|
||||||
|
x = input
|
||||||
|
for module in modules:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
n = len(modules)
|
||||||
|
if n == 0:
|
||||||
|
return input
|
||||||
|
|
||||||
|
# Use √n segments
|
||||||
|
if segments is None:
|
||||||
|
segments = max(1, int(math.sqrt(n)))
|
||||||
|
|
||||||
|
segment_size = max(1, n // segments)
|
||||||
|
|
||||||
|
# Execute with checkpointing
|
||||||
|
x = input
|
||||||
|
for i in range(0, n, segment_size):
|
||||||
|
segment = modules[i:i + segment_size]
|
||||||
|
|
||||||
|
if len(segment) == 1:
|
||||||
|
# Single module
|
||||||
|
if modules[0].training:
|
||||||
|
x = checkpoint(segment[0], x)
|
||||||
|
else:
|
||||||
|
x = segment[0](x)
|
||||||
|
else:
|
||||||
|
# Multiple modules - create sequential wrapper
|
||||||
|
def run_segment(x, *modules):
|
||||||
|
for module in modules:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if modules[0].training:
|
||||||
|
x = checkpoint(run_segment, x, *segment)
|
||||||
|
else:
|
||||||
|
x = run_segment(x, *segment)
|
||||||
|
|
||||||
|
return x
|
||||||
488
src/sqrtspace_spacetime/ml/optimizer.py
Normal file
488
src/sqrtspace_spacetime/ml/optimizer.py
Normal file
@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
ML Training Memory Optimizer: Optimize neural network training memory usage.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Layer-by-layer memory profiling
|
||||||
|
- Automatic gradient checkpointing with √n intervals
|
||||||
|
- Mixed precision configuration
|
||||||
|
- Batch size optimization
|
||||||
|
- Framework-agnostic (PyTorch/TensorFlow)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import psutil
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.memory import monitor
|
||||||
|
|
||||||
|
# Try to import ML frameworks
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
HAS_TF = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TF = False
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryOptimizationStrategy(Enum):
|
||||||
|
"""Memory optimization strategies for ML training."""
|
||||||
|
GRADIENT_CHECKPOINTING = "gradient_checkpointing" # Recompute activations
|
||||||
|
MIXED_PRECISION = "mixed_precision" # FP16/BF16 training
|
||||||
|
GRADIENT_ACCUMULATION = "gradient_accumulation" # Smaller effective batch
|
||||||
|
MODEL_SHARDING = "model_sharding" # Distribute layers
|
||||||
|
ACTIVATION_COMPRESSION = "activation_compression" # Compress intermediate
|
||||||
|
DYNAMIC_BATCH_SIZE = "dynamic_batch_size" # Adjust on the fly
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayerProfile:
|
||||||
|
"""Profile of a neural network layer."""
|
||||||
|
name: str
|
||||||
|
layer_type: str
|
||||||
|
parameters: int
|
||||||
|
activation_size: int # Per sample
|
||||||
|
gradient_size: int # Per sample
|
||||||
|
computation_time: float
|
||||||
|
memory_bytes: int
|
||||||
|
can_checkpoint: bool
|
||||||
|
precision: str # 'fp32', 'fp16', 'int8'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelProfile:
|
||||||
|
"""Complete model memory profile."""
|
||||||
|
total_parameters: int
|
||||||
|
total_activations: int # Per sample
|
||||||
|
peak_memory: int
|
||||||
|
layers: List[LayerProfile]
|
||||||
|
memory_timeline: List[Tuple[str, int]] # (operation, memory)
|
||||||
|
bottleneck_layers: List[str]
|
||||||
|
framework: str # 'pytorch', 'tensorflow', 'generic'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptimizationPlan:
|
||||||
|
"""Optimization plan for model training."""
|
||||||
|
strategies: List[MemoryOptimizationStrategy]
|
||||||
|
checkpoint_layers: List[str]
|
||||||
|
batch_size: int
|
||||||
|
gradient_accumulation_steps: int
|
||||||
|
mixed_precision_config: Dict[str, Any]
|
||||||
|
estimated_memory: int
|
||||||
|
estimated_speedup: float
|
||||||
|
memory_savings: int
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
"""Configuration for optimized training."""
|
||||||
|
original_batch_size: int
|
||||||
|
optimized_batch_size: int
|
||||||
|
accumulation_steps: int
|
||||||
|
checkpoint_segments: List[List[str]]
|
||||||
|
precision_map: Dict[str, str]
|
||||||
|
memory_limit: int
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MLMemoryOptimizer:
|
||||||
|
"""Optimize memory usage for ML model training."""
|
||||||
|
|
||||||
|
def __init__(self, memory_limit: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Initialize optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_limit: Memory limit in bytes (None for auto-detect)
|
||||||
|
"""
|
||||||
|
self.memory_limit = memory_limit or int(psutil.virtual_memory().available * 0.8)
|
||||||
|
|
||||||
|
def analyze_model(self,
|
||||||
|
model: Any,
|
||||||
|
input_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
|
||||||
|
batch_size: int = 1) -> ModelProfile:
|
||||||
|
"""
|
||||||
|
Analyze model memory requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Neural network model
|
||||||
|
input_shape: Input shape(s)
|
||||||
|
batch_size: Batch size for analysis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelProfile with memory analysis
|
||||||
|
"""
|
||||||
|
if HAS_TORCH and isinstance(model, nn.Module):
|
||||||
|
return self._analyze_torch_model(model, input_shape, batch_size)
|
||||||
|
elif HAS_TF and hasattr(model, 'layers'):
|
||||||
|
return self._analyze_tf_model(model, input_shape, batch_size)
|
||||||
|
else:
|
||||||
|
return self._analyze_generic_model(model, input_shape, batch_size)
|
||||||
|
|
||||||
|
def _analyze_torch_model(self,
|
||||||
|
model: nn.Module,
|
||||||
|
input_shape: Tuple[int, ...],
|
||||||
|
batch_size: int) -> ModelProfile:
|
||||||
|
"""Analyze PyTorch model."""
|
||||||
|
layers = []
|
||||||
|
total_params = 0
|
||||||
|
total_activations = 0
|
||||||
|
memory_timeline = []
|
||||||
|
|
||||||
|
# Count parameters
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
total_params += param.numel()
|
||||||
|
|
||||||
|
# Analyze layers
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if len(list(module.children())) == 0: # Leaf module
|
||||||
|
layer_params = sum(p.numel() for p in module.parameters())
|
||||||
|
|
||||||
|
# Estimate activation size (simplified)
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
activation_size = module.out_features * batch_size * 4 # fp32
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
# Rough estimate
|
||||||
|
activation_size = module.out_channels * 100 * 100 * batch_size * 4
|
||||||
|
else:
|
||||||
|
activation_size = layer_params * batch_size * 4
|
||||||
|
|
||||||
|
total_activations += activation_size
|
||||||
|
|
||||||
|
layers.append(LayerProfile(
|
||||||
|
name=name,
|
||||||
|
layer_type=module.__class__.__name__,
|
||||||
|
parameters=layer_params,
|
||||||
|
activation_size=activation_size // batch_size,
|
||||||
|
gradient_size=layer_params * 4, # fp32 gradients
|
||||||
|
computation_time=0.001, # Placeholder
|
||||||
|
memory_bytes=layer_params * 4 + activation_size,
|
||||||
|
can_checkpoint=self._can_checkpoint_layer(module),
|
||||||
|
precision='fp32'
|
||||||
|
))
|
||||||
|
|
||||||
|
# Find bottlenecks (top 20% by memory)
|
||||||
|
sorted_layers = sorted(layers, key=lambda l: l.memory_bytes, reverse=True)
|
||||||
|
bottleneck_count = max(1, len(layers) // 5)
|
||||||
|
bottleneck_layers = [l.name for l in sorted_layers[:bottleneck_count]]
|
||||||
|
|
||||||
|
return ModelProfile(
|
||||||
|
total_parameters=total_params,
|
||||||
|
total_activations=total_activations // batch_size,
|
||||||
|
peak_memory=total_params * 4 + total_activations,
|
||||||
|
layers=layers,
|
||||||
|
memory_timeline=memory_timeline,
|
||||||
|
bottleneck_layers=bottleneck_layers,
|
||||||
|
framework='pytorch'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_tf_model(self,
|
||||||
|
model: Any,
|
||||||
|
input_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
|
||||||
|
batch_size: int) -> ModelProfile:
|
||||||
|
"""Analyze TensorFlow model."""
|
||||||
|
layers = []
|
||||||
|
total_params = model.count_params()
|
||||||
|
total_activations = 0
|
||||||
|
|
||||||
|
# Analyze each layer
|
||||||
|
for layer in model.layers:
|
||||||
|
layer_params = layer.count_params()
|
||||||
|
|
||||||
|
# Estimate activation size
|
||||||
|
if hasattr(layer, 'output_shape'):
|
||||||
|
shape = layer.output_shape
|
||||||
|
if isinstance(shape, tuple):
|
||||||
|
activation_size = np.prod(shape[1:]) * batch_size * 4
|
||||||
|
else:
|
||||||
|
activation_size = layer_params * batch_size * 4
|
||||||
|
else:
|
||||||
|
activation_size = layer_params * batch_size * 4
|
||||||
|
|
||||||
|
total_activations += activation_size
|
||||||
|
|
||||||
|
layers.append(LayerProfile(
|
||||||
|
name=layer.name,
|
||||||
|
layer_type=layer.__class__.__name__,
|
||||||
|
parameters=layer_params,
|
||||||
|
activation_size=activation_size // batch_size,
|
||||||
|
gradient_size=layer_params * 4,
|
||||||
|
computation_time=0.001,
|
||||||
|
memory_bytes=layer_params * 4 + activation_size,
|
||||||
|
can_checkpoint=True, # Most TF layers can checkpoint
|
||||||
|
precision='fp32'
|
||||||
|
))
|
||||||
|
|
||||||
|
# Find bottlenecks
|
||||||
|
sorted_layers = sorted(layers, key=lambda l: l.memory_bytes, reverse=True)
|
||||||
|
bottleneck_count = max(1, len(layers) // 5)
|
||||||
|
bottleneck_layers = [l.name for l in sorted_layers[:bottleneck_count]]
|
||||||
|
|
||||||
|
return ModelProfile(
|
||||||
|
total_parameters=total_params,
|
||||||
|
total_activations=total_activations // batch_size,
|
||||||
|
peak_memory=total_params * 4 + total_activations,
|
||||||
|
layers=layers,
|
||||||
|
memory_timeline=[],
|
||||||
|
bottleneck_layers=bottleneck_layers,
|
||||||
|
framework='tensorflow'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_generic_model(self,
|
||||||
|
model: Any,
|
||||||
|
input_shape: Tuple[int, ...],
|
||||||
|
batch_size: int) -> ModelProfile:
|
||||||
|
"""Analyze generic model."""
|
||||||
|
# Basic heuristics
|
||||||
|
estimated_params = 10_000_000 # 10M parameters
|
||||||
|
estimated_activations = estimated_params * batch_size
|
||||||
|
|
||||||
|
return ModelProfile(
|
||||||
|
total_parameters=estimated_params,
|
||||||
|
total_activations=estimated_activations,
|
||||||
|
peak_memory=estimated_params * 4 + estimated_activations * 4,
|
||||||
|
layers=[],
|
||||||
|
memory_timeline=[],
|
||||||
|
bottleneck_layers=[],
|
||||||
|
framework='generic'
|
||||||
|
)
|
||||||
|
|
||||||
|
def optimize(self,
|
||||||
|
model_profile: ModelProfile,
|
||||||
|
target_batch_size: int,
|
||||||
|
strategies: Optional[List[MemoryOptimizationStrategy]] = None) -> OptimizationPlan:
|
||||||
|
"""
|
||||||
|
Generate optimization plan for model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_profile: Model profile from analyze_model
|
||||||
|
target_batch_size: Desired batch size
|
||||||
|
strategies: Strategies to consider (None for auto)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OptimizationPlan with recommendations
|
||||||
|
"""
|
||||||
|
if strategies is None:
|
||||||
|
strategies = self._select_strategies(model_profile, target_batch_size)
|
||||||
|
|
||||||
|
# Calculate memory requirements
|
||||||
|
base_memory = model_profile.total_parameters * 4 # Parameters
|
||||||
|
activation_memory = model_profile.total_activations * target_batch_size * 4
|
||||||
|
gradient_memory = model_profile.total_parameters * 4 # Gradients
|
||||||
|
optimizer_memory = model_profile.total_parameters * 8 # Adam states
|
||||||
|
|
||||||
|
total_memory = base_memory + activation_memory + gradient_memory + optimizer_memory
|
||||||
|
|
||||||
|
# Initialize plan
|
||||||
|
plan = OptimizationPlan(
|
||||||
|
strategies=strategies,
|
||||||
|
checkpoint_layers=[],
|
||||||
|
batch_size=target_batch_size,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
mixed_precision_config={},
|
||||||
|
estimated_memory=total_memory,
|
||||||
|
estimated_speedup=1.0,
|
||||||
|
memory_savings=0,
|
||||||
|
explanation=""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply strategies
|
||||||
|
for strategy in strategies:
|
||||||
|
if strategy == MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING:
|
||||||
|
self._apply_checkpointing(plan, model_profile)
|
||||||
|
elif strategy == MemoryOptimizationStrategy.MIXED_PRECISION:
|
||||||
|
self._apply_mixed_precision(plan, model_profile)
|
||||||
|
elif strategy == MemoryOptimizationStrategy.GRADIENT_ACCUMULATION:
|
||||||
|
self._apply_gradient_accumulation(plan, model_profile)
|
||||||
|
|
||||||
|
# Calculate final estimates
|
||||||
|
plan.memory_savings = total_memory - plan.estimated_memory
|
||||||
|
plan.explanation = self._generate_explanation(plan, model_profile)
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
def _select_strategies(self,
|
||||||
|
model_profile: ModelProfile,
|
||||||
|
target_batch_size: int) -> List[MemoryOptimizationStrategy]:
|
||||||
|
"""Select appropriate optimization strategies."""
|
||||||
|
strategies = []
|
||||||
|
|
||||||
|
# Calculate memory pressure
|
||||||
|
required_memory = (model_profile.total_parameters * 4 +
|
||||||
|
model_profile.total_activations * target_batch_size * 4)
|
||||||
|
|
||||||
|
if required_memory > self.memory_limit:
|
||||||
|
# High memory pressure - use all strategies
|
||||||
|
strategies.append(MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING)
|
||||||
|
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||||
|
strategies.append(MemoryOptimizationStrategy.GRADIENT_ACCUMULATION)
|
||||||
|
elif required_memory > self.memory_limit * 0.8:
|
||||||
|
# Medium pressure
|
||||||
|
strategies.append(MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING)
|
||||||
|
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||||
|
elif required_memory > self.memory_limit * 0.6:
|
||||||
|
# Low pressure
|
||||||
|
strategies.append(MemoryOptimizationStrategy.MIXED_PRECISION)
|
||||||
|
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
def _apply_checkpointing(self,
|
||||||
|
plan: OptimizationPlan,
|
||||||
|
model_profile: ModelProfile) -> None:
|
||||||
|
"""Apply gradient checkpointing using √n strategy."""
|
||||||
|
n_layers = len(model_profile.layers)
|
||||||
|
|
||||||
|
if n_layers == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use √n checkpointing intervals
|
||||||
|
checkpoint_interval = max(1, int(math.sqrt(n_layers)))
|
||||||
|
|
||||||
|
# Select layers to checkpoint
|
||||||
|
checkpoint_layers = []
|
||||||
|
for i in range(0, n_layers, checkpoint_interval):
|
||||||
|
if i < len(model_profile.layers):
|
||||||
|
layer = model_profile.layers[i]
|
||||||
|
if layer.can_checkpoint:
|
||||||
|
checkpoint_layers.append(layer.name)
|
||||||
|
|
||||||
|
plan.checkpoint_layers = checkpoint_layers
|
||||||
|
|
||||||
|
# Update memory estimate (save ~50% of activation memory)
|
||||||
|
saved_memory = sum(l.activation_size * plan.batch_size * 4
|
||||||
|
for l in model_profile.layers
|
||||||
|
if l.name in checkpoint_layers) * 0.5
|
||||||
|
|
||||||
|
plan.estimated_memory -= int(saved_memory)
|
||||||
|
plan.estimated_speedup *= 0.8 # 20% slowdown from recomputation
|
||||||
|
|
||||||
|
def _apply_mixed_precision(self,
|
||||||
|
plan: OptimizationPlan,
|
||||||
|
model_profile: ModelProfile) -> None:
|
||||||
|
"""Apply mixed precision training."""
|
||||||
|
plan.mixed_precision_config = {
|
||||||
|
'enabled': True,
|
||||||
|
'loss_scale': 'dynamic',
|
||||||
|
'compute_dtype': 'float16',
|
||||||
|
'variable_dtype': 'float32'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update memory estimate (save ~50% on activations)
|
||||||
|
activation_savings = model_profile.total_activations * plan.batch_size * 2
|
||||||
|
plan.estimated_memory -= activation_savings
|
||||||
|
plan.estimated_speedup *= 1.5 # Potential speedup on modern GPUs
|
||||||
|
|
||||||
|
def _apply_gradient_accumulation(self,
|
||||||
|
plan: OptimizationPlan,
|
||||||
|
model_profile: ModelProfile) -> None:
|
||||||
|
"""Apply gradient accumulation."""
|
||||||
|
# Calculate how many accumulation steps needed
|
||||||
|
current_memory = plan.estimated_memory
|
||||||
|
|
||||||
|
if current_memory > self.memory_limit:
|
||||||
|
# Reduce effective batch size
|
||||||
|
reduction_factor = current_memory / self.memory_limit
|
||||||
|
accumulation_steps = int(math.ceil(reduction_factor))
|
||||||
|
|
||||||
|
# Adjust batch size and accumulation
|
||||||
|
effective_batch = plan.batch_size // accumulation_steps
|
||||||
|
plan.batch_size = max(1, effective_batch)
|
||||||
|
plan.gradient_accumulation_steps = accumulation_steps
|
||||||
|
|
||||||
|
# Update memory estimate
|
||||||
|
plan.estimated_memory = plan.estimated_memory // accumulation_steps
|
||||||
|
|
||||||
|
def _can_checkpoint_layer(self, layer: Any) -> bool:
|
||||||
|
"""Check if layer can be checkpointed."""
|
||||||
|
if HAS_TORCH:
|
||||||
|
# Most layers can be checkpointed except those with side effects
|
||||||
|
no_checkpoint_types = (nn.Dropout, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
|
||||||
|
return not isinstance(layer, no_checkpoint_types)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _generate_explanation(self,
|
||||||
|
plan: OptimizationPlan,
|
||||||
|
model_profile: ModelProfile) -> str:
|
||||||
|
"""Generate human-readable explanation."""
|
||||||
|
explanations = []
|
||||||
|
|
||||||
|
explanations.append(f"Model Analysis:")
|
||||||
|
explanations.append(f"- Total parameters: {model_profile.total_parameters:,}")
|
||||||
|
explanations.append(f"- Peak memory estimate: {plan.estimated_memory / (1024**3):.2f} GB")
|
||||||
|
explanations.append(f"- Memory savings: {plan.memory_savings / (1024**3):.2f} GB")
|
||||||
|
|
||||||
|
if MemoryOptimizationStrategy.GRADIENT_CHECKPOINTING in plan.strategies:
|
||||||
|
explanations.append(f"\nGradient Checkpointing:")
|
||||||
|
explanations.append(f"- Checkpointing {len(plan.checkpoint_layers)} layers using √n strategy")
|
||||||
|
explanations.append(f"- This trades ~20% compute time for ~50% activation memory")
|
||||||
|
|
||||||
|
if MemoryOptimizationStrategy.MIXED_PRECISION in plan.strategies:
|
||||||
|
explanations.append(f"\nMixed Precision:")
|
||||||
|
explanations.append(f"- Using FP16 for forward pass, FP32 for gradients")
|
||||||
|
explanations.append(f"- Reduces activation memory by ~50%")
|
||||||
|
|
||||||
|
if plan.gradient_accumulation_steps > 1:
|
||||||
|
explanations.append(f"\nGradient Accumulation:")
|
||||||
|
explanations.append(f"- Accumulating over {plan.gradient_accumulation_steps} steps")
|
||||||
|
explanations.append(f"- Effective batch size: {plan.batch_size * plan.gradient_accumulation_steps}")
|
||||||
|
|
||||||
|
return "\n".join(explanations)
|
||||||
|
|
||||||
|
def get_training_config(self,
|
||||||
|
plan: OptimizationPlan,
|
||||||
|
model_profile: ModelProfile) -> TrainingConfig:
|
||||||
|
"""
|
||||||
|
Generate training configuration from optimization plan.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: Optimization plan
|
||||||
|
model_profile: Model profile
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainingConfig ready for use
|
||||||
|
"""
|
||||||
|
# Group checkpoint layers into segments
|
||||||
|
checkpoint_segments = []
|
||||||
|
if plan.checkpoint_layers:
|
||||||
|
# Create √n segments
|
||||||
|
n_segments = int(math.sqrt(len(plan.checkpoint_layers)))
|
||||||
|
segment_size = max(1, len(plan.checkpoint_layers) // n_segments)
|
||||||
|
|
||||||
|
for i in range(0, len(plan.checkpoint_layers), segment_size):
|
||||||
|
segment = plan.checkpoint_layers[i:i + segment_size]
|
||||||
|
if segment:
|
||||||
|
checkpoint_segments.append(segment)
|
||||||
|
|
||||||
|
# Create precision map
|
||||||
|
precision_map = {}
|
||||||
|
if MemoryOptimizationStrategy.MIXED_PRECISION in plan.strategies:
|
||||||
|
for layer in model_profile.layers:
|
||||||
|
# Use FP16 for compute-heavy layers
|
||||||
|
if layer.layer_type in ['Linear', 'Conv2d', 'Dense', 'Conv2D']:
|
||||||
|
precision_map[layer.name] = 'fp16'
|
||||||
|
else:
|
||||||
|
precision_map[layer.name] = 'fp32'
|
||||||
|
|
||||||
|
return TrainingConfig(
|
||||||
|
original_batch_size=plan.batch_size * plan.gradient_accumulation_steps,
|
||||||
|
optimized_batch_size=plan.batch_size,
|
||||||
|
accumulation_steps=plan.gradient_accumulation_steps,
|
||||||
|
checkpoint_segments=checkpoint_segments,
|
||||||
|
precision_map=precision_map,
|
||||||
|
memory_limit=self.memory_limit
|
||||||
|
)
|
||||||
25
src/sqrtspace_spacetime/profiler/__init__.py
Normal file
25
src/sqrtspace_spacetime/profiler/__init__.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""SpaceTime Profiler for memory and performance analysis."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.profiler.profiler import (
|
||||||
|
SpaceTimeProfiler,
|
||||||
|
ProfilingReport,
|
||||||
|
Hotspot,
|
||||||
|
BottleneckAnalysis,
|
||||||
|
AccessPattern,
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.profiler.decorators import (
|
||||||
|
profile,
|
||||||
|
profile_memory,
|
||||||
|
profile_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SpaceTimeProfiler",
|
||||||
|
"ProfilingReport",
|
||||||
|
"Hotspot",
|
||||||
|
"BottleneckAnalysis",
|
||||||
|
"AccessPattern",
|
||||||
|
"profile",
|
||||||
|
"profile_memory",
|
||||||
|
"profile_time",
|
||||||
|
]
|
||||||
175
src/sqrtspace_spacetime/profiler/decorators.py
Normal file
175
src/sqrtspace_spacetime/profiler/decorators.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
"""Decorators for easy profiling."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.profiler.profiler import SpaceTimeProfiler
|
||||||
|
|
||||||
|
|
||||||
|
def profile(output_file: Optional[str] = None,
|
||||||
|
print_summary: bool = True) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to profile a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Optional file to save report
|
||||||
|
print_summary: Print summary to console
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@profile(output_file="profile.json")
|
||||||
|
def my_function():
|
||||||
|
# Process data
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
profiler = SpaceTimeProfiler()
|
||||||
|
result, report = profiler.profile(func, *args, **kwargs)
|
||||||
|
|
||||||
|
if print_summary:
|
||||||
|
print(report.summary)
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
report.save(output_file)
|
||||||
|
|
||||||
|
# Store report on function for access
|
||||||
|
wrapper.last_report = report
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
wrapper.last_report = None
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def profile_memory(threshold_mb: float = 100,
|
||||||
|
alert: bool = True) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to profile memory usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold_mb: Memory threshold in MB to trigger alert
|
||||||
|
alert: Print alert if threshold exceeded
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@profile_memory(threshold_mb=500)
|
||||||
|
def process_large_data():
|
||||||
|
# Process data
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
start_memory = process.memory_info().rss
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
end_memory = process.memory_info().rss
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
memory_used = (end_memory - start_memory) / (1024 * 1024)
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
# Store metrics
|
||||||
|
wrapper.memory_used = memory_used
|
||||||
|
wrapper.duration = duration
|
||||||
|
|
||||||
|
if alert and memory_used > threshold_mb:
|
||||||
|
print(f"⚠️ Memory Alert: {func.__name__} used {memory_used:.1f}MB "
|
||||||
|
f"(threshold: {threshold_mb}MB)")
|
||||||
|
print(f" Consider using SpaceTime collections for memory efficiency")
|
||||||
|
|
||||||
|
if alert:
|
||||||
|
print(f"Memory: {memory_used:.1f}MB, Time: {duration:.2f}s")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
wrapper.memory_used = None
|
||||||
|
wrapper.duration = None
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def profile_time(threshold_seconds: float = 1.0,
|
||||||
|
alert: bool = True) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to profile execution time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold_seconds: Time threshold to trigger alert
|
||||||
|
alert: Print alert if threshold exceeded
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@profile_time(threshold_seconds=5.0)
|
||||||
|
def slow_operation():
|
||||||
|
# Time-consuming operation
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
wrapper.duration = duration
|
||||||
|
|
||||||
|
if alert and duration > threshold_seconds:
|
||||||
|
print(f"⏱️ Time Alert: {func.__name__} took {duration:.2f}s "
|
||||||
|
f"(threshold: {threshold_seconds}s)")
|
||||||
|
|
||||||
|
if alert:
|
||||||
|
print(f"Execution time: {duration:.2f}s")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
wrapper.duration = None
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileContext:
|
||||||
|
"""Context manager for profiling code blocks."""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "block", print_summary: bool = True):
|
||||||
|
self.name = name
|
||||||
|
self.print_summary = print_summary
|
||||||
|
self.profiler = None
|
||||||
|
self.report = None
|
||||||
|
self._monitoring = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.profiler = SpaceTimeProfiler()
|
||||||
|
self.profiler.start_monitoring()
|
||||||
|
self._start_time = time.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
duration = time.time() - self._start_time
|
||||||
|
self.profiler.stop_monitoring()
|
||||||
|
|
||||||
|
# Generate simple report
|
||||||
|
if self.print_summary:
|
||||||
|
peak_memory = max((m[1] for m in self.profiler.memory_timeline), default=0)
|
||||||
|
print(f"\nProfile: {self.name}")
|
||||||
|
print(f"Duration: {duration:.2f}s")
|
||||||
|
print(f"Peak Memory: {peak_memory / (1024*1024):.1f}MB")
|
||||||
|
|
||||||
|
if peak_memory > 100 * 1024 * 1024: # 100MB
|
||||||
|
print("💡 Consider using SpaceTime collections for memory optimization")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience instance
|
||||||
|
profile_context = ProfileContext
|
||||||
475
src/sqrtspace_spacetime/profiler/profiler.py
Normal file
475
src/sqrtspace_spacetime/profiler/profiler.py
Normal file
@ -0,0 +1,475 @@
|
|||||||
|
"""
|
||||||
|
SpaceTime Profiler: Profile applications to identify optimization opportunities.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Memory pattern analysis (sequential, random, strided)
|
||||||
|
- Bottleneck detection (memory vs CPU)
|
||||||
|
- Memory hierarchy awareness (L1/L2/L3/RAM/Disk)
|
||||||
|
- Hotspot identification
|
||||||
|
- AI-generated recommendations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import psutil
|
||||||
|
import numpy as np
|
||||||
|
import tracemalloc
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
import io
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
|
||||||
|
|
||||||
|
class AccessPattern(Enum):
|
||||||
|
"""Memory access patterns."""
|
||||||
|
SEQUENTIAL = "sequential"
|
||||||
|
RANDOM = "random"
|
||||||
|
STRIDED = "strided"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryAccess:
|
||||||
|
"""Single memory access event."""
|
||||||
|
timestamp: float
|
||||||
|
address: int
|
||||||
|
size: int
|
||||||
|
operation: str # 'read' or 'write'
|
||||||
|
function: str
|
||||||
|
line_number: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hotspot:
|
||||||
|
"""Memory hotspot information."""
|
||||||
|
function: str
|
||||||
|
file_path: str
|
||||||
|
line_number: int
|
||||||
|
memory_allocated: int
|
||||||
|
memory_freed: int
|
||||||
|
net_memory: int
|
||||||
|
allocation_count: int
|
||||||
|
cpu_time: float
|
||||||
|
access_pattern: AccessPattern
|
||||||
|
recommendations: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BottleneckAnalysis:
|
||||||
|
"""Analysis of performance bottlenecks."""
|
||||||
|
type: str # 'memory', 'cpu', 'io'
|
||||||
|
severity: float # 0.0 to 1.0
|
||||||
|
description: str
|
||||||
|
evidence: Dict[str, Any]
|
||||||
|
recommendations: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfilingReport:
|
||||||
|
"""Complete profiling report."""
|
||||||
|
timestamp: str
|
||||||
|
duration: float
|
||||||
|
peak_memory: int
|
||||||
|
total_allocations: int
|
||||||
|
memory_timeline: List[Tuple[float, int]]
|
||||||
|
cpu_timeline: List[Tuple[float, float]]
|
||||||
|
hotspots: List[Hotspot]
|
||||||
|
bottlenecks: List[BottleneckAnalysis]
|
||||||
|
access_patterns: Dict[str, AccessPattern]
|
||||||
|
hierarchy_transitions: Dict[str, int]
|
||||||
|
optimization_opportunities: List[Dict[str, Any]]
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert report to dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
def save(self, path: str) -> None:
|
||||||
|
"""Save report to JSON file."""
|
||||||
|
import json
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracer:
|
||||||
|
"""Trace memory accesses and allocations."""
|
||||||
|
|
||||||
|
def __init__(self, max_samples: int = 100000):
|
||||||
|
self.accesses = deque(maxlen=max_samples)
|
||||||
|
self.allocations = defaultdict(list)
|
||||||
|
self.start_time = time.time()
|
||||||
|
self._tracemalloc_snapshot = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start memory tracing."""
|
||||||
|
if not tracemalloc.is_tracing():
|
||||||
|
tracemalloc.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop memory tracing."""
|
||||||
|
if tracemalloc.is_tracing():
|
||||||
|
self._tracemalloc_snapshot = tracemalloc.take_snapshot()
|
||||||
|
tracemalloc.stop()
|
||||||
|
|
||||||
|
def analyze_pattern(self, accesses: List[MemoryAccess]) -> AccessPattern:
|
||||||
|
"""Analyze access pattern from recent accesses."""
|
||||||
|
if len(accesses) < 10:
|
||||||
|
return AccessPattern.UNKNOWN
|
||||||
|
|
||||||
|
# Extract addresses
|
||||||
|
addresses = [a.address for a in accesses[-100:]]
|
||||||
|
|
||||||
|
# Calculate differences
|
||||||
|
diffs = np.diff(addresses)
|
||||||
|
if len(diffs) == 0:
|
||||||
|
return AccessPattern.UNKNOWN
|
||||||
|
|
||||||
|
# Check for sequential pattern
|
||||||
|
if np.all(diffs > 0) and np.std(diffs) < np.mean(diffs) * 0.1:
|
||||||
|
return AccessPattern.SEQUENTIAL
|
||||||
|
|
||||||
|
# Check for strided pattern
|
||||||
|
unique_diffs = set(diffs)
|
||||||
|
if len(unique_diffs) < 5 and np.std(diffs) < 100:
|
||||||
|
return AccessPattern.STRIDED
|
||||||
|
|
||||||
|
# Otherwise random
|
||||||
|
return AccessPattern.RANDOM
|
||||||
|
|
||||||
|
def get_top_allocators(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""Get top memory allocators from tracemalloc."""
|
||||||
|
if not self._tracemalloc_snapshot:
|
||||||
|
return []
|
||||||
|
|
||||||
|
top_stats = self._tracemalloc_snapshot.statistics('lineno')[:limit]
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for stat in top_stats:
|
||||||
|
result.append({
|
||||||
|
'file': stat.traceback.format()[0] if stat.traceback else 'unknown',
|
||||||
|
'size': stat.size,
|
||||||
|
'count': stat.count,
|
||||||
|
'average': stat.size // stat.count if stat.count > 0 else 0
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceTimeProfiler:
|
||||||
|
"""Main profiler class."""
|
||||||
|
|
||||||
|
def __init__(self, sample_interval: float = 0.01):
|
||||||
|
self.sample_interval = sample_interval
|
||||||
|
self.memory_tracer = MemoryTracer()
|
||||||
|
|
||||||
|
# Tracking data
|
||||||
|
self.memory_timeline = []
|
||||||
|
self.cpu_timeline = []
|
||||||
|
self.io_timeline = []
|
||||||
|
self.function_stats = defaultdict(lambda: {
|
||||||
|
'calls': 0,
|
||||||
|
'memory': 0,
|
||||||
|
'time': 0.0,
|
||||||
|
'allocations': []
|
||||||
|
})
|
||||||
|
|
||||||
|
self._monitoring = False
|
||||||
|
self._monitor_thread = None
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
|
def start_monitoring(self):
|
||||||
|
"""Start background monitoring."""
|
||||||
|
self._monitoring = True
|
||||||
|
self._start_time = time.time()
|
||||||
|
self.memory_tracer.start()
|
||||||
|
|
||||||
|
self._monitor_thread = threading.Thread(target=self._monitor_loop)
|
||||||
|
self._monitor_thread.daemon = True
|
||||||
|
self._monitor_thread.start()
|
||||||
|
|
||||||
|
def stop_monitoring(self):
|
||||||
|
"""Stop background monitoring."""
|
||||||
|
self._monitoring = False
|
||||||
|
if self._monitor_thread:
|
||||||
|
self._monitor_thread.join()
|
||||||
|
self.memory_tracer.stop()
|
||||||
|
|
||||||
|
def _monitor_loop(self):
|
||||||
|
"""Background monitoring loop."""
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
while self._monitoring:
|
||||||
|
timestamp = time.time() - self._start_time
|
||||||
|
|
||||||
|
# Memory usage
|
||||||
|
mem_info = process.memory_info()
|
||||||
|
self.memory_timeline.append((timestamp, mem_info.rss))
|
||||||
|
|
||||||
|
# CPU usage
|
||||||
|
cpu_percent = process.cpu_percent(interval=None)
|
||||||
|
self.cpu_timeline.append((timestamp, cpu_percent))
|
||||||
|
|
||||||
|
# IO counters (if available)
|
||||||
|
try:
|
||||||
|
io_counters = process.io_counters()
|
||||||
|
self.io_timeline.append((timestamp, {
|
||||||
|
'read_bytes': io_counters.read_bytes,
|
||||||
|
'write_bytes': io_counters.write_bytes,
|
||||||
|
'read_count': io_counters.read_count,
|
||||||
|
'write_count': io_counters.write_count
|
||||||
|
}))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time.sleep(self.sample_interval)
|
||||||
|
|
||||||
|
def profile(self, func: Callable, *args, **kwargs) -> Tuple[Any, ProfilingReport]:
|
||||||
|
"""Profile a function execution."""
|
||||||
|
# Start monitoring
|
||||||
|
self.start_monitoring()
|
||||||
|
|
||||||
|
# CPU profiling
|
||||||
|
profiler = cProfile.Profile()
|
||||||
|
profiler.enable()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute function
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
# Stop profiling
|
||||||
|
end_time = time.time()
|
||||||
|
profiler.disable()
|
||||||
|
self.stop_monitoring()
|
||||||
|
|
||||||
|
# Generate report
|
||||||
|
report = self._generate_report(
|
||||||
|
duration=end_time - start_time,
|
||||||
|
cpu_profile=profiler
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, report
|
||||||
|
|
||||||
|
def _generate_report(self, duration: float, cpu_profile: cProfile.Profile) -> ProfilingReport:
|
||||||
|
"""Generate comprehensive profiling report."""
|
||||||
|
# Get peak memory
|
||||||
|
peak_memory = max((m[1] for m in self.memory_timeline), default=0)
|
||||||
|
|
||||||
|
# Analyze components
|
||||||
|
hotspots = self._analyze_hotspots(cpu_profile)
|
||||||
|
bottlenecks = self._analyze_bottlenecks()
|
||||||
|
patterns = self._analyze_access_patterns()
|
||||||
|
transitions = self._count_hierarchy_transitions()
|
||||||
|
opportunities = self._find_optimization_opportunities(hotspots, bottlenecks)
|
||||||
|
|
||||||
|
# Generate summary
|
||||||
|
summary = self._generate_summary(duration, peak_memory, hotspots, bottlenecks)
|
||||||
|
|
||||||
|
return ProfilingReport(
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
duration=duration,
|
||||||
|
peak_memory=peak_memory,
|
||||||
|
total_allocations=len(self.memory_tracer.allocations),
|
||||||
|
memory_timeline=self.memory_timeline,
|
||||||
|
cpu_timeline=self.cpu_timeline,
|
||||||
|
hotspots=hotspots,
|
||||||
|
bottlenecks=bottlenecks,
|
||||||
|
access_patterns=patterns,
|
||||||
|
hierarchy_transitions=transitions,
|
||||||
|
optimization_opportunities=opportunities,
|
||||||
|
summary=summary
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_hotspots(self, cpu_profile: cProfile.Profile) -> List[Hotspot]:
|
||||||
|
"""Identify performance hotspots."""
|
||||||
|
stats = pstats.Stats(cpu_profile)
|
||||||
|
stats.sort_stats('cumulative')
|
||||||
|
|
||||||
|
hotspots = []
|
||||||
|
top_allocators = self.memory_tracer.get_top_allocators()
|
||||||
|
|
||||||
|
# Create lookup for memory stats
|
||||||
|
memory_by_file = {stat['file']: stat for stat in top_allocators}
|
||||||
|
|
||||||
|
# Analyze top functions
|
||||||
|
for func_info, (cc, nc, tt, ct, callers) in list(stats.stats.items())[:20]:
|
||||||
|
filename, line_number, function_name = func_info
|
||||||
|
|
||||||
|
# Get memory info if available
|
||||||
|
mem_info = memory_by_file.get(f"{filename}:{line_number}", {})
|
||||||
|
|
||||||
|
# Skip built-in functions
|
||||||
|
if filename.startswith('<') or 'site-packages' in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine access pattern (simplified)
|
||||||
|
pattern = AccessPattern.UNKNOWN
|
||||||
|
|
||||||
|
# Generate recommendations
|
||||||
|
recommendations = []
|
||||||
|
if ct > duration * 0.1: # More than 10% of time
|
||||||
|
recommendations.append("Consider optimizing this function - it's a CPU hotspot")
|
||||||
|
if mem_info.get('size', 0) > peak_memory * 0.1: # More than 10% of memory
|
||||||
|
recommendations.append("This function allocates significant memory - consider √n optimization")
|
||||||
|
|
||||||
|
hotspots.append(Hotspot(
|
||||||
|
function=function_name,
|
||||||
|
file_path=filename,
|
||||||
|
line_number=line_number,
|
||||||
|
memory_allocated=mem_info.get('size', 0),
|
||||||
|
memory_freed=0, # Not tracked in simple version
|
||||||
|
net_memory=mem_info.get('size', 0),
|
||||||
|
allocation_count=mem_info.get('count', 0),
|
||||||
|
cpu_time=ct,
|
||||||
|
access_pattern=pattern,
|
||||||
|
recommendations=recommendations
|
||||||
|
))
|
||||||
|
|
||||||
|
return hotspots
|
||||||
|
|
||||||
|
def _analyze_bottlenecks(self) -> List[BottleneckAnalysis]:
|
||||||
|
"""Analyze performance bottlenecks."""
|
||||||
|
bottlenecks = []
|
||||||
|
|
||||||
|
# Memory bottleneck analysis
|
||||||
|
if self.memory_timeline:
|
||||||
|
mem_values = [m[1] for m in self.memory_timeline]
|
||||||
|
mem_growth = mem_values[-1] - mem_values[0] if len(mem_values) > 1 else 0
|
||||||
|
|
||||||
|
if mem_growth > 100 * 1024 * 1024: # 100MB growth
|
||||||
|
bottlenecks.append(BottleneckAnalysis(
|
||||||
|
type="memory",
|
||||||
|
severity=min(1.0, mem_growth / (1024 * 1024 * 1024)), # GB scale
|
||||||
|
description=f"Significant memory growth detected: {mem_growth / (1024*1024):.1f}MB",
|
||||||
|
evidence={
|
||||||
|
"start_memory": mem_values[0],
|
||||||
|
"end_memory": mem_values[-1],
|
||||||
|
"growth": mem_growth
|
||||||
|
},
|
||||||
|
recommendations=[
|
||||||
|
"Consider using SpaceTime collections for large datasets",
|
||||||
|
"Implement streaming processing with √n buffering",
|
||||||
|
"Use external sorting/grouping algorithms"
|
||||||
|
]
|
||||||
|
))
|
||||||
|
|
||||||
|
# CPU bottleneck analysis
|
||||||
|
if self.cpu_timeline:
|
||||||
|
cpu_values = [c[1] for c in self.cpu_timeline]
|
||||||
|
avg_cpu = np.mean(cpu_values) if cpu_values else 0
|
||||||
|
|
||||||
|
if avg_cpu > 80: # 80% CPU usage
|
||||||
|
bottlenecks.append(BottleneckAnalysis(
|
||||||
|
type="cpu",
|
||||||
|
severity=min(1.0, avg_cpu / 100),
|
||||||
|
description=f"High CPU usage detected: {avg_cpu:.1f}% average",
|
||||||
|
evidence={
|
||||||
|
"average_cpu": avg_cpu,
|
||||||
|
"peak_cpu": max(cpu_values) if cpu_values else 0
|
||||||
|
},
|
||||||
|
recommendations=[
|
||||||
|
"Profile CPU hotspots for optimization opportunities",
|
||||||
|
"Consider parallel processing with √n chunk size",
|
||||||
|
"Use more efficient algorithms"
|
||||||
|
]
|
||||||
|
))
|
||||||
|
|
||||||
|
return bottlenecks
|
||||||
|
|
||||||
|
def _analyze_access_patterns(self) -> Dict[str, AccessPattern]:
|
||||||
|
"""Analyze memory access patterns by function."""
|
||||||
|
# Simplified implementation
|
||||||
|
return {"overall": AccessPattern.UNKNOWN}
|
||||||
|
|
||||||
|
def _count_hierarchy_transitions(self) -> Dict[str, int]:
|
||||||
|
"""Count memory hierarchy transitions."""
|
||||||
|
# Simplified implementation
|
||||||
|
transitions = {
|
||||||
|
"L1_to_L2": 0,
|
||||||
|
"L2_to_L3": 0,
|
||||||
|
"L3_to_RAM": 0,
|
||||||
|
"RAM_to_Disk": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Estimate based on memory growth
|
||||||
|
if self.memory_timeline:
|
||||||
|
mem_values = [m[1] for m in self.memory_timeline]
|
||||||
|
max_mem = max(mem_values) if mem_values else 0
|
||||||
|
|
||||||
|
if max_mem > 32 * 1024: # > L1
|
||||||
|
transitions["L1_to_L2"] += 1
|
||||||
|
if max_mem > 256 * 1024: # > L2
|
||||||
|
transitions["L2_to_L3"] += 1
|
||||||
|
if max_mem > 8 * 1024 * 1024: # > L3
|
||||||
|
transitions["L3_to_RAM"] += 1
|
||||||
|
if max_mem > 1024 * 1024 * 1024: # > 1GB
|
||||||
|
transitions["RAM_to_Disk"] += 1
|
||||||
|
|
||||||
|
return transitions
|
||||||
|
|
||||||
|
def _find_optimization_opportunities(self,
|
||||||
|
hotspots: List[Hotspot],
|
||||||
|
bottlenecks: List[BottleneckAnalysis]) -> List[Dict[str, Any]]:
|
||||||
|
"""Find SpaceTime optimization opportunities."""
|
||||||
|
opportunities = []
|
||||||
|
|
||||||
|
# Check for large memory allocations
|
||||||
|
for hotspot in hotspots:
|
||||||
|
if hotspot.memory_allocated > 10 * 1024 * 1024: # 10MB
|
||||||
|
opportunities.append({
|
||||||
|
"type": "large_allocation",
|
||||||
|
"location": f"{hotspot.file_path}:{hotspot.line_number}",
|
||||||
|
"function": hotspot.function,
|
||||||
|
"memory": hotspot.memory_allocated,
|
||||||
|
"suggestion": "Use SpaceTimeArray or SpaceTimeDict for large collections",
|
||||||
|
"potential_savings": f"{hotspot.memory_allocated * 0.9 / (1024*1024):.1f}MB"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for memory growth patterns
|
||||||
|
memory_bottleneck = next((b for b in bottlenecks if b.type == "memory"), None)
|
||||||
|
if memory_bottleneck:
|
||||||
|
opportunities.append({
|
||||||
|
"type": "memory_growth",
|
||||||
|
"severity": memory_bottleneck.severity,
|
||||||
|
"suggestion": "Implement streaming processing with Stream class",
|
||||||
|
"example": "Stream.from_file('data.csv').map(process).chunk(√n).foreach(save)"
|
||||||
|
})
|
||||||
|
|
||||||
|
return opportunities
|
||||||
|
|
||||||
|
def _generate_summary(self, duration: float, peak_memory: int,
|
||||||
|
hotspots: List[Hotspot],
|
||||||
|
bottlenecks: List[BottleneckAnalysis]) -> str:
|
||||||
|
"""Generate human-readable summary."""
|
||||||
|
summary_parts = [
|
||||||
|
f"Profile Summary",
|
||||||
|
f"===============",
|
||||||
|
f"Duration: {duration:.2f}s",
|
||||||
|
f"Peak Memory: {peak_memory / (1024*1024):.1f}MB",
|
||||||
|
f"Hotspots Found: {len(hotspots)}",
|
||||||
|
f"Bottlenecks: {len(bottlenecks)}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if bottlenecks:
|
||||||
|
summary_parts.append("\nMain Bottlenecks:")
|
||||||
|
for b in bottlenecks[:3]:
|
||||||
|
summary_parts.append(f"- {b.type.upper()}: {b.description}")
|
||||||
|
|
||||||
|
if hotspots:
|
||||||
|
summary_parts.append("\nTop Hotspots:")
|
||||||
|
for h in hotspots[:3]:
|
||||||
|
summary_parts.append(f"- {h.function} ({h.cpu_time:.2f}s, {h.memory_allocated/(1024*1024):.1f}MB)")
|
||||||
|
|
||||||
|
# Add SpaceTime recommendation
|
||||||
|
if peak_memory > 100 * 1024 * 1024: # 100MB
|
||||||
|
summary_parts.append("\nSpaceTime Optimization Potential: HIGH")
|
||||||
|
summary_parts.append("Consider using SpaceTime collections and algorithms for √n memory reduction")
|
||||||
|
|
||||||
|
return "\n".join(summary_parts)
|
||||||
2
src/sqrtspace_spacetime/py.typed
Normal file
2
src/sqrtspace_spacetime/py.typed
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# Marker file for PEP 561
|
||||||
|
# This package supports type hints
|
||||||
27
src/sqrtspace_spacetime/streams/__init__.py
Normal file
27
src/sqrtspace_spacetime/streams/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""Streaming operations with √n memory usage."""
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.streams.stream import (
|
||||||
|
Stream,
|
||||||
|
FileStream,
|
||||||
|
CSVStream,
|
||||||
|
JSONLStream,
|
||||||
|
)
|
||||||
|
from sqrtspace_spacetime.streams.operators import (
|
||||||
|
StreamOperator,
|
||||||
|
MapOperator,
|
||||||
|
FilterOperator,
|
||||||
|
FlatMapOperator,
|
||||||
|
ChunkOperator,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Stream",
|
||||||
|
"FileStream",
|
||||||
|
"CSVStream",
|
||||||
|
"JSONLStream",
|
||||||
|
"StreamOperator",
|
||||||
|
"MapOperator",
|
||||||
|
"FilterOperator",
|
||||||
|
"FlatMapOperator",
|
||||||
|
"ChunkOperator",
|
||||||
|
]
|
||||||
169
src/sqrtspace_spacetime/streams/operators.py
Normal file
169
src/sqrtspace_spacetime/streams/operators.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Stream operators for transformation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Iterable, Iterator, List, TypeVar, Optional
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
U = TypeVar('U')
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOperator(ABC):
|
||||||
|
"""Base class for stream operators."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[Any]:
|
||||||
|
"""Apply operator to iterator."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MapOperator(StreamOperator):
|
||||||
|
"""Map each element to a new value."""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[[T], U]):
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[U]:
|
||||||
|
for item in iterator:
|
||||||
|
yield self.func(item)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterOperator(StreamOperator):
|
||||||
|
"""Filter elements by predicate."""
|
||||||
|
|
||||||
|
def __init__(self, predicate: Callable[[T], bool]):
|
||||||
|
self.predicate = predicate
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
for item in iterator:
|
||||||
|
if self.predicate(item):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class FlatMapOperator(StreamOperator):
|
||||||
|
"""Map each element to multiple elements."""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[[T], Iterable[U]]):
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[U]:
|
||||||
|
for item in iterator:
|
||||||
|
result = self.func(item)
|
||||||
|
if hasattr(result, '__iter__'):
|
||||||
|
yield from result
|
||||||
|
else:
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkOperator(StreamOperator):
|
||||||
|
"""Group elements into fixed-size chunks."""
|
||||||
|
|
||||||
|
def __init__(self, size: int):
|
||||||
|
self.size = max(1, size)
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[List[T]]:
|
||||||
|
chunk = []
|
||||||
|
|
||||||
|
for item in iterator:
|
||||||
|
chunk.append(item)
|
||||||
|
|
||||||
|
if len(chunk) >= self.size:
|
||||||
|
yield chunk
|
||||||
|
chunk = []
|
||||||
|
|
||||||
|
# Don't forget last chunk
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class WindowOperator(StreamOperator):
|
||||||
|
"""Sliding window over stream."""
|
||||||
|
|
||||||
|
def __init__(self, size: int, slide: int = 1):
|
||||||
|
self.size = max(1, size)
|
||||||
|
self.slide = max(1, slide)
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[List[T]]:
|
||||||
|
window = []
|
||||||
|
|
||||||
|
for item in iterator:
|
||||||
|
window.append(item)
|
||||||
|
|
||||||
|
if len(window) >= self.size:
|
||||||
|
yield window.copy()
|
||||||
|
|
||||||
|
# Slide window
|
||||||
|
for _ in range(min(self.slide, len(window))):
|
||||||
|
window.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
class TakeWhileOperator(StreamOperator):
|
||||||
|
"""Take elements while predicate is true."""
|
||||||
|
|
||||||
|
def __init__(self, predicate: Callable[[T], bool]):
|
||||||
|
self.predicate = predicate
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
for item in iterator:
|
||||||
|
if self.predicate(item):
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class DropWhileOperator(StreamOperator):
|
||||||
|
"""Drop elements while predicate is true."""
|
||||||
|
|
||||||
|
def __init__(self, predicate: Callable[[T], bool]):
|
||||||
|
self.predicate = predicate
|
||||||
|
self.dropping = True
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
for item in iterator:
|
||||||
|
if self.dropping and self.predicate(item):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self.dropping = False
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class DistinctOperator(StreamOperator):
|
||||||
|
"""Remove duplicate elements."""
|
||||||
|
|
||||||
|
def __init__(self, key_func: Optional[Callable[[T], Any]] = None):
|
||||||
|
self.key_func = key_func or (lambda x: x)
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
for item in iterator:
|
||||||
|
key = self.key_func(item)
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class TakeOperator(StreamOperator):
|
||||||
|
"""Take first n elements."""
|
||||||
|
|
||||||
|
def __init__(self, n: int):
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
for i, item in enumerate(iterator):
|
||||||
|
if i >= self.n:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class SkipOperator(StreamOperator):
|
||||||
|
"""Skip first n elements."""
|
||||||
|
|
||||||
|
def __init__(self, n: int):
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def apply(self, iterator: Iterator[T]) -> Iterator[T]:
|
||||||
|
for i, item in enumerate(iterator):
|
||||||
|
if i >= self.n:
|
||||||
|
yield item
|
||||||
298
src/sqrtspace_spacetime/streams/stream.py
Normal file
298
src/sqrtspace_spacetime/streams/stream.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
"""
|
||||||
|
Memory-efficient streaming operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
Any, Callable, Dict, Iterable, Iterator, List, Optional,
|
||||||
|
TypeVar, Union, AsyncIterator, Tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqrtspace_spacetime.config import config
|
||||||
|
from sqrtspace_spacetime.streams.operators import (
|
||||||
|
MapOperator, FilterOperator, FlatMapOperator, ChunkOperator,
|
||||||
|
TakeOperator, SkipOperator
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
U = TypeVar('U')
|
||||||
|
|
||||||
|
|
||||||
|
class Stream(Iterable[T]):
|
||||||
|
"""
|
||||||
|
A lazy, memory-efficient stream for processing large datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, source: Union[Iterable[T], Iterator[T], Callable[[], Iterator[T]]]):
|
||||||
|
"""
|
||||||
|
Initialize stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: Data source (iterable, iterator, or callable returning iterator)
|
||||||
|
"""
|
||||||
|
if callable(source):
|
||||||
|
self._source = source
|
||||||
|
elif hasattr(source, '__iter__'):
|
||||||
|
self._source = lambda: iter(source)
|
||||||
|
else:
|
||||||
|
raise TypeError("Source must be iterable or callable")
|
||||||
|
|
||||||
|
self._operators: List[Any] = []
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[T]:
|
||||||
|
"""Create iterator with all operators applied."""
|
||||||
|
iterator = self._source()
|
||||||
|
|
||||||
|
# Apply operators in sequence
|
||||||
|
for op in self._operators:
|
||||||
|
iterator = op.apply(iterator)
|
||||||
|
|
||||||
|
return iterator
|
||||||
|
|
||||||
|
# Transformation operators
|
||||||
|
|
||||||
|
def map(self, func: Callable[[T], U]) -> 'Stream[U]':
|
||||||
|
"""Apply function to each element."""
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(MapOperator(func))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def filter(self, predicate: Callable[[T], bool]) -> 'Stream[T]':
|
||||||
|
"""Keep only elements matching predicate."""
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(FilterOperator(predicate))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def flat_map(self, func: Callable[[T], Iterable[U]]) -> 'Stream[U]':
|
||||||
|
"""Map each element to multiple elements."""
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(FlatMapOperator(func))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def chunk(self, size: Optional[int] = None) -> 'Stream[List[T]]':
|
||||||
|
"""Group elements into chunks."""
|
||||||
|
if size is None:
|
||||||
|
# Use √n chunking
|
||||||
|
# Since we don't know total size, use a reasonable default
|
||||||
|
size = 1000
|
||||||
|
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(ChunkOperator(size))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def take(self, n: int) -> 'Stream[T]':
|
||||||
|
"""Take first n elements."""
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(TakeOperator(n))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def skip(self, n: int) -> 'Stream[T]':
|
||||||
|
"""Skip first n elements."""
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(SkipOperator(n))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
def distinct(self) -> 'Stream[T]':
|
||||||
|
"""Remove duplicate elements."""
|
||||||
|
def distinct_op(iterator):
|
||||||
|
seen = set()
|
||||||
|
for item in iterator:
|
||||||
|
if item not in seen:
|
||||||
|
seen.add(item)
|
||||||
|
yield item
|
||||||
|
|
||||||
|
new_stream = Stream(self._source)
|
||||||
|
new_stream._operators = self._operators.copy()
|
||||||
|
new_stream._operators.append(lambda it: distinct_op(it))
|
||||||
|
return new_stream
|
||||||
|
|
||||||
|
# Terminal operators
|
||||||
|
|
||||||
|
def collect(self) -> List[T]:
|
||||||
|
"""Collect all elements into a list."""
|
||||||
|
return list(self)
|
||||||
|
|
||||||
|
def reduce(self, func: Callable[[U, T], U], initial: U) -> U:
|
||||||
|
"""Reduce stream to single value."""
|
||||||
|
result = initial
|
||||||
|
for item in self:
|
||||||
|
result = func(result, item)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Count elements."""
|
||||||
|
return sum(1 for _ in self)
|
||||||
|
|
||||||
|
def first(self) -> Optional[T]:
|
||||||
|
"""Get first element."""
|
||||||
|
for item in self:
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
def foreach(self, func: Callable[[T], None]) -> None:
|
||||||
|
"""Apply function to each element."""
|
||||||
|
for item in self:
|
||||||
|
func(item)
|
||||||
|
|
||||||
|
def group_by(self, key_func: Callable[[T], Any]) -> Dict[Any, List[T]]:
|
||||||
|
"""Group elements by key."""
|
||||||
|
from sqrtspace_spacetime.algorithms import external_groupby
|
||||||
|
return external_groupby(self, key_func)
|
||||||
|
|
||||||
|
def sort(self, key: Optional[Callable[[T], Any]] = None, reverse: bool = False) -> List[T]:
|
||||||
|
"""Sort elements."""
|
||||||
|
from sqrtspace_spacetime.algorithms import external_sort_key, external_sort
|
||||||
|
|
||||||
|
if key:
|
||||||
|
return external_sort_key(self, key=key, reverse=reverse)
|
||||||
|
else:
|
||||||
|
return external_sort(self, reverse=reverse)
|
||||||
|
|
||||||
|
def to_file(self, path: Union[str, Path], mode: str = 'w') -> None:
|
||||||
|
"""Write stream to file."""
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
with open(path, mode) as f:
|
||||||
|
for item in self:
|
||||||
|
f.write(str(item) + '\n')
|
||||||
|
|
||||||
|
def to_csv(self, path: Union[str, Path], headers: Optional[List[str]] = None) -> None:
|
||||||
|
"""Write stream to CSV file."""
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
with open(path, 'w', newline='') as f:
|
||||||
|
writer = None
|
||||||
|
|
||||||
|
for item in self:
|
||||||
|
if writer is None:
|
||||||
|
# Initialize writer based on first item
|
||||||
|
if isinstance(item, dict):
|
||||||
|
writer = csv.DictWriter(f, fieldnames=headers or item.keys())
|
||||||
|
if headers or item:
|
||||||
|
writer.writeheader()
|
||||||
|
else:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
if headers:
|
||||||
|
writer.writerow(headers)
|
||||||
|
|
||||||
|
if isinstance(item, dict):
|
||||||
|
writer.writerow(item)
|
||||||
|
elif isinstance(item, (list, tuple)):
|
||||||
|
writer.writerow(item)
|
||||||
|
else:
|
||||||
|
writer.writerow([item])
|
||||||
|
|
||||||
|
def to_jsonl(self, path: Union[str, Path]) -> None:
|
||||||
|
"""Write stream to JSON Lines file."""
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
for item in self:
|
||||||
|
f.write(json.dumps(item) + '\n')
|
||||||
|
|
||||||
|
# Async support
|
||||||
|
|
||||||
|
async def async_foreach(self, func: Callable[[T], Any]) -> None:
|
||||||
|
"""Apply async function to each element."""
|
||||||
|
for item in self:
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
await func(item)
|
||||||
|
else:
|
||||||
|
func(item)
|
||||||
|
|
||||||
|
# Factory methods
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_iterable(cls, iterable: Iterable[T]) -> 'Stream[T]':
|
||||||
|
"""Create stream from iterable."""
|
||||||
|
return cls(iterable)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, path: Union[str, Path], mode: str = 'r') -> 'Stream[str]':
|
||||||
|
"""Create stream from file."""
|
||||||
|
return FileStream(path, mode)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_csv(cls, path: Union[str, Path], headers: bool = True, **kwargs) -> 'Stream[Dict[str, Any]]':
|
||||||
|
"""Create stream from CSV file."""
|
||||||
|
return CSVStream(path, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_jsonl(cls, path: Union[str, Path]) -> 'Stream[Any]':
|
||||||
|
"""Create stream from JSON Lines file."""
|
||||||
|
return JSONLStream(path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def range(cls, *args) -> 'Stream[int]':
|
||||||
|
"""Create stream of integers."""
|
||||||
|
return cls(lambda: iter(range(*args)))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def infinite(cls, func: Callable[[], T]) -> 'Stream[T]':
|
||||||
|
"""Create infinite stream."""
|
||||||
|
def generator():
|
||||||
|
while True:
|
||||||
|
yield func()
|
||||||
|
return cls(generator)
|
||||||
|
|
||||||
|
|
||||||
|
class FileStream(Stream[str]):
|
||||||
|
"""Stream lines from a file."""
|
||||||
|
|
||||||
|
def __init__(self, path: Union[str, Path], mode: str = 'r', encoding: str = 'utf-8'):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.mode = mode
|
||||||
|
self.encoding = encoding
|
||||||
|
|
||||||
|
def file_iterator():
|
||||||
|
with open(self.path, self.mode, encoding=self.encoding) as f:
|
||||||
|
for line in f:
|
||||||
|
yield line.rstrip('\n\r')
|
||||||
|
|
||||||
|
super().__init__(file_iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVStream(Stream[Dict[str, Any]]):
|
||||||
|
"""Stream rows from CSV file."""
|
||||||
|
|
||||||
|
def __init__(self, path: Union[str, Path], headers: bool = True, **csv_kwargs):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.headers = headers
|
||||||
|
self.csv_kwargs = csv_kwargs
|
||||||
|
|
||||||
|
def csv_iterator():
|
||||||
|
with open(self.path, 'r', newline='') as f:
|
||||||
|
if self.headers:
|
||||||
|
reader = csv.DictReader(f, **self.csv_kwargs)
|
||||||
|
else:
|
||||||
|
reader = csv.reader(f, **self.csv_kwargs)
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
super().__init__(csv_iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONLStream(Stream[Any]):
|
||||||
|
"""Stream objects from JSON Lines file."""
|
||||||
|
|
||||||
|
def __init__(self, path: Union[str, Path]):
|
||||||
|
self.path = Path(path)
|
||||||
|
|
||||||
|
def jsonl_iterator():
|
||||||
|
with open(self.path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
super().__init__(jsonl_iterator)
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Ubiquity SpaceTime Test Suite
|
||||||
234
tests/test_external_algorithms.py
Normal file
234
tests/test_external_algorithms.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for external algorithms with memory pressure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import random
|
||||||
|
import gc
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
from sqrtspace_spacetime import external_sort, external_groupby, SpaceTimeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalAlgorithms(unittest.TestCase):
|
||||||
|
"""Test external algorithms under memory constraints."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test environment."""
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
memory_limit=100 * 1024 * 1024, # 100MB limit
|
||||||
|
chunk_strategy='sqrt_n'
|
||||||
|
)
|
||||||
|
self.process = psutil.Process()
|
||||||
|
|
||||||
|
def test_external_sort_small(self):
|
||||||
|
"""Test external sort with small dataset."""
|
||||||
|
data = [random.randint(1, 1000) for _ in range(1000)]
|
||||||
|
sorted_data = external_sort(data)
|
||||||
|
|
||||||
|
# Verify sorting
|
||||||
|
self.assertEqual(len(sorted_data), len(data))
|
||||||
|
for i in range(len(sorted_data) - 1):
|
||||||
|
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||||
|
|
||||||
|
# Verify all elements present
|
||||||
|
self.assertEqual(sorted(data), sorted_data)
|
||||||
|
|
||||||
|
def test_external_sort_large_with_memory_tracking(self):
|
||||||
|
"""Test external sort with large dataset and memory tracking."""
|
||||||
|
n = 1_000_000 # 1 million items
|
||||||
|
|
||||||
|
# Generate data
|
||||||
|
print(f"\nGenerating {n:,} random integers...")
|
||||||
|
data = [random.randint(1, 10_000_000) for _ in range(n)]
|
||||||
|
|
||||||
|
# Track memory before sorting
|
||||||
|
gc.collect()
|
||||||
|
memory_before = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
peak_memory = memory_before
|
||||||
|
|
||||||
|
# Sort with memory tracking
|
||||||
|
print("Sorting with external_sort...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Create a custom monitoring function
|
||||||
|
memory_samples = []
|
||||||
|
def monitor_memory():
|
||||||
|
current = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
memory_samples.append(current)
|
||||||
|
return current
|
||||||
|
|
||||||
|
# Sort data
|
||||||
|
sorted_data = external_sort(data)
|
||||||
|
|
||||||
|
# Measure final state
|
||||||
|
gc.collect()
|
||||||
|
memory_after = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Sample memory during verification
|
||||||
|
for i in range(0, len(sorted_data) - 1, 10000):
|
||||||
|
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||||
|
if i % 100000 == 0:
|
||||||
|
peak_memory = max(peak_memory, monitor_memory())
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
memory_increase = memory_after - memory_before
|
||||||
|
theoretical_sqrt_n = int(n ** 0.5)
|
||||||
|
|
||||||
|
print(f"\nExternal Sort Statistics:")
|
||||||
|
print(f" Items sorted: {n:,}")
|
||||||
|
print(f" Time taken: {elapsed:.2f} seconds")
|
||||||
|
print(f" Memory before: {memory_before:.1f} MB")
|
||||||
|
print(f" Memory after: {memory_after:.1f} MB")
|
||||||
|
print(f" Peak memory: {peak_memory:.1f} MB")
|
||||||
|
print(f" Memory increase: {memory_increase:.1f} MB")
|
||||||
|
print(f" Theoretical √n: {theoretical_sqrt_n:,} items")
|
||||||
|
print(f" Items per MB: {n / max(memory_increase, 0.1):,.0f}")
|
||||||
|
|
||||||
|
# Verify memory efficiency
|
||||||
|
# With 1M items, sqrt(n) = 1000, so memory should be much less than full dataset
|
||||||
|
self.assertLess(memory_increase, 50, f"Memory increase {memory_increase:.1f} MB is too high")
|
||||||
|
|
||||||
|
# Verify correctness on sample
|
||||||
|
sample_indices = random.sample(range(len(sorted_data) - 1), min(1000, len(sorted_data) - 1))
|
||||||
|
for i in sample_indices:
|
||||||
|
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||||
|
|
||||||
|
def test_external_groupby_memory_efficiency(self):
|
||||||
|
"""Test external groupby with memory tracking."""
|
||||||
|
n = 100_000
|
||||||
|
|
||||||
|
# Generate data with limited number of groups
|
||||||
|
print(f"\nGenerating {n:,} items for groupby...")
|
||||||
|
categories = [f"category_{i}" for i in range(100)]
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"category": random.choice(categories),
|
||||||
|
"value": random.randint(1, 1000),
|
||||||
|
"data": f"data_{i}" * 10 # Make items larger
|
||||||
|
}
|
||||||
|
for i in range(n)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track memory
|
||||||
|
gc.collect()
|
||||||
|
memory_before = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
# Group by category
|
||||||
|
print("Grouping by category...")
|
||||||
|
start_time = time.time()
|
||||||
|
grouped = external_groupby(data, key_func=lambda x: x["category"])
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure memory
|
||||||
|
gc.collect()
|
||||||
|
memory_after = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
memory_increase = memory_after - memory_before
|
||||||
|
|
||||||
|
print(f"\nExternal GroupBy Statistics:")
|
||||||
|
print(f" Items grouped: {n:,}")
|
||||||
|
print(f" Groups created: {len(grouped)}")
|
||||||
|
print(f" Time taken: {elapsed:.2f} seconds")
|
||||||
|
print(f" Memory increase: {memory_increase:.1f} MB")
|
||||||
|
print(f" Items per MB: {n / max(memory_increase, 0.1):,.0f}")
|
||||||
|
|
||||||
|
# Verify correctness
|
||||||
|
self.assertEqual(len(grouped), len(categories))
|
||||||
|
total_items = sum(len(group) for group in grouped.values())
|
||||||
|
self.assertEqual(total_items, n)
|
||||||
|
|
||||||
|
# Verify grouping
|
||||||
|
for category, items in grouped.items():
|
||||||
|
for item in items[:10]: # Check first 10 items in each group
|
||||||
|
self.assertEqual(item["category"], category)
|
||||||
|
|
||||||
|
# Memory should be reasonable
|
||||||
|
self.assertLess(memory_increase, 100, f"Memory increase {memory_increase:.1f} MB is too high")
|
||||||
|
|
||||||
|
def test_stress_test_combined_operations(self):
|
||||||
|
"""Stress test with combined operations."""
|
||||||
|
n = 50_000
|
||||||
|
|
||||||
|
print(f"\nRunning stress test with {n:,} items...")
|
||||||
|
|
||||||
|
# Generate complex data
|
||||||
|
data = []
|
||||||
|
for i in range(n):
|
||||||
|
data.append({
|
||||||
|
"id": i,
|
||||||
|
"group": f"group_{i % 50}",
|
||||||
|
"value": random.randint(1, 1000),
|
||||||
|
"score": random.random(),
|
||||||
|
"text": f"This is item {i} with some text" * 5
|
||||||
|
})
|
||||||
|
|
||||||
|
# Track initial memory
|
||||||
|
gc.collect()
|
||||||
|
initial_memory = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
# Operation 1: Group by
|
||||||
|
print(" 1. Grouping data...")
|
||||||
|
grouped = external_groupby(data, key_func=lambda x: x["group"])
|
||||||
|
|
||||||
|
# Operation 2: Sort each group
|
||||||
|
print(" 2. Sorting each group...")
|
||||||
|
for group_key, group_items in grouped.items():
|
||||||
|
# Sort by value
|
||||||
|
sorted_items = external_sort(
|
||||||
|
group_items,
|
||||||
|
key=lambda x: x["value"]
|
||||||
|
)
|
||||||
|
grouped[group_key] = sorted_items
|
||||||
|
|
||||||
|
# Operation 3: Extract top items from each group
|
||||||
|
print(" 3. Extracting top items...")
|
||||||
|
top_items = []
|
||||||
|
for group_items in grouped.values():
|
||||||
|
# Get top 10 by value
|
||||||
|
top_items.extend(group_items[-10:])
|
||||||
|
|
||||||
|
# Operation 4: Final sort
|
||||||
|
print(" 4. Final sort of top items...")
|
||||||
|
final_sorted = external_sort(
|
||||||
|
top_items,
|
||||||
|
key=lambda x: x["score"],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Measure final memory
|
||||||
|
gc.collect()
|
||||||
|
final_memory = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
total_memory_increase = final_memory - initial_memory
|
||||||
|
|
||||||
|
print(f"\nStress Test Results:")
|
||||||
|
print(f" Initial memory: {initial_memory:.1f} MB")
|
||||||
|
print(f" Final memory: {final_memory:.1f} MB")
|
||||||
|
print(f" Total increase: {total_memory_increase:.1f} MB")
|
||||||
|
print(f" Groups processed: {len(grouped)}")
|
||||||
|
print(f" Top items selected: {len(top_items)}")
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
self.assertEqual(len(grouped), 50) # 50 groups
|
||||||
|
self.assertEqual(len(top_items), 50 * 10) # Top 10 from each
|
||||||
|
self.assertEqual(len(final_sorted), len(top_items))
|
||||||
|
|
||||||
|
# Verify sorting
|
||||||
|
for i in range(len(final_sorted) - 1):
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
final_sorted[i]["score"],
|
||||||
|
final_sorted[i + 1]["score"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Memory should still be reasonable after all operations
|
||||||
|
self.assertLess(
|
||||||
|
total_memory_increase,
|
||||||
|
150,
|
||||||
|
f"Memory increase {total_memory_increase:.1f} MB is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
309
tests/test_memory_pressure.py
Normal file
309
tests/test_memory_pressure.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory pressure tests to verify √n behavior under constrained memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import resource
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from sqrtspace_spacetime import (
|
||||||
|
SpaceTimeArray, SpaceTimeDict, external_sort,
|
||||||
|
external_groupby, SpaceTimeConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryPressure(unittest.TestCase):
|
||||||
|
"""Test √n memory behavior under real memory constraints."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test environment."""
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.process = psutil.Process()
|
||||||
|
|
||||||
|
# Configure strict memory limits
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
storage_path=self.temp_dir,
|
||||||
|
memory_limit=50 * 1024 * 1024, # 50MB limit
|
||||||
|
chunk_strategy='sqrt_n',
|
||||||
|
compression='gzip'
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test environment."""
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
def test_array_under_memory_pressure(self):
|
||||||
|
"""Test SpaceTimeArray behavior when memory is constrained."""
|
||||||
|
print("\n=== Testing SpaceTimeArray under memory pressure ===")
|
||||||
|
|
||||||
|
# Create large objects that will force spillover
|
||||||
|
large_object_size = 1024 # 1KB per object
|
||||||
|
n_objects = 100_000 # Total: ~100MB if all in memory
|
||||||
|
|
||||||
|
array = SpaceTimeArray(threshold='auto')
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
spillovers = 0
|
||||||
|
max_memory = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Add objects and monitor memory
|
||||||
|
for i in range(n_objects):
|
||||||
|
# Create a large object
|
||||||
|
obj = {
|
||||||
|
'id': i,
|
||||||
|
'data': 'x' * large_object_size,
|
||||||
|
'timestamp': time.time()
|
||||||
|
}
|
||||||
|
array.append(obj)
|
||||||
|
|
||||||
|
# Monitor every 1000 items
|
||||||
|
if i % 1000 == 0:
|
||||||
|
gc.collect()
|
||||||
|
current_memory = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
max_memory = max(max_memory, current_memory)
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
hot_count = len(array._hot_data)
|
||||||
|
cold_count = len(array._cold_indices)
|
||||||
|
print(f" Items: {i:,} | Memory: {current_memory:.1f}MB | "
|
||||||
|
f"Hot: {hot_count} | Cold: {cold_count}")
|
||||||
|
|
||||||
|
# Check if spillover is happening
|
||||||
|
if cold_count > spillovers:
|
||||||
|
spillovers = cold_count
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Verify all data is accessible
|
||||||
|
print("\nVerifying data accessibility...")
|
||||||
|
sample_indices = random.sample(range(n_objects), min(100, n_objects))
|
||||||
|
for idx in sample_indices:
|
||||||
|
obj = array[idx]
|
||||||
|
self.assertEqual(obj['id'], idx)
|
||||||
|
self.assertEqual(len(obj['data']), large_object_size)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
theoretical_sqrt_n = int(n_objects ** 0.5)
|
||||||
|
actual_hot_items = len(array._hot_data)
|
||||||
|
|
||||||
|
print(f"\nResults:")
|
||||||
|
print(f" Total items: {n_objects:,}")
|
||||||
|
print(f" Time taken: {elapsed:.2f} seconds")
|
||||||
|
print(f" Max memory used: {max_memory:.1f} MB")
|
||||||
|
print(f" Theoretical √n: {theoretical_sqrt_n:,}")
|
||||||
|
print(f" Actual hot items: {actual_hot_items:,}")
|
||||||
|
print(f" Cold items: {len(array._cold_indices):,}")
|
||||||
|
print(f" Memory efficiency: {n_objects / max_memory:.0f} items/MB")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
self.assertEqual(len(array), n_objects)
|
||||||
|
self.assertLess(max_memory, 150) # Should use much less than 100MB
|
||||||
|
self.assertGreater(spillovers, 0) # Should have spilled to disk
|
||||||
|
self.assertLessEqual(actual_hot_items, theoretical_sqrt_n * 2) # Within 2x of √n
|
||||||
|
|
||||||
|
def test_dict_with_memory_limit(self):
|
||||||
|
"""Test SpaceTimeDict with strict memory limit."""
|
||||||
|
print("\n=== Testing SpaceTimeDict under memory pressure ===")
|
||||||
|
|
||||||
|
# Create dictionary with explicit threshold
|
||||||
|
cache = SpaceTimeDict(threshold=1000) # Keep only 1000 items in memory
|
||||||
|
|
||||||
|
n_items = 50_000
|
||||||
|
value_size = 500 # 500 bytes per value
|
||||||
|
|
||||||
|
# Track evictions
|
||||||
|
evictions = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Add items
|
||||||
|
for i in range(n_items):
|
||||||
|
key = f"key_{i:06d}"
|
||||||
|
value = {
|
||||||
|
'id': i,
|
||||||
|
'data': 'v' * value_size,
|
||||||
|
'accessed': 0
|
||||||
|
}
|
||||||
|
cache[key] = value
|
||||||
|
|
||||||
|
# Check for evictions
|
||||||
|
if i % 1000 == 0 and i > 0:
|
||||||
|
current_hot = len(cache._hot_data)
|
||||||
|
current_cold = len(cache._cold_keys)
|
||||||
|
if current_cold > evictions:
|
||||||
|
evictions = current_cold
|
||||||
|
print(f" Items: {i:,} | Hot: {current_hot} | Cold: {current_cold}")
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Test access patterns (LRU behavior)
|
||||||
|
print("\nTesting LRU behavior...")
|
||||||
|
# Access some old items
|
||||||
|
for i in range(0, 100, 10):
|
||||||
|
key = f"key_{i:06d}"
|
||||||
|
value = cache[key]
|
||||||
|
value['accessed'] += 1
|
||||||
|
|
||||||
|
# Add more items to trigger eviction
|
||||||
|
for i in range(n_items, n_items + 1000):
|
||||||
|
cache[f"key_{i:06d}"] = {'id': i, 'data': 'x' * value_size}
|
||||||
|
|
||||||
|
# Recent items should still be hot
|
||||||
|
stats = cache.get_stats()
|
||||||
|
|
||||||
|
print(f"\nResults:")
|
||||||
|
print(f" Total items: {len(cache):,}")
|
||||||
|
print(f" Time taken: {elapsed:.2f} seconds")
|
||||||
|
print(f" Hot items: {len(cache._hot_data)}")
|
||||||
|
print(f" Cold items: {len(cache._cold_keys)}")
|
||||||
|
print(f" Stats: {stats}")
|
||||||
|
|
||||||
|
# Verify all items accessible
|
||||||
|
sample_keys = random.sample([f"key_{i:06d}" for i in range(n_items)], 100)
|
||||||
|
for key in sample_keys:
|
||||||
|
self.assertIn(key, cache)
|
||||||
|
value = cache[key]
|
||||||
|
self.assertIsNotNone(value)
|
||||||
|
|
||||||
|
def test_algorithm_memory_scaling(self):
|
||||||
|
"""Test that algorithms scale with √n memory usage."""
|
||||||
|
print("\n=== Testing algorithm memory scaling ===")
|
||||||
|
|
||||||
|
datasets = [10_000, 40_000, 90_000, 160_000] # n, 4n, 9n, 16n
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for n in datasets:
|
||||||
|
print(f"\nTesting with n = {n:,}")
|
||||||
|
|
||||||
|
# Generate data
|
||||||
|
data = [random.randint(1, 1_000_000) for _ in range(n)]
|
||||||
|
|
||||||
|
# Measure memory for sorting
|
||||||
|
gc.collect()
|
||||||
|
mem_before = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
sorted_data = external_sort(data)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
mem_after = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
mem_used = mem_after - mem_before
|
||||||
|
|
||||||
|
# Verify correctness
|
||||||
|
self.assertEqual(len(sorted_data), n)
|
||||||
|
for i in range(min(1000, len(sorted_data) - 1)):
|
||||||
|
self.assertLessEqual(sorted_data[i], sorted_data[i + 1])
|
||||||
|
|
||||||
|
sqrt_n = int(n ** 0.5)
|
||||||
|
results.append({
|
||||||
|
'n': n,
|
||||||
|
'sqrt_n': sqrt_n,
|
||||||
|
'memory_used': mem_used,
|
||||||
|
'ratio': mem_used / max(sqrt_n * 8 / 1024 / 1024, 0.001) # 8 bytes per int
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f" √n = {sqrt_n:,}")
|
||||||
|
print(f" Memory used: {mem_used:.2f} MB")
|
||||||
|
print(f" Ratio to theoretical: {results[-1]['ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Verify √n scaling
|
||||||
|
print("\nScaling Analysis:")
|
||||||
|
print("n | √n | Memory (MB) | Ratio")
|
||||||
|
print("---------|---------|-------------|-------")
|
||||||
|
for r in results:
|
||||||
|
print(f"{r['n']:8,} | {r['sqrt_n']:7,} | {r['memory_used']:11.2f} | {r['ratio']:6.2f}x")
|
||||||
|
|
||||||
|
# Memory should scale roughly with √n
|
||||||
|
# As n increases 4x, memory should increase ~2x
|
||||||
|
for i in range(1, len(results)):
|
||||||
|
n_ratio = results[i]['n'] / results[i-1]['n']
|
||||||
|
mem_ratio = results[i]['memory_used'] / max(results[i-1]['memory_used'], 0.1)
|
||||||
|
expected_ratio = n_ratio ** 0.5
|
||||||
|
|
||||||
|
print(f"\nn increased {n_ratio:.1f}x, memory increased {mem_ratio:.1f}x "
|
||||||
|
f"(expected ~{expected_ratio:.1f}x)")
|
||||||
|
|
||||||
|
# Allow some variance due to overheads
|
||||||
|
self.assertLess(mem_ratio, expected_ratio * 3,
|
||||||
|
f"Memory scaling worse than √n: {mem_ratio:.1f}x vs {expected_ratio:.1f}x")
|
||||||
|
|
||||||
|
def test_concurrent_memory_pressure(self):
|
||||||
|
"""Test behavior under concurrent access with memory pressure."""
|
||||||
|
print("\n=== Testing concurrent access under memory pressure ===")
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
|
array = SpaceTimeArray(threshold=500)
|
||||||
|
errors = queue.Queue()
|
||||||
|
n_threads = 4
|
||||||
|
items_per_thread = 25_000
|
||||||
|
|
||||||
|
def worker(thread_id, start_idx):
|
||||||
|
try:
|
||||||
|
for i in range(items_per_thread):
|
||||||
|
item = {
|
||||||
|
'thread': thread_id,
|
||||||
|
'index': start_idx + i,
|
||||||
|
'data': f"thread_{thread_id}_item_{i}" * 50
|
||||||
|
}
|
||||||
|
array.append(item)
|
||||||
|
|
||||||
|
# Occasionally read random items
|
||||||
|
if i % 100 == 0 and len(array) > 10:
|
||||||
|
idx = random.randint(0, len(array) - 1)
|
||||||
|
_ = array[idx]
|
||||||
|
except Exception as e:
|
||||||
|
errors.put((thread_id, str(e)))
|
||||||
|
|
||||||
|
# Start threads
|
||||||
|
threads = []
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for i in range(n_threads):
|
||||||
|
t = threading.Thread(
|
||||||
|
target=worker,
|
||||||
|
args=(i, i * items_per_thread)
|
||||||
|
)
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# Monitor memory while threads run
|
||||||
|
max_memory = 0
|
||||||
|
while any(t.is_alive() for t in threads):
|
||||||
|
current_memory = self.process.memory_info().rss / 1024 / 1024
|
||||||
|
max_memory = max(max_memory, current_memory)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
error_list = []
|
||||||
|
while not errors.empty():
|
||||||
|
error_list.append(errors.get())
|
||||||
|
|
||||||
|
print(f"\nResults:")
|
||||||
|
print(f" Threads: {n_threads}")
|
||||||
|
print(f" Total items: {n_threads * items_per_thread:,}")
|
||||||
|
print(f" Time taken: {elapsed:.2f} seconds")
|
||||||
|
print(f" Max memory: {max_memory:.1f} MB")
|
||||||
|
print(f" Errors: {len(error_list)}")
|
||||||
|
print(f" Final array size: {len(array):,}")
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
self.assertEqual(len(error_list), 0, f"Thread errors: {error_list}")
|
||||||
|
self.assertEqual(len(array), n_threads * items_per_thread)
|
||||||
|
self.assertLess(max_memory, 200) # Should handle memory pressure
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
202
tests/test_spacetime_array.py
Normal file
202
tests/test_spacetime_array.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for SpaceTimeArray with memory pressure simulation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import psutil
|
||||||
|
from sqrtspace_spacetime import SpaceTimeArray, SpaceTimeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceTimeArray(unittest.TestCase):
|
||||||
|
"""Test SpaceTimeArray functionality."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test environment."""
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
SpaceTimeConfig.set_defaults(
|
||||||
|
storage_path=self.temp_dir,
|
||||||
|
memory_limit=50 * 1024 * 1024, # 50MB for testing
|
||||||
|
chunk_strategy='sqrt_n'
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up test environment."""
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
def test_basic_operations(self):
|
||||||
|
"""Test basic array operations."""
|
||||||
|
array = SpaceTimeArray(threshold=100)
|
||||||
|
|
||||||
|
# Test append
|
||||||
|
for i in range(50):
|
||||||
|
array.append(f"item_{i}")
|
||||||
|
|
||||||
|
self.assertEqual(len(array), 50)
|
||||||
|
self.assertEqual(array[0], "item_0")
|
||||||
|
self.assertEqual(array[49], "item_49")
|
||||||
|
|
||||||
|
# Test negative indexing
|
||||||
|
self.assertEqual(array[-1], "item_49")
|
||||||
|
self.assertEqual(array[-50], "item_0")
|
||||||
|
|
||||||
|
# Test slice
|
||||||
|
slice_result = array[10:20]
|
||||||
|
self.assertEqual(len(slice_result), 10)
|
||||||
|
self.assertEqual(slice_result[0], "item_10")
|
||||||
|
|
||||||
|
def test_automatic_spillover(self):
|
||||||
|
"""Test automatic spillover to disk."""
|
||||||
|
# Create array with small threshold
|
||||||
|
array = SpaceTimeArray(threshold=10)
|
||||||
|
|
||||||
|
# Add more items than threshold
|
||||||
|
for i in range(100):
|
||||||
|
array.append(f"value_{i}")
|
||||||
|
|
||||||
|
# Check that spillover happened
|
||||||
|
self.assertEqual(len(array), 100)
|
||||||
|
self.assertGreater(len(array._cold_indices), 0)
|
||||||
|
self.assertLessEqual(len(array._hot_data), array.threshold)
|
||||||
|
|
||||||
|
# Verify all items are accessible
|
||||||
|
for i in range(100):
|
||||||
|
self.assertEqual(array[i], f"value_{i}")
|
||||||
|
|
||||||
|
def test_memory_pressure_handling(self):
|
||||||
|
"""Test behavior under memory pressure."""
|
||||||
|
# Create array with auto threshold
|
||||||
|
array = SpaceTimeArray()
|
||||||
|
|
||||||
|
# Generate large data items
|
||||||
|
large_item = "x" * 10000 # 10KB string
|
||||||
|
|
||||||
|
# Add items until memory pressure detected
|
||||||
|
for i in range(1000):
|
||||||
|
array.append(f"{large_item}_{i}")
|
||||||
|
|
||||||
|
# Check memory usage periodically
|
||||||
|
if i % 100 == 0:
|
||||||
|
process = psutil.Process()
|
||||||
|
memory_mb = process.memory_info().rss / 1024 / 1024
|
||||||
|
# Ensure we're not using excessive memory
|
||||||
|
self.assertLess(memory_mb, 200, f"Memory usage too high at iteration {i}")
|
||||||
|
|
||||||
|
# Verify all items still accessible
|
||||||
|
self.assertEqual(len(array), 1000)
|
||||||
|
self.assertTrue(array[0].endswith("_0"))
|
||||||
|
self.assertTrue(array[999].endswith("_999"))
|
||||||
|
|
||||||
|
def test_large_dataset_sqrt_n_memory(self):
|
||||||
|
"""Test √n memory usage with large dataset."""
|
||||||
|
# Configure for sqrt_n strategy
|
||||||
|
SpaceTimeConfig.set_defaults(chunk_strategy='sqrt_n')
|
||||||
|
|
||||||
|
n = 10000 # Total items
|
||||||
|
sqrt_n = int(n ** 0.5) # Expected memory items
|
||||||
|
|
||||||
|
array = SpaceTimeArray()
|
||||||
|
|
||||||
|
# Track initial memory
|
||||||
|
gc.collect()
|
||||||
|
process = psutil.Process()
|
||||||
|
initial_memory = process.memory_info().rss
|
||||||
|
|
||||||
|
# Add n items
|
||||||
|
for i in range(n):
|
||||||
|
array.append({"id": i, "data": f"item_{i}" * 10})
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Check memory usage
|
||||||
|
final_memory = process.memory_info().rss
|
||||||
|
memory_increase_mb = (final_memory - initial_memory) / 1024 / 1024
|
||||||
|
|
||||||
|
# Verify sqrt_n behavior
|
||||||
|
self.assertEqual(len(array), n)
|
||||||
|
self.assertLessEqual(len(array._hot_data), sqrt_n * 2) # Allow some buffer
|
||||||
|
self.assertGreater(len(array._cold_indices), n - sqrt_n * 2)
|
||||||
|
|
||||||
|
# Memory should be much less than storing all items
|
||||||
|
# Rough estimate: each item ~100 bytes, so n items = ~1MB
|
||||||
|
# With sqrt_n, should use ~10KB in memory
|
||||||
|
self.assertLess(memory_increase_mb, 10, f"Memory increase {memory_increase_mb}MB is too high")
|
||||||
|
|
||||||
|
# Verify random access still works
|
||||||
|
import random
|
||||||
|
for _ in range(100):
|
||||||
|
idx = random.randint(0, n - 1)
|
||||||
|
self.assertEqual(array[idx]["id"], idx)
|
||||||
|
|
||||||
|
def test_persistence_across_sessions(self):
|
||||||
|
"""Test data persistence when array is recreated."""
|
||||||
|
storage_path = os.path.join(self.temp_dir, "persist_test")
|
||||||
|
|
||||||
|
# Create and populate array
|
||||||
|
array1 = SpaceTimeArray(threshold=10, storage_path=storage_path)
|
||||||
|
for i in range(50):
|
||||||
|
array1.append(f"persistent_{i}")
|
||||||
|
|
||||||
|
# Force spillover
|
||||||
|
array1._check_and_spill()
|
||||||
|
del array1
|
||||||
|
|
||||||
|
# Create new array with same storage path
|
||||||
|
array2 = SpaceTimeArray(threshold=10, storage_path=storage_path)
|
||||||
|
|
||||||
|
# Data should be accessible
|
||||||
|
self.assertEqual(len(array2), 50)
|
||||||
|
for i in range(50):
|
||||||
|
self.assertEqual(array2[i], f"persistent_{i}")
|
||||||
|
|
||||||
|
def test_concurrent_access(self):
|
||||||
|
"""Test thread-safe access to array."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
array = SpaceTimeArray(threshold=100)
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def writer(start, count):
|
||||||
|
try:
|
||||||
|
for i in range(start, start + count):
|
||||||
|
array.append(f"thread_{i}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
def reader(count):
|
||||||
|
try:
|
||||||
|
for _ in range(count):
|
||||||
|
if len(array) > 0:
|
||||||
|
_ = array[0] # Just access, don't verify
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
# Create threads
|
||||||
|
threads = []
|
||||||
|
for i in range(5):
|
||||||
|
t = threading.Thread(target=writer, args=(i * 100, 100))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=reader, args=(50,))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
# Run threads
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
self.assertEqual(len(errors), 0, f"Thread errors: {errors}")
|
||||||
|
self.assertEqual(len(array), 500)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user