363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""
|
|
LLM KV-Cache Space-Time Tradeoff Experiment
|
|
|
|
Demonstrates how KV-cache size affects transformer inference time,
|
|
showing Williams' √n pattern in modern AI systems.
|
|
|
|
This simulates the core attention mechanism where:
|
|
- Full KV-cache (O(n)): Store all past tokens' keys/values
|
|
- Sliding window (O(√n)): Keep only recent √n tokens
|
|
- Minimal cache (O(1)): Recompute everything
|
|
|
|
Based on Flash Attention and similar optimizations used in production LLMs.
|
|
"""
|
|
|
|
import numpy as np
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
from typing import Dict, List, Tuple
|
|
import json
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class AttentionConfig:
|
|
"""Configuration for attention mechanism"""
|
|
seq_length: int # Total sequence length
|
|
hidden_dim: int # Model dimension (d_model)
|
|
num_heads: int # Number of attention heads
|
|
head_dim: int # Dimension per head
|
|
batch_size: int = 1 # Batch size
|
|
|
|
def __post_init__(self):
|
|
assert self.hidden_dim == self.num_heads * self.head_dim
|
|
|
|
class TransformerAttention:
|
|
"""Simplified transformer attention with configurable KV-cache"""
|
|
|
|
def __init__(self, config: AttentionConfig):
|
|
self.config = config
|
|
|
|
# Initialize weights (random for simulation)
|
|
self.W_q = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02
|
|
self.W_k = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02
|
|
self.W_v = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02
|
|
self.W_o = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02
|
|
|
|
def compute_attention(self,
|
|
query_pos: int,
|
|
hidden_states: np.ndarray,
|
|
kv_cache_size: int) -> Tuple[np.ndarray, Dict]:
|
|
"""
|
|
Compute attention for position query_pos with limited KV-cache
|
|
|
|
Args:
|
|
query_pos: Current token position
|
|
hidden_states: All hidden states up to query_pos
|
|
kv_cache_size: Maximum number of past tokens to cache
|
|
|
|
Returns:
|
|
attention_output: Output for the query position
|
|
stats: Performance statistics
|
|
"""
|
|
stats = {
|
|
'cache_size': kv_cache_size,
|
|
'recompute_steps': 0,
|
|
'cache_hits': 0,
|
|
'memory_used': 0
|
|
}
|
|
|
|
# Get query vector for current position
|
|
query = hidden_states[query_pos:query_pos+1] # [1, hidden_dim]
|
|
Q = query @ self.W_q # [1, hidden_dim]
|
|
|
|
# Reshape for multi-head attention
|
|
Q = Q.reshape(1, self.config.num_heads, self.config.head_dim)
|
|
|
|
# Determine which positions to attend to
|
|
if kv_cache_size >= query_pos:
|
|
# Full cache - use all previous positions
|
|
start_pos = 0
|
|
cached_positions = query_pos
|
|
stats['cache_hits'] = query_pos
|
|
else:
|
|
# Limited cache - use only recent positions
|
|
start_pos = max(0, query_pos - kv_cache_size)
|
|
cached_positions = min(kv_cache_size, query_pos)
|
|
stats['cache_hits'] = cached_positions
|
|
stats['recompute_steps'] = query_pos - cached_positions
|
|
|
|
# Get relevant hidden states
|
|
relevant_hidden = hidden_states[start_pos:query_pos+1]
|
|
|
|
# Compute keys and values (this is what we cache/recompute)
|
|
start_time = time.time()
|
|
K = relevant_hidden @ self.W_k # [seq_len, hidden_dim]
|
|
V = relevant_hidden @ self.W_v
|
|
compute_time = time.time() - start_time
|
|
|
|
# Reshape for multi-head
|
|
seq_len = K.shape[0]
|
|
K = K.reshape(seq_len, self.config.num_heads, self.config.head_dim)
|
|
V = V.reshape(seq_len, self.config.num_heads, self.config.head_dim)
|
|
|
|
# Compute attention scores
|
|
scores = np.einsum('qhd,khd->hqk', Q, K) / np.sqrt(self.config.head_dim)
|
|
|
|
# Apply causal mask if needed
|
|
if start_pos > 0:
|
|
# Mask out positions we can't see due to limited cache
|
|
mask = np.ones_like(scores)
|
|
scores = scores * mask
|
|
|
|
# Softmax
|
|
attn_weights = self._softmax(scores, axis=-1)
|
|
|
|
# Apply attention to values
|
|
attn_output = np.einsum('hqk,khd->qhd', attn_weights, V)
|
|
|
|
# Reshape and project
|
|
attn_output = attn_output.reshape(1, self.config.hidden_dim)
|
|
output = attn_output @ self.W_o
|
|
|
|
# Calculate memory usage
|
|
stats['memory_used'] = (
|
|
2 * cached_positions * self.config.hidden_dim * 4 # K and V cache in bytes
|
|
)
|
|
stats['compute_time'] = compute_time
|
|
|
|
return output, stats
|
|
|
|
def _softmax(self, x, axis=-1):
|
|
"""Numerically stable softmax"""
|
|
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
|
return e_x / np.sum(e_x, axis=axis, keepdims=True)
|
|
|
|
def generate_sequence(self,
|
|
prompt_length: int,
|
|
generation_length: int,
|
|
kv_cache_size: int) -> Dict:
|
|
"""
|
|
Simulate autoregressive generation with limited KV-cache
|
|
|
|
This mimics how LLMs generate text token by token
|
|
"""
|
|
total_length = prompt_length + generation_length
|
|
hidden_dim = self.config.hidden_dim
|
|
|
|
# Initialize with random hidden states (simulating embeddings)
|
|
hidden_states = np.random.randn(total_length, hidden_dim) * 0.1
|
|
|
|
total_stats = {
|
|
'total_time': 0,
|
|
'total_memory': 0,
|
|
'total_recomputes': 0,
|
|
'per_token_times': []
|
|
}
|
|
|
|
# Process prompt (can use full attention)
|
|
start_time = time.time()
|
|
for pos in range(prompt_length):
|
|
_, stats = self.compute_attention(pos, hidden_states, kv_cache_size)
|
|
prompt_time = time.time() - start_time
|
|
|
|
# Generate new tokens
|
|
generation_times = []
|
|
for pos in range(prompt_length, total_length):
|
|
start = time.time()
|
|
output, stats = self.compute_attention(pos, hidden_states, kv_cache_size)
|
|
token_time = time.time() - start
|
|
|
|
generation_times.append(token_time)
|
|
total_stats['total_recomputes'] += stats['recompute_steps']
|
|
total_stats['total_memory'] = max(total_stats['total_memory'],
|
|
stats['memory_used'])
|
|
|
|
# Simulate token generation (would normally sample from logits)
|
|
hidden_states[pos] = output[0]
|
|
|
|
total_stats['total_time'] = sum(generation_times) + prompt_time
|
|
total_stats['avg_token_time'] = np.mean(generation_times) if generation_times else 0
|
|
total_stats['prompt_time'] = prompt_time
|
|
total_stats['generation_time'] = sum(generation_times)
|
|
total_stats['tokens_per_second'] = generation_length / sum(generation_times) if generation_times else 0
|
|
|
|
return total_stats
|
|
|
|
def run_llm_experiment():
|
|
"""Run comprehensive LLM KV-cache experiment"""
|
|
|
|
print("="*60)
|
|
print("LLM KV-Cache Space-Time Tradeoff Experiment")
|
|
print("Simulating transformer attention with different cache sizes")
|
|
print("="*60)
|
|
|
|
# Model configuration (similar to GPT-2 small)
|
|
config = AttentionConfig(
|
|
seq_length=2048, # Max sequence length
|
|
hidden_dim=768, # Model dimension
|
|
num_heads=12, # Attention heads
|
|
head_dim=64, # Dimension per head
|
|
batch_size=1
|
|
)
|
|
|
|
model = TransformerAttention(config)
|
|
|
|
# Test different sequence lengths
|
|
test_lengths = [512, 1024, 2048]
|
|
results = {}
|
|
|
|
for seq_len in test_lengths:
|
|
print(f"\n{'='*40}")
|
|
print(f"Testing sequence length: {seq_len}")
|
|
print(f"{'='*40}")
|
|
|
|
# Different KV-cache configurations
|
|
cache_configs = [
|
|
('Full O(n)', seq_len), # Full attention
|
|
('Flash O(√n)', int(np.sqrt(seq_len) * 4)), # Flash Attention-like
|
|
('Minimal O(1)', 8), # Almost no cache
|
|
]
|
|
|
|
seq_results = []
|
|
|
|
for label, cache_size in cache_configs:
|
|
print(f"\n{label}: {cache_size} tokens cached")
|
|
|
|
# Run multiple trials
|
|
trials = []
|
|
num_trials = 5
|
|
|
|
for trial in range(num_trials):
|
|
stats = model.generate_sequence(
|
|
prompt_length=seq_len // 2,
|
|
generation_length=seq_len // 2,
|
|
kv_cache_size=cache_size
|
|
)
|
|
trials.append(stats)
|
|
|
|
# Average results
|
|
avg_stats = {
|
|
'label': label,
|
|
'cache_size': cache_size,
|
|
'avg_token_time': np.mean([t['avg_token_time'] for t in trials]),
|
|
'tokens_per_second': np.mean([t['tokens_per_second'] for t in trials]),
|
|
'max_memory_mb': np.mean([t['total_memory'] for t in trials]) / 1024 / 1024,
|
|
'total_recomputes': np.mean([t['total_recomputes'] for t in trials])
|
|
}
|
|
|
|
seq_results.append(avg_stats)
|
|
|
|
print(f" Avg token time: {avg_stats['avg_token_time']*1000:.2f} ms")
|
|
print(f" Tokens/second: {avg_stats['tokens_per_second']:.1f}")
|
|
print(f" Memory used: {avg_stats['max_memory_mb']:.1f} MB")
|
|
print(f" Recomputations: {avg_stats['total_recomputes']:.0f}")
|
|
|
|
results[seq_len] = seq_results
|
|
|
|
# Create visualizations
|
|
create_llm_plots(results)
|
|
|
|
# Save results
|
|
save_data = {
|
|
'model_config': {
|
|
'hidden_dim': config.hidden_dim,
|
|
'num_heads': config.num_heads,
|
|
'head_dim': config.head_dim
|
|
},
|
|
'results': results
|
|
}
|
|
|
|
with open('llm_kv_cache_results.json', 'w') as f:
|
|
json.dump(save_data, f, indent=2)
|
|
|
|
print("\n" + "="*60)
|
|
print("EXPERIMENT COMPLETE")
|
|
print("Generated files:")
|
|
print(" - llm_attention_tradeoff.png")
|
|
print(" - llm_kv_cache_results.json")
|
|
print("="*60)
|
|
|
|
def create_llm_plots(results):
|
|
"""Create publication-quality plots for LLM experiment"""
|
|
|
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
|
|
|
|
# Plot 1: Token generation time vs cache size
|
|
seq_lengths = sorted(results.keys())
|
|
colors = ['green', 'orange', 'red']
|
|
|
|
for seq_len in seq_lengths:
|
|
cache_sizes = [r['cache_size'] for r in results[seq_len]]
|
|
token_times = [r['avg_token_time'] * 1000 for r in results[seq_len]]
|
|
|
|
ax1.plot(cache_sizes, token_times, 'o-', label=f'Seq {seq_len}',
|
|
linewidth=2, markersize=8)
|
|
|
|
ax1.set_xlabel('KV-Cache Size (tokens)', fontsize=12)
|
|
ax1.set_ylabel('Avg Token Time (ms)', fontsize=12)
|
|
ax1.set_title('Token Generation Time vs Cache Size', fontsize=14)
|
|
ax1.set_xscale('log')
|
|
ax1.legend()
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
# Plot 2: Memory usage
|
|
for i, seq_len in enumerate(seq_lengths):
|
|
labels = [r['label'].replace(' O', '\nO') for r in results[seq_len]]
|
|
memory = [r['max_memory_mb'] for r in results[seq_len]]
|
|
|
|
x = np.arange(len(labels)) + i * 0.25
|
|
ax2.bar(x, memory, 0.25, label=f'Seq {seq_len}', alpha=0.8)
|
|
|
|
ax2.set_xticks(np.arange(len(labels)) + 0.25)
|
|
ax2.set_xticklabels(labels)
|
|
ax2.set_ylabel('Memory Usage (MB)', fontsize=12)
|
|
ax2.set_title('KV-Cache Memory Requirements', fontsize=14)
|
|
ax2.legend()
|
|
ax2.grid(True, alpha=0.3, axis='y')
|
|
|
|
# Plot 3: Throughput (tokens/second)
|
|
seq_len = 2048 # Focus on largest
|
|
data = results[seq_len]
|
|
|
|
labels = [r['label'] for r in data]
|
|
throughput = [r['tokens_per_second'] for r in data]
|
|
|
|
bars = ax3.bar(labels, throughput, color=colors, edgecolor='black', linewidth=1.5)
|
|
ax3.set_ylabel('Tokens per Second', fontsize=12)
|
|
ax3.set_title(f'Generation Throughput (seq_len={seq_len})', fontsize=14)
|
|
ax3.grid(True, alpha=0.3, axis='y')
|
|
|
|
# Add value labels
|
|
for bar, val in zip(bars, throughput):
|
|
ax3.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
|
|
f'{val:.0f}', ha='center', va='bottom', fontsize=11)
|
|
|
|
# Plot 4: Space-time tradeoff curve
|
|
for seq_len in seq_lengths:
|
|
cache_pct = [r['cache_size'] / seq_len * 100 for r in results[seq_len]]
|
|
speedup = [results[seq_len][0]['tokens_per_second'] / r['tokens_per_second']
|
|
for r in results[seq_len]]
|
|
|
|
ax4.plot(cache_pct, speedup, 's-', label=f'Seq {seq_len}',
|
|
linewidth=2, markersize=8)
|
|
|
|
# Add theoretical √n curve
|
|
x_theory = np.linspace(1, 100, 100)
|
|
y_theory = np.sqrt(100 / x_theory)
|
|
ax4.plot(x_theory, y_theory, 'k--', alpha=0.5, label='Theoretical √n')
|
|
|
|
ax4.set_xlabel('Cache Size (% of Sequence)', fontsize=12)
|
|
ax4.set_ylabel('Slowdown Factor', fontsize=12)
|
|
ax4.set_title('Space-Time Tradeoff in Attention', fontsize=14)
|
|
ax4.set_xscale('log')
|
|
ax4.set_yscale('log')
|
|
ax4.legend()
|
|
ax4.grid(True, alpha=0.3)
|
|
|
|
plt.suptitle('LLM Attention: KV-Cache Space-Time Tradeoffs', fontsize=16)
|
|
plt.tight_layout()
|
|
plt.savefig('llm_attention_tradeoff.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
if __name__ == "__main__":
|
|
run_llm_experiment() |