""" LLM KV-Cache Space-Time Tradeoff Experiment Demonstrates how KV-cache size affects transformer inference time, showing Williams' √n pattern in modern AI systems. This simulates the core attention mechanism where: - Full KV-cache (O(n)): Store all past tokens' keys/values - Sliding window (O(√n)): Keep only recent √n tokens - Minimal cache (O(1)): Recompute everything Based on Flash Attention and similar optimizations used in production LLMs. """ import numpy as np import time import matplotlib.pyplot as plt from typing import Dict, List, Tuple import json from dataclasses import dataclass @dataclass class AttentionConfig: """Configuration for attention mechanism""" seq_length: int # Total sequence length hidden_dim: int # Model dimension (d_model) num_heads: int # Number of attention heads head_dim: int # Dimension per head batch_size: int = 1 # Batch size def __post_init__(self): assert self.hidden_dim == self.num_heads * self.head_dim class TransformerAttention: """Simplified transformer attention with configurable KV-cache""" def __init__(self, config: AttentionConfig): self.config = config # Initialize weights (random for simulation) self.W_q = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02 self.W_k = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02 self.W_v = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02 self.W_o = np.random.randn(config.hidden_dim, config.hidden_dim) * 0.02 def compute_attention(self, query_pos: int, hidden_states: np.ndarray, kv_cache_size: int) -> Tuple[np.ndarray, Dict]: """ Compute attention for position query_pos with limited KV-cache Args: query_pos: Current token position hidden_states: All hidden states up to query_pos kv_cache_size: Maximum number of past tokens to cache Returns: attention_output: Output for the query position stats: Performance statistics """ stats = { 'cache_size': kv_cache_size, 'recompute_steps': 0, 'cache_hits': 0, 'memory_used': 0 } # Get query vector for current position query = hidden_states[query_pos:query_pos+1] # [1, hidden_dim] Q = query @ self.W_q # [1, hidden_dim] # Reshape for multi-head attention Q = Q.reshape(1, self.config.num_heads, self.config.head_dim) # Determine which positions to attend to if kv_cache_size >= query_pos: # Full cache - use all previous positions start_pos = 0 cached_positions = query_pos stats['cache_hits'] = query_pos else: # Limited cache - use only recent positions start_pos = max(0, query_pos - kv_cache_size) cached_positions = min(kv_cache_size, query_pos) stats['cache_hits'] = cached_positions stats['recompute_steps'] = query_pos - cached_positions # Get relevant hidden states relevant_hidden = hidden_states[start_pos:query_pos+1] # Compute keys and values (this is what we cache/recompute) start_time = time.time() K = relevant_hidden @ self.W_k # [seq_len, hidden_dim] V = relevant_hidden @ self.W_v compute_time = time.time() - start_time # Reshape for multi-head seq_len = K.shape[0] K = K.reshape(seq_len, self.config.num_heads, self.config.head_dim) V = V.reshape(seq_len, self.config.num_heads, self.config.head_dim) # Compute attention scores scores = np.einsum('qhd,khd->hqk', Q, K) / np.sqrt(self.config.head_dim) # Apply causal mask if needed if start_pos > 0: # Mask out positions we can't see due to limited cache mask = np.ones_like(scores) scores = scores * mask # Softmax attn_weights = self._softmax(scores, axis=-1) # Apply attention to values attn_output = np.einsum('hqk,khd->qhd', attn_weights, V) # Reshape and project attn_output = attn_output.reshape(1, self.config.hidden_dim) output = attn_output @ self.W_o # Calculate memory usage stats['memory_used'] = ( 2 * cached_positions * self.config.hidden_dim * 4 # K and V cache in bytes ) stats['compute_time'] = compute_time return output, stats def _softmax(self, x, axis=-1): """Numerically stable softmax""" e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return e_x / np.sum(e_x, axis=axis, keepdims=True) def generate_sequence(self, prompt_length: int, generation_length: int, kv_cache_size: int) -> Dict: """ Simulate autoregressive generation with limited KV-cache This mimics how LLMs generate text token by token """ total_length = prompt_length + generation_length hidden_dim = self.config.hidden_dim # Initialize with random hidden states (simulating embeddings) hidden_states = np.random.randn(total_length, hidden_dim) * 0.1 total_stats = { 'total_time': 0, 'total_memory': 0, 'total_recomputes': 0, 'per_token_times': [] } # Process prompt (can use full attention) start_time = time.time() for pos in range(prompt_length): _, stats = self.compute_attention(pos, hidden_states, kv_cache_size) prompt_time = time.time() - start_time # Generate new tokens generation_times = [] for pos in range(prompt_length, total_length): start = time.time() output, stats = self.compute_attention(pos, hidden_states, kv_cache_size) token_time = time.time() - start generation_times.append(token_time) total_stats['total_recomputes'] += stats['recompute_steps'] total_stats['total_memory'] = max(total_stats['total_memory'], stats['memory_used']) # Simulate token generation (would normally sample from logits) hidden_states[pos] = output[0] total_stats['total_time'] = sum(generation_times) + prompt_time total_stats['avg_token_time'] = np.mean(generation_times) if generation_times else 0 total_stats['prompt_time'] = prompt_time total_stats['generation_time'] = sum(generation_times) total_stats['tokens_per_second'] = generation_length / sum(generation_times) if generation_times else 0 return total_stats def run_llm_experiment(): """Run comprehensive LLM KV-cache experiment""" print("="*60) print("LLM KV-Cache Space-Time Tradeoff Experiment") print("Simulating transformer attention with different cache sizes") print("="*60) # Model configuration (similar to GPT-2 small) config = AttentionConfig( seq_length=2048, # Max sequence length hidden_dim=768, # Model dimension num_heads=12, # Attention heads head_dim=64, # Dimension per head batch_size=1 ) model = TransformerAttention(config) # Test different sequence lengths test_lengths = [512, 1024, 2048] results = {} for seq_len in test_lengths: print(f"\n{'='*40}") print(f"Testing sequence length: {seq_len}") print(f"{'='*40}") # Different KV-cache configurations cache_configs = [ ('Full O(n)', seq_len), # Full attention ('Flash O(√n)', int(np.sqrt(seq_len) * 4)), # Flash Attention-like ('Minimal O(1)', 8), # Almost no cache ] seq_results = [] for label, cache_size in cache_configs: print(f"\n{label}: {cache_size} tokens cached") # Run multiple trials trials = [] num_trials = 5 for trial in range(num_trials): stats = model.generate_sequence( prompt_length=seq_len // 2, generation_length=seq_len // 2, kv_cache_size=cache_size ) trials.append(stats) # Average results avg_stats = { 'label': label, 'cache_size': cache_size, 'avg_token_time': np.mean([t['avg_token_time'] for t in trials]), 'tokens_per_second': np.mean([t['tokens_per_second'] for t in trials]), 'max_memory_mb': np.mean([t['total_memory'] for t in trials]) / 1024 / 1024, 'total_recomputes': np.mean([t['total_recomputes'] for t in trials]) } seq_results.append(avg_stats) print(f" Avg token time: {avg_stats['avg_token_time']*1000:.2f} ms") print(f" Tokens/second: {avg_stats['tokens_per_second']:.1f}") print(f" Memory used: {avg_stats['max_memory_mb']:.1f} MB") print(f" Recomputations: {avg_stats['total_recomputes']:.0f}") results[seq_len] = seq_results # Create visualizations create_llm_plots(results) # Save results save_data = { 'model_config': { 'hidden_dim': config.hidden_dim, 'num_heads': config.num_heads, 'head_dim': config.head_dim }, 'results': results } with open('llm_kv_cache_results.json', 'w') as f: json.dump(save_data, f, indent=2) print("\n" + "="*60) print("EXPERIMENT COMPLETE") print("Generated files:") print(" - llm_attention_tradeoff.png") print(" - llm_kv_cache_results.json") print("="*60) def create_llm_plots(results): """Create publication-quality plots for LLM experiment""" fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10)) # Plot 1: Token generation time vs cache size seq_lengths = sorted(results.keys()) colors = ['green', 'orange', 'red'] for seq_len in seq_lengths: cache_sizes = [r['cache_size'] for r in results[seq_len]] token_times = [r['avg_token_time'] * 1000 for r in results[seq_len]] ax1.plot(cache_sizes, token_times, 'o-', label=f'Seq {seq_len}', linewidth=2, markersize=8) ax1.set_xlabel('KV-Cache Size (tokens)', fontsize=12) ax1.set_ylabel('Avg Token Time (ms)', fontsize=12) ax1.set_title('Token Generation Time vs Cache Size', fontsize=14) ax1.set_xscale('log') ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: Memory usage for i, seq_len in enumerate(seq_lengths): labels = [r['label'].replace(' O', '\nO') for r in results[seq_len]] memory = [r['max_memory_mb'] for r in results[seq_len]] x = np.arange(len(labels)) + i * 0.25 ax2.bar(x, memory, 0.25, label=f'Seq {seq_len}', alpha=0.8) ax2.set_xticks(np.arange(len(labels)) + 0.25) ax2.set_xticklabels(labels) ax2.set_ylabel('Memory Usage (MB)', fontsize=12) ax2.set_title('KV-Cache Memory Requirements', fontsize=14) ax2.legend() ax2.grid(True, alpha=0.3, axis='y') # Plot 3: Throughput (tokens/second) seq_len = 2048 # Focus on largest data = results[seq_len] labels = [r['label'] for r in data] throughput = [r['tokens_per_second'] for r in data] bars = ax3.bar(labels, throughput, color=colors, edgecolor='black', linewidth=1.5) ax3.set_ylabel('Tokens per Second', fontsize=12) ax3.set_title(f'Generation Throughput (seq_len={seq_len})', fontsize=14) ax3.grid(True, alpha=0.3, axis='y') # Add value labels for bar, val in zip(bars, throughput): ax3.text(bar.get_x() + bar.get_width()/2., bar.get_height(), f'{val:.0f}', ha='center', va='bottom', fontsize=11) # Plot 4: Space-time tradeoff curve for seq_len in seq_lengths: cache_pct = [r['cache_size'] / seq_len * 100 for r in results[seq_len]] speedup = [results[seq_len][0]['tokens_per_second'] / r['tokens_per_second'] for r in results[seq_len]] ax4.plot(cache_pct, speedup, 's-', label=f'Seq {seq_len}', linewidth=2, markersize=8) # Add theoretical √n curve x_theory = np.linspace(1, 100, 100) y_theory = np.sqrt(100 / x_theory) ax4.plot(x_theory, y_theory, 'k--', alpha=0.5, label='Theoretical √n') ax4.set_xlabel('Cache Size (% of Sequence)', fontsize=12) ax4.set_ylabel('Slowdown Factor', fontsize=12) ax4.set_title('Space-Time Tradeoff in Attention', fontsize=14) ax4.set_xscale('log') ax4.set_yscale('log') ax4.legend() ax4.grid(True, alpha=0.3) plt.suptitle('LLM Attention: KV-Cache Space-Time Tradeoffs', fontsize=16) plt.tight_layout() plt.savefig('llm_attention_tradeoff.png', dpi=300, bbox_inches='tight') plt.close() if __name__ == "__main__": run_llm_experiment()