Initial
This commit is contained in:
244
case_studies/llm_transformers/detailed_analysis.md
Normal file
244
case_studies/llm_transformers/detailed_analysis.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Large Language Models: Space-Time Tradeoffs at Scale
|
||||
|
||||
## Overview
|
||||
Modern LLMs are a masterclass in space-time tradeoffs. With models reaching trillions of parameters, every architectural decision trades memory for computation.
|
||||
|
||||
## 1. Attention Mechanisms
|
||||
|
||||
### Standard Attention (O(n²) Space)
|
||||
```python
|
||||
# Naive attention: Store full attention matrix
|
||||
def standard_attention(Q, K, V):
|
||||
# Q, K, V: [batch, seq_len, d_model]
|
||||
scores = Q @ K.T / sqrt(d_model) # [batch, seq_len, seq_len]
|
||||
attn = softmax(scores) # Must store entire matrix!
|
||||
output = attn @ V
|
||||
return output
|
||||
|
||||
# Memory: O(seq_len²) - becomes prohibitive for long sequences
|
||||
# For seq_len=32K: 4GB just for attention matrix!
|
||||
```
|
||||
|
||||
### Flash Attention (O(n) Space)
|
||||
```python
|
||||
# Recompute attention in blocks during backward pass
|
||||
def flash_attention(Q, K, V, block_size=256):
|
||||
# Process in blocks, never materializing full matrix
|
||||
output = []
|
||||
for q_block in chunks(Q, block_size):
|
||||
block_out = compute_block_attention(q_block, K, V)
|
||||
output.append(block_out)
|
||||
return concat(output)
|
||||
|
||||
# Memory: O(seq_len) - linear in sequence length!
|
||||
# Time: ~2x slower but enables 10x longer sequences
|
||||
```
|
||||
|
||||
### Real Impact
|
||||
- GPT-3: Limited to 2K tokens due to quadratic memory
|
||||
- GPT-4 with Flash: 32K tokens with same hardware
|
||||
- Claude: 100K+ tokens using similar techniques
|
||||
|
||||
## 2. KV-Cache Optimization
|
||||
|
||||
### Standard KV-Cache
|
||||
```python
|
||||
# During generation, cache keys and values
|
||||
class StandardKVCache:
|
||||
def __init__(self, max_seq_len, n_layers, n_heads, d_head):
|
||||
# Cache for all positions
|
||||
self.k_cache = zeros(n_layers, max_seq_len, n_heads, d_head)
|
||||
self.v_cache = zeros(n_layers, max_seq_len, n_heads, d_head)
|
||||
|
||||
# Memory: O(max_seq_len × n_layers × hidden_dim)
|
||||
# For 70B model: ~140GB for 32K context!
|
||||
```
|
||||
|
||||
### Multi-Query Attention (MQA)
|
||||
```python
|
||||
# Share keys/values across heads
|
||||
class MQACache:
|
||||
def __init__(self, max_seq_len, n_layers, d_model):
|
||||
# Single K,V per layer instead of per head
|
||||
self.k_cache = zeros(n_layers, max_seq_len, d_model)
|
||||
self.v_cache = zeros(n_layers, max_seq_len, d_model)
|
||||
|
||||
# Memory: O(max_seq_len × n_layers × d_model / n_heads)
|
||||
# 8-32x memory reduction!
|
||||
```
|
||||
|
||||
### Grouped-Query Attention (GQA)
|
||||
Balance between quality and memory:
|
||||
- Groups of 4-8 heads share K,V
|
||||
- 4-8x memory reduction
|
||||
- <1% quality loss
|
||||
|
||||
## 3. Model Quantization
|
||||
|
||||
### Full Precision (32-bit)
|
||||
```python
|
||||
# Standard weights
|
||||
weight = torch.randn(4096, 4096, dtype=torch.float32)
|
||||
# Memory: 64MB per layer
|
||||
# Computation: Fast matmul
|
||||
```
|
||||
|
||||
### INT8 Quantization
|
||||
```python
|
||||
# 8-bit weights with scale factors
|
||||
weight_int8 = (weight * scale).round().clamp(-128, 127).to(torch.int8)
|
||||
# Memory: 16MB per layer (4x reduction)
|
||||
# Computation: Slightly slower, dequantize on the fly
|
||||
```
|
||||
|
||||
### 4-bit Quantization (QLoRA)
|
||||
```python
|
||||
# Extreme quantization with adapters
|
||||
weight_4bit = quantize_nf4(weight) # 4-bit normal float
|
||||
lora_A = torch.randn(4096, 16) # Low-rank adapter
|
||||
lora_B = torch.randn(16, 4096)
|
||||
|
||||
def forward(x):
|
||||
# Dequantize and compute
|
||||
base = dequantize(weight_4bit) @ x
|
||||
adapter = lora_B @ (lora_A @ x)
|
||||
return base + adapter
|
||||
|
||||
# Memory: 8MB base + 0.5MB adapter (8x reduction)
|
||||
# Time: 2-3x slower due to dequantization
|
||||
```
|
||||
|
||||
## 4. Checkpoint Strategies
|
||||
|
||||
### Gradient Checkpointing
|
||||
```python
|
||||
# Standard: Store all activations
|
||||
def transformer_layer(x):
|
||||
attn = self.attention(x) # Store activation
|
||||
ff = self.feedforward(attn) # Store activation
|
||||
return ff
|
||||
|
||||
# With checkpointing: Recompute during backward
|
||||
@checkpoint
|
||||
def transformer_layer(x):
|
||||
attn = self.attention(x) # Don't store
|
||||
ff = self.feedforward(attn) # Don't store
|
||||
return ff
|
||||
|
||||
# Memory: O(√n_layers) instead of O(n_layers)
|
||||
# Time: 30% slower training
|
||||
```
|
||||
|
||||
## 5. Sparse Models
|
||||
|
||||
### Dense Model
|
||||
- Every token processed by all parameters
|
||||
- Memory: O(n_params)
|
||||
- Time: O(n_tokens × n_params)
|
||||
|
||||
### Mixture of Experts (MoE)
|
||||
```python
|
||||
# Route to subset of experts
|
||||
def moe_layer(x):
|
||||
router_logits = self.router(x)
|
||||
expert_ids = top_k(router_logits, k=2)
|
||||
|
||||
output = 0
|
||||
for expert_id in expert_ids:
|
||||
output += self.experts[expert_id](x)
|
||||
|
||||
return output
|
||||
|
||||
# Memory: Full model size
|
||||
# Active memory: O(n_params / n_experts)
|
||||
# Enables 10x larger models with same compute
|
||||
```
|
||||
|
||||
## 6. Real-World Examples
|
||||
|
||||
### GPT-3 vs GPT-4
|
||||
| Aspect | GPT-3 | GPT-4 |
|
||||
|--------|-------|-------|
|
||||
| Parameters | 175B | ~1.8T (MoE) |
|
||||
| Context | 2K | 32K-128K |
|
||||
| Techniques | Dense | MoE + Flash + GQA |
|
||||
| Memory/token | ~350MB | ~50MB (active) |
|
||||
|
||||
### Llama 2 Family
|
||||
```
|
||||
Llama-2-7B: Full precision = 28GB
|
||||
INT8 = 7GB
|
||||
INT4 = 3.5GB
|
||||
|
||||
Llama-2-70B: Full precision = 280GB
|
||||
INT8 = 70GB
|
||||
INT4 + QLoRA = 35GB (fits on single GPU!)
|
||||
```
|
||||
|
||||
## 7. Serving Optimizations
|
||||
|
||||
### Continuous Batching
|
||||
Instead of fixed batches, dynamically batch requests:
|
||||
- Memory: Reuse KV-cache across requests
|
||||
- Time: Higher throughput via better GPU utilization
|
||||
|
||||
### PagedAttention (vLLM)
|
||||
```python
|
||||
# Treat KV-cache like virtual memory
|
||||
class PagedKVCache:
|
||||
def __init__(self, block_size=16):
|
||||
self.blocks = {} # Allocated on demand
|
||||
self.page_table = {} # Maps positions to blocks
|
||||
|
||||
def allocate(self, seq_id, position):
|
||||
# Only allocate blocks as needed
|
||||
if position // self.block_size not in self.page_table[seq_id]:
|
||||
self.page_table[seq_id].append(new_block())
|
||||
```
|
||||
|
||||
Memory fragmentation: <5% vs 60% for naive allocation
|
||||
|
||||
## 8. Training vs Inference Tradeoffs
|
||||
|
||||
### Training (Memory Intensive)
|
||||
- Gradients: 2x model size
|
||||
- Optimizer states: 2-3x model size
|
||||
- Activations: O(batch × seq_len × layers)
|
||||
- Total: 15-20x model parameters
|
||||
|
||||
### Inference (Can Trade Memory for Time)
|
||||
- Only model weights needed
|
||||
- Quantize aggressively
|
||||
- Recompute instead of cache
|
||||
- Stream weights from disk if needed
|
||||
|
||||
## Key Insights
|
||||
|
||||
1. **Every major LLM innovation** is a space-time tradeoff:
|
||||
- Flash Attention: Recompute for linear memory
|
||||
- Quantization: Dequantize for smaller models
|
||||
- MoE: Route for sparse activation
|
||||
|
||||
2. **The √n pattern appears everywhere**:
|
||||
- Gradient checkpointing: √n_layers memory
|
||||
- Block-wise attention: √seq_len blocks
|
||||
- Optimal batch sizes: Often √total_examples
|
||||
|
||||
3. **Practical systems combine multiple techniques**:
|
||||
- GPT-4: MoE + Flash + INT8 + GQA
|
||||
- Llama: Quantization + RoPE + GQA
|
||||
- Claude: Flash + Constitutional training
|
||||
|
||||
4. **Memory is the binding constraint**:
|
||||
- Not compute or data
|
||||
- Drives all architectural decisions
|
||||
- Williams' result predicts these optimizations
|
||||
|
||||
## Connection to Theory
|
||||
|
||||
Williams showed TIME[t] ⊆ SPACE[√(t log t)]. In LLMs:
|
||||
- Standard attention: O(n²) space, O(n²) time
|
||||
- Flash attention: O(n) space, O(n² log n) time
|
||||
- The log factor comes from block coordination
|
||||
|
||||
This validates that the theoretical √t space bound manifests in practice, driving the most important optimizations in modern AI systems.
|
||||
Reference in New Issue
Block a user