sqrtspace-experiments/case_studies/llm_transformers/detailed_analysis.md
2025-07-20 03:56:21 -04:00

6.9 KiB
Raw Blame History

Large Language Models: Space-Time Tradeoffs at Scale

Overview

Modern LLMs are a masterclass in space-time tradeoffs. With models reaching trillions of parameters, every architectural decision trades memory for computation.

1. Attention Mechanisms

Standard Attention (O(n²) Space)

# Naive attention: Store full attention matrix
def standard_attention(Q, K, V):
    # Q, K, V: [batch, seq_len, d_model]
    scores = Q @ K.T / sqrt(d_model)  # [batch, seq_len, seq_len]
    attn = softmax(scores)            # Must store entire matrix!
    output = attn @ V
    return output

# Memory: O(seq_len²) - becomes prohibitive for long sequences
# For seq_len=32K: 4GB just for attention matrix!

Flash Attention (O(n) Space)

# Recompute attention in blocks during backward pass
def flash_attention(Q, K, V, block_size=256):
    # Process in blocks, never materializing full matrix
    output = []
    for q_block in chunks(Q, block_size):
        block_out = compute_block_attention(q_block, K, V)
        output.append(block_out)
    return concat(output)

# Memory: O(seq_len) - linear in sequence length!
# Time: ~2x slower but enables 10x longer sequences

Real Impact

  • GPT-3: Limited to 2K tokens due to quadratic memory
  • GPT-4 with Flash: 32K tokens with same hardware
  • Claude: 100K+ tokens using similar techniques

2. KV-Cache Optimization

Standard KV-Cache

# During generation, cache keys and values
class StandardKVCache:
    def __init__(self, max_seq_len, n_layers, n_heads, d_head):
        # Cache for all positions
        self.k_cache = zeros(n_layers, max_seq_len, n_heads, d_head)
        self.v_cache = zeros(n_layers, max_seq_len, n_heads, d_head)
    
    # Memory: O(max_seq_len × n_layers × hidden_dim)
    # For 70B model: ~140GB for 32K context!

Multi-Query Attention (MQA)

# Share keys/values across heads
class MQACache:
    def __init__(self, max_seq_len, n_layers, d_model):
        # Single K,V per layer instead of per head
        self.k_cache = zeros(n_layers, max_seq_len, d_model)
        self.v_cache = zeros(n_layers, max_seq_len, d_model)
    
    # Memory: O(max_seq_len × n_layers × d_model / n_heads)
    # 8-32x memory reduction!

Grouped-Query Attention (GQA)

Balance between quality and memory:

  • Groups of 4-8 heads share K,V
  • 4-8x memory reduction
  • <1% quality loss

3. Model Quantization

Full Precision (32-bit)

# Standard weights
weight = torch.randn(4096, 4096, dtype=torch.float32)
# Memory: 64MB per layer
# Computation: Fast matmul

INT8 Quantization

# 8-bit weights with scale factors
weight_int8 = (weight * scale).round().clamp(-128, 127).to(torch.int8)
# Memory: 16MB per layer (4x reduction)
# Computation: Slightly slower, dequantize on the fly

4-bit Quantization (QLoRA)

# Extreme quantization with adapters
weight_4bit = quantize_nf4(weight)  # 4-bit normal float
lora_A = torch.randn(4096, 16)      # Low-rank adapter
lora_B = torch.randn(16, 4096)

def forward(x):
    # Dequantize and compute
    base = dequantize(weight_4bit) @ x
    adapter = lora_B @ (lora_A @ x)
    return base + adapter

# Memory: 8MB base + 0.5MB adapter (8x reduction)
# Time: 2-3x slower due to dequantization

4. Checkpoint Strategies

Gradient Checkpointing

# Standard: Store all activations
def transformer_layer(x):
    attn = self.attention(x)      # Store activation
    ff = self.feedforward(attn)   # Store activation
    return ff

# With checkpointing: Recompute during backward
@checkpoint
def transformer_layer(x):
    attn = self.attention(x)      # Don't store
    ff = self.feedforward(attn)   # Don't store
    return ff

# Memory: O(√n_layers) instead of O(n_layers)
# Time: 30% slower training

5. Sparse Models

Dense Model

  • Every token processed by all parameters
  • Memory: O(n_params)
  • Time: O(n_tokens × n_params)

Mixture of Experts (MoE)

# Route to subset of experts
def moe_layer(x):
    router_logits = self.router(x)
    expert_ids = top_k(router_logits, k=2)
    
    output = 0
    for expert_id in expert_ids:
        output += self.experts[expert_id](x)
    
    return output

# Memory: Full model size
# Active memory: O(n_params / n_experts)
# Enables 10x larger models with same compute

6. Real-World Examples

GPT-3 vs GPT-4

Aspect GPT-3 GPT-4
Parameters 175B ~1.8T (MoE)
Context 2K 32K-128K
Techniques Dense MoE + Flash + GQA
Memory/token ~350MB ~50MB (active)

Llama 2 Family

Llama-2-7B:  Full precision = 28GB
             INT8 = 7GB
             INT4 = 3.5GB
             
Llama-2-70B: Full precision = 280GB
             INT8 = 70GB
             INT4 + QLoRA = 35GB (fits on single GPU!)

7. Serving Optimizations

Continuous Batching

Instead of fixed batches, dynamically batch requests:

  • Memory: Reuse KV-cache across requests
  • Time: Higher throughput via better GPU utilization

PagedAttention (vLLM)

# Treat KV-cache like virtual memory
class PagedKVCache:
    def __init__(self, block_size=16):
        self.blocks = {}  # Allocated on demand
        self.page_table = {}  # Maps positions to blocks
    
    def allocate(self, seq_id, position):
        # Only allocate blocks as needed
        if position // self.block_size not in self.page_table[seq_id]:
            self.page_table[seq_id].append(new_block())

Memory fragmentation: <5% vs 60% for naive allocation

8. Training vs Inference Tradeoffs

Training (Memory Intensive)

  • Gradients: 2x model size
  • Optimizer states: 2-3x model size
  • Activations: O(batch × seq_len × layers)
  • Total: 15-20x model parameters

Inference (Can Trade Memory for Time)

  • Only model weights needed
  • Quantize aggressively
  • Recompute instead of cache
  • Stream weights from disk if needed

Key Insights

  1. Every major LLM innovation is a space-time tradeoff:

    • Flash Attention: Recompute for linear memory
    • Quantization: Dequantize for smaller models
    • MoE: Route for sparse activation
  2. The √n pattern appears everywhere:

    • Gradient checkpointing: √n_layers memory
    • Block-wise attention: √seq_len blocks
    • Optimal batch sizes: Often √total_examples
  3. Practical systems combine multiple techniques:

    • GPT-4: MoE + Flash + INT8 + GQA
    • Llama: Quantization + RoPE + GQA
    • Claude: Flash + Constitutional training
  4. Memory is the binding constraint:

    • Not compute or data
    • Drives all architectural decisions
    • Williams' result predicts these optimizations

Connection to Theory

Williams showed TIME[t] ⊆ SPACE[√(t log t)]. In LLMs:

  • Standard attention: O(n²) space, O(n²) time
  • Flash attention: O(n) space, O(n² log n) time
  • The log factor comes from block coordination

This validates that the theoretical √t space bound manifests in practice, driving the most important optimizations in modern AI systems.