Files
2025-07-21 18:06:37 -04:00

342 lines
13 KiB
Python

#!/usr/bin/env python3
"""
LLM Space-Time Tradeoff Experiments using Ollama
Demonstrates real-world space-time tradeoffs in LLM inference:
1. Context window chunking (√n chunks)
2. Streaming vs full generation
3. Checkpointing for long generations
"""
import json
import time
import psutil
import requests
import numpy as np
from typing import List, Dict, Tuple
import argparse
import sys
import os
# Ollama API endpoint
OLLAMA_API = "http://localhost:11434/api"
def get_process_memory():
"""Get current process memory usage in MB"""
return psutil.Process().memory_info().rss / 1024 / 1024
def generate_with_ollama(model: str, prompt: str, stream: bool = False) -> Tuple[str, float]:
"""Generate text using Ollama API"""
url = f"{OLLAMA_API}/generate"
data = {
"model": model,
"prompt": prompt,
"stream": stream
}
start_time = time.time()
response = requests.post(url, json=data, stream=stream)
if stream:
full_response = ""
for line in response.iter_lines():
if line:
chunk = json.loads(line)
if "response" in chunk:
full_response += chunk["response"]
result = full_response
else:
result = response.json()["response"]
elapsed = time.time() - start_time
return result, elapsed
def chunked_context_processing(model: str, long_text: str, chunk_size: int) -> Dict:
"""Process long context in chunks vs all at once"""
print(f"\n=== Chunked Context Processing ===")
print(f"Total context length: {len(long_text)} chars")
print(f"Chunk size: {chunk_size} chars")
results = {}
# Method 1: Process entire context at once
print("\nMethod 1: Full context (O(n) memory)")
prompt_full = f"Summarize the following text:\n\n{long_text}\n\nSummary:"
mem_before = get_process_memory()
summary_full, time_full = generate_with_ollama(model, prompt_full)
mem_after = get_process_memory()
results["full_context"] = {
"time": time_full,
"memory_delta": mem_after - mem_before,
"summary_length": len(summary_full)
}
print(f"Time: {time_full:.2f}s, Memory delta: {mem_after - mem_before:.2f}MB")
# Method 2: Process in √n chunks
print(f"\nMethod 2: Chunked processing (O(√n) memory)")
chunks = [long_text[i:i+chunk_size] for i in range(0, len(long_text), chunk_size)]
chunk_summaries = []
mem_before = get_process_memory()
time_start = time.time()
for i, chunk in enumerate(chunks):
prompt_chunk = f"Summarize this text fragment:\n\n{chunk}\n\nSummary:"
summary, _ = generate_with_ollama(model, prompt_chunk)
chunk_summaries.append(summary)
print(f" Processed chunk {i+1}/{len(chunks)}")
# Combine chunk summaries
combined_prompt = f"Combine these summaries into one:\n\n" + "\n\n".join(chunk_summaries) + "\n\nCombined summary:"
final_summary, _ = generate_with_ollama(model, combined_prompt)
time_chunked = time.time() - time_start
mem_after = get_process_memory()
results["chunked_context"] = {
"time": time_chunked,
"memory_delta": mem_after - mem_before,
"summary_length": len(final_summary),
"num_chunks": len(chunks),
"chunk_size": chunk_size
}
print(f"Time: {time_chunked:.2f}s, Memory delta: {mem_after - mem_before:.2f}MB")
print(f"Slowdown: {time_chunked/time_full:.2f}x")
return results
def streaming_vs_full_generation(model: str, prompt: str, num_tokens: int = 200) -> Dict:
"""Compare streaming vs full generation"""
print(f"\n=== Streaming vs Full Generation ===")
print(f"Generating ~{num_tokens} tokens")
results = {}
# Create a prompt that generates substantial output
generation_prompt = prompt + "\n\nWrite a detailed explanation (at least 200 words):"
# Method 1: Full generation (O(n) memory for response)
print("\nMethod 1: Full generation")
mem_before = get_process_memory()
response_full, time_full = generate_with_ollama(model, generation_prompt, stream=False)
mem_after = get_process_memory()
results["full_generation"] = {
"time": time_full,
"memory_delta": mem_after - mem_before,
"response_length": len(response_full),
"estimated_tokens": len(response_full.split())
}
print(f"Time: {time_full:.2f}s, Memory delta: {mem_after - mem_before:.2f}MB")
# Method 2: Streaming generation (O(1) memory)
print("\nMethod 2: Streaming generation")
mem_before = get_process_memory()
response_stream, time_stream = generate_with_ollama(model, generation_prompt, stream=True)
mem_after = get_process_memory()
results["streaming_generation"] = {
"time": time_stream,
"memory_delta": mem_after - mem_before,
"response_length": len(response_stream),
"estimated_tokens": len(response_stream.split())
}
print(f"Time: {time_stream:.2f}s, Memory delta: {mem_after - mem_before:.2f}MB")
return results
def checkpointed_generation(model: str, prompts: List[str], checkpoint_interval: int) -> Dict:
"""Simulate checkpointed generation for multiple prompts"""
print(f"\n=== Checkpointed Generation ===")
print(f"Processing {len(prompts)} prompts")
print(f"Checkpoint interval: {checkpoint_interval}")
results = {}
# Method 1: Process all prompts without checkpointing
print("\nMethod 1: No checkpointing")
responses_full = []
mem_before = get_process_memory()
time_start = time.time()
for i, prompt in enumerate(prompts):
response, _ = generate_with_ollama(model, prompt)
responses_full.append(response)
print(f" Processed prompt {i+1}/{len(prompts)}")
time_full = time.time() - time_start
mem_after = get_process_memory()
results["no_checkpoint"] = {
"time": time_full,
"memory_delta": mem_after - mem_before,
"total_responses": len(responses_full),
"avg_response_length": np.mean([len(r) for r in responses_full])
}
# Method 2: Process with checkpointing (simulate by clearing responses)
print(f"\nMethod 2: Checkpointing every {checkpoint_interval} prompts")
responses_checkpoint = []
checkpoint_data = []
mem_before = get_process_memory()
time_start = time.time()
for i, prompt in enumerate(prompts):
response, _ = generate_with_ollama(model, prompt)
responses_checkpoint.append(response)
# Simulate checkpoint: save and clear memory
if (i + 1) % checkpoint_interval == 0:
checkpoint_data.append({
"index": i,
"responses": responses_checkpoint.copy()
})
responses_checkpoint = [] # Clear to save memory
print(f" Checkpoint at prompt {i+1}")
else:
print(f" Processed prompt {i+1}/{len(prompts)}")
# Final checkpoint for remaining
if responses_checkpoint:
checkpoint_data.append({
"index": len(prompts) - 1,
"responses": responses_checkpoint
})
time_checkpoint = time.time() - time_start
mem_after = get_process_memory()
# Reconstruct all responses from checkpoints
all_responses = []
for checkpoint in checkpoint_data:
all_responses.extend(checkpoint["responses"])
results["with_checkpoint"] = {
"time": time_checkpoint,
"memory_delta": mem_after - mem_before,
"total_responses": len(all_responses),
"avg_response_length": np.mean([len(r) for r in all_responses]),
"num_checkpoints": len(checkpoint_data),
"checkpoint_interval": checkpoint_interval
}
print(f"\nTime comparison:")
print(f" No checkpoint: {time_full:.2f}s")
print(f" With checkpoint: {time_checkpoint:.2f}s")
print(f" Overhead: {(time_checkpoint/time_full - 1)*100:.1f}%")
return results
def run_all_experiments(model: str = "llama3.2:latest"):
"""Run all space-time tradeoff experiments"""
print(f"Using model: {model}")
# Check if model is available
try:
test_response = requests.post(f"{OLLAMA_API}/generate",
json={"model": model, "prompt": "test", "stream": False})
if test_response.status_code != 200:
print(f"Error: Model {model} not available. Please pull it first with: ollama pull {model}")
return
except:
print("Error: Cannot connect to Ollama. Make sure it's running with: ollama serve")
return
all_results = {
"model": model,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"experiments": {}
}
# Experiment 1: Context chunking
# Create a long text by repeating a passage
base_text = """The quick brown fox jumps over the lazy dog. This pangram contains every letter of the alphabet.
It has been used for decades to test typewriters and computer keyboards. The sentence is memorable and
helps identify any malfunctioning keys. Many variations exist in different languages."""
long_text = (base_text + " ") * 50 # ~10KB of text
chunk_size = int(np.sqrt(len(long_text))) # √n chunk size
context_results = chunked_context_processing(model, long_text, chunk_size)
all_results["experiments"]["context_chunking"] = context_results
# Experiment 2: Streaming vs full generation
prompt = "Explain the concept of space-time tradeoffs in computer science."
streaming_results = streaming_vs_full_generation(model, prompt)
all_results["experiments"]["streaming"] = streaming_results
# Experiment 3: Checkpointed generation
prompts = [
"What is machine learning?",
"Explain neural networks.",
"What is deep learning?",
"Describe transformer models.",
"What is attention mechanism?",
"Explain BERT architecture.",
"What is GPT?",
"Describe fine-tuning.",
"What is transfer learning?",
"Explain few-shot learning."
]
checkpoint_interval = int(np.sqrt(len(prompts))) # √n checkpoint interval
checkpoint_results = checkpointed_generation(model, prompts, checkpoint_interval)
all_results["experiments"]["checkpointing"] = checkpoint_results
# Save results
with open("ollama_experiment_results.json", "w") as f:
json.dump(all_results, f, indent=2)
print("\n=== Summary ===")
print(f"Results saved to ollama_experiment_results.json")
# Print summary
print("\n1. Context Chunking:")
if "context_chunking" in all_results["experiments"]:
full = all_results["experiments"]["context_chunking"]["full_context"]
chunked = all_results["experiments"]["context_chunking"]["chunked_context"]
print(f" Full context: {full['time']:.2f}s, {full['memory_delta']:.2f}MB")
print(f" Chunked (√n): {chunked['time']:.2f}s, {chunked['memory_delta']:.2f}MB")
print(f" Slowdown: {chunked['time']/full['time']:.2f}x")
print(f" Memory reduction: {(1 - chunked['memory_delta']/max(full['memory_delta'], 0.1))*100:.1f}%")
print("\n2. Streaming Generation:")
if "streaming" in all_results["experiments"]:
full = all_results["experiments"]["streaming"]["full_generation"]
stream = all_results["experiments"]["streaming"]["streaming_generation"]
print(f" Full generation: {full['time']:.2f}s, {full['memory_delta']:.2f}MB")
print(f" Streaming: {stream['time']:.2f}s, {stream['memory_delta']:.2f}MB")
print("\n3. Checkpointing:")
if "checkpointing" in all_results["experiments"]:
no_ckpt = all_results["experiments"]["checkpointing"]["no_checkpoint"]
with_ckpt = all_results["experiments"]["checkpointing"]["with_checkpoint"]
print(f" No checkpoint: {no_ckpt['time']:.2f}s, {no_ckpt['memory_delta']:.2f}MB")
print(f" With checkpoint: {with_ckpt['time']:.2f}s, {with_ckpt['memory_delta']:.2f}MB")
print(f" Time overhead: {(with_ckpt['time']/no_ckpt['time'] - 1)*100:.1f}%")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM Space-Time Tradeoff Experiments")
parser.add_argument("--model", default="llama3.2:latest", help="Ollama model to use")
parser.add_argument("--experiment", choices=["all", "context", "streaming", "checkpoint"],
default="all", help="Which experiment to run")
args = parser.parse_args()
if args.experiment == "all":
run_all_experiments(args.model)
else:
print(f"Running {args.experiment} experiment with {args.model}")
# Run specific experiment
if args.experiment == "context":
base_text = "The quick brown fox jumps over the lazy dog. " * 100
results = chunked_context_processing(args.model, base_text, int(np.sqrt(len(base_text))))
elif args.experiment == "streaming":
results = streaming_vs_full_generation(args.model, "Explain AI in detail.")
elif args.experiment == "checkpoint":
prompts = [f"Explain concept {i}" for i in range(10)]
results = checkpointed_generation(args.model, prompts, 3)
print(f"\nResults: {json.dumps(results, indent=2)}")