Initial
This commit is contained in:
760
db_optimizer/memory_aware_optimizer.py
Normal file
760
db_optimizer/memory_aware_optimizer.py
Normal file
@@ -0,0 +1,760 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory-Aware Query Optimizer: Database query optimizer considering memory hierarchies
|
||||
|
||||
Features:
|
||||
- Cost Model: Include L3/RAM/SSD boundaries in cost calculations
|
||||
- Algorithm Selection: Choose between hash/sort/nested-loop based on true costs
|
||||
- Buffer Sizing: Automatically size buffers to √(data_size)
|
||||
- Spill Planning: Optimize when and how to spill to disk
|
||||
- AI Explanations: Clear reasoning for optimization decisions
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import sqlite3
|
||||
import psutil
|
||||
import numpy as np
|
||||
import time
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, List, Tuple, Optional, Any, Union
|
||||
from enum import Enum
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Import core components
|
||||
from core.spacetime_core import (
|
||||
MemoryHierarchy,
|
||||
SqrtNCalculator,
|
||||
OptimizationStrategy,
|
||||
StrategyAnalyzer
|
||||
)
|
||||
|
||||
|
||||
class JoinAlgorithm(Enum):
|
||||
"""Join algorithms with different space-time tradeoffs"""
|
||||
NESTED_LOOP = "nested_loop" # O(1) space, O(n*m) time
|
||||
SORT_MERGE = "sort_merge" # O(n+m) space, O(n log n + m log m) time
|
||||
HASH_JOIN = "hash_join" # O(min(n,m)) space, O(n+m) time
|
||||
BLOCK_NESTED = "block_nested" # O(√n) space, O(n*m/√n) time
|
||||
|
||||
|
||||
class ScanType(Enum):
|
||||
"""Scan types for table access"""
|
||||
SEQUENTIAL = "sequential" # Full table scan
|
||||
INDEX = "index" # Index scan
|
||||
BITMAP = "bitmap" # Bitmap index scan
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableStats:
|
||||
"""Statistics about a database table"""
|
||||
name: str
|
||||
row_count: int
|
||||
avg_row_size: int
|
||||
total_size: int
|
||||
indexes: List[str]
|
||||
cardinality: Dict[str, int] # Column -> distinct values
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryNode:
|
||||
"""Node in query execution plan"""
|
||||
operation: str
|
||||
algorithm: Optional[str]
|
||||
estimated_rows: int
|
||||
estimated_size: int
|
||||
estimated_cost: float
|
||||
memory_required: int
|
||||
memory_level: str
|
||||
children: List['QueryNode']
|
||||
explanation: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationResult:
|
||||
"""Result of query optimization"""
|
||||
original_plan: QueryNode
|
||||
optimized_plan: QueryNode
|
||||
memory_saved: int
|
||||
estimated_speedup: float
|
||||
buffer_sizes: Dict[str, int]
|
||||
spill_strategy: Dict[str, str]
|
||||
explanation: str
|
||||
|
||||
|
||||
class CostModel:
|
||||
"""Cost model considering memory hierarchy"""
|
||||
|
||||
def __init__(self, hierarchy: MemoryHierarchy):
|
||||
self.hierarchy = hierarchy
|
||||
|
||||
# Cost factors (relative to L1 access)
|
||||
self.cpu_factor = 0.1
|
||||
self.l1_factor = 1.0
|
||||
self.l2_factor = 4.0
|
||||
self.l3_factor = 12.0
|
||||
self.ram_factor = 100.0
|
||||
self.disk_factor = 10000.0
|
||||
|
||||
def calculate_scan_cost(self, table_size: int, scan_type: ScanType) -> float:
|
||||
"""Calculate cost of scanning a table"""
|
||||
level, latency = self.hierarchy.get_level_for_size(table_size)
|
||||
|
||||
if scan_type == ScanType.SEQUENTIAL:
|
||||
# Sequential scan benefits from prefetching
|
||||
return table_size * latency * 0.5
|
||||
elif scan_type == ScanType.INDEX:
|
||||
# Random access pattern
|
||||
return table_size * latency * 2.0
|
||||
else: # BITMAP
|
||||
# Mixed pattern
|
||||
return table_size * latency
|
||||
|
||||
def calculate_join_cost(self, left_size: int, right_size: int,
|
||||
algorithm: JoinAlgorithm, buffer_size: int) -> float:
|
||||
"""Calculate cost of join operation"""
|
||||
if algorithm == JoinAlgorithm.NESTED_LOOP:
|
||||
# O(n*m) comparisons, minimal memory
|
||||
comparisons = left_size * right_size
|
||||
memory_used = buffer_size
|
||||
|
||||
elif algorithm == JoinAlgorithm.SORT_MERGE:
|
||||
# Sort both sides then merge
|
||||
sort_cost = left_size * np.log2(left_size) + right_size * np.log2(right_size)
|
||||
merge_cost = left_size + right_size
|
||||
comparisons = sort_cost + merge_cost
|
||||
memory_used = left_size + right_size
|
||||
|
||||
elif algorithm == JoinAlgorithm.HASH_JOIN:
|
||||
# Build hash table on smaller side
|
||||
build_size = min(left_size, right_size)
|
||||
probe_size = max(left_size, right_size)
|
||||
comparisons = build_size + probe_size
|
||||
memory_used = build_size * 1.5 # Hash table overhead
|
||||
|
||||
else: # BLOCK_NESTED
|
||||
# Process in √n blocks
|
||||
block_size = int(np.sqrt(min(left_size, right_size)))
|
||||
blocks = (left_size // block_size) * (right_size // block_size)
|
||||
comparisons = blocks * block_size * block_size
|
||||
memory_used = block_size
|
||||
|
||||
# Get memory level for this operation
|
||||
level, latency = self.hierarchy.get_level_for_size(memory_used)
|
||||
|
||||
# Add spill cost if memory exceeded
|
||||
spill_cost = 0
|
||||
if memory_used > buffer_size:
|
||||
spill_ratio = memory_used / buffer_size
|
||||
spill_cost = comparisons * self.disk_factor * 0.1 * spill_ratio
|
||||
|
||||
return comparisons * latency + spill_cost
|
||||
|
||||
def calculate_sort_cost(self, data_size: int, memory_limit: int) -> float:
|
||||
"""Calculate cost of sorting with limited memory"""
|
||||
if data_size <= memory_limit:
|
||||
# In-memory sort
|
||||
comparisons = data_size * np.log2(data_size)
|
||||
level, latency = self.hierarchy.get_level_for_size(data_size)
|
||||
return comparisons * latency
|
||||
else:
|
||||
# External sort with √n memory
|
||||
runs = data_size // memory_limit
|
||||
merge_passes = np.log2(runs)
|
||||
total_io = data_size * merge_passes * 2 # Read + write
|
||||
return total_io * self.disk_factor
|
||||
|
||||
|
||||
class QueryAnalyzer:
|
||||
"""Analyze queries and extract operations"""
|
||||
|
||||
@staticmethod
|
||||
def parse_query(sql: str) -> Dict[str, Any]:
|
||||
"""Parse SQL query to extract operations"""
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Extract tables
|
||||
tables = []
|
||||
from_match = re.search(r'FROM\s+(\w+)', sql_upper)
|
||||
if from_match:
|
||||
tables.append(from_match.group(1))
|
||||
|
||||
join_matches = re.findall(r'JOIN\s+(\w+)', sql_upper)
|
||||
tables.extend(join_matches)
|
||||
|
||||
# Extract join conditions
|
||||
joins = []
|
||||
join_pattern = r'(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)'
|
||||
for match in re.finditer(join_pattern, sql, re.IGNORECASE):
|
||||
joins.append({
|
||||
'left_table': match.group(1),
|
||||
'left_col': match.group(2),
|
||||
'right_table': match.group(3),
|
||||
'right_col': match.group(4)
|
||||
})
|
||||
|
||||
# Extract filters
|
||||
where_match = re.search(r'WHERE\s+(.+?)(?:GROUP|ORDER|LIMIT|$)', sql_upper)
|
||||
filters = where_match.group(1) if where_match else None
|
||||
|
||||
# Extract aggregations
|
||||
agg_functions = ['COUNT', 'SUM', 'AVG', 'MIN', 'MAX']
|
||||
aggregations = []
|
||||
for func in agg_functions:
|
||||
if func in sql_upper:
|
||||
aggregations.append(func)
|
||||
|
||||
# Extract order by
|
||||
order_match = re.search(r'ORDER\s+BY\s+(.+?)(?:LIMIT|$)', sql_upper)
|
||||
order_by = order_match.group(1) if order_match else None
|
||||
|
||||
return {
|
||||
'tables': tables,
|
||||
'joins': joins,
|
||||
'filters': filters,
|
||||
'aggregations': aggregations,
|
||||
'order_by': order_by
|
||||
}
|
||||
|
||||
|
||||
class MemoryAwareOptimizer:
|
||||
"""Main query optimizer with memory awareness"""
|
||||
|
||||
def __init__(self, connection: sqlite3.Connection,
|
||||
memory_limit: Optional[int] = None):
|
||||
self.conn = connection
|
||||
self.hierarchy = MemoryHierarchy.detect_system()
|
||||
self.cost_model = CostModel(self.hierarchy)
|
||||
self.memory_limit = memory_limit or int(psutil.virtual_memory().available * 0.5)
|
||||
self.table_stats = {}
|
||||
|
||||
# Collect table statistics
|
||||
self._collect_statistics()
|
||||
|
||||
def _collect_statistics(self):
|
||||
"""Collect statistics about database tables"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Get all tables
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = cursor.fetchall()
|
||||
|
||||
for (table_name,) in tables:
|
||||
# Get row count
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
row_count = cursor.fetchone()[0]
|
||||
|
||||
# Estimate row size (simplified)
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = cursor.fetchall()
|
||||
avg_row_size = len(columns) * 20 # Rough estimate
|
||||
|
||||
# Get indexes
|
||||
cursor.execute(f"PRAGMA index_list({table_name})")
|
||||
indexes = [idx[1] for idx in cursor.fetchall()]
|
||||
|
||||
self.table_stats[table_name] = TableStats(
|
||||
name=table_name,
|
||||
row_count=row_count,
|
||||
avg_row_size=avg_row_size,
|
||||
total_size=row_count * avg_row_size,
|
||||
indexes=indexes,
|
||||
cardinality={}
|
||||
)
|
||||
|
||||
def optimize_query(self, sql: str) -> OptimizationResult:
|
||||
"""Optimize a SQL query considering memory constraints"""
|
||||
# Parse query
|
||||
query_info = QueryAnalyzer.parse_query(sql)
|
||||
|
||||
# Build original plan
|
||||
original_plan = self._build_execution_plan(query_info, optimize=False)
|
||||
|
||||
# Build optimized plan
|
||||
optimized_plan = self._build_execution_plan(query_info, optimize=True)
|
||||
|
||||
# Calculate buffer sizes
|
||||
buffer_sizes = self._calculate_buffer_sizes(optimized_plan)
|
||||
|
||||
# Determine spill strategy
|
||||
spill_strategy = self._determine_spill_strategy(optimized_plan)
|
||||
|
||||
# Calculate improvements
|
||||
memory_saved = original_plan.memory_required - optimized_plan.memory_required
|
||||
estimated_speedup = original_plan.estimated_cost / optimized_plan.estimated_cost
|
||||
|
||||
# Generate explanation
|
||||
explanation = self._generate_optimization_explanation(
|
||||
original_plan, optimized_plan, buffer_sizes
|
||||
)
|
||||
|
||||
return OptimizationResult(
|
||||
original_plan=original_plan,
|
||||
optimized_plan=optimized_plan,
|
||||
memory_saved=memory_saved,
|
||||
estimated_speedup=estimated_speedup,
|
||||
buffer_sizes=buffer_sizes,
|
||||
spill_strategy=spill_strategy,
|
||||
explanation=explanation
|
||||
)
|
||||
|
||||
def _build_execution_plan(self, query_info: Dict[str, Any],
|
||||
optimize: bool) -> QueryNode:
|
||||
"""Build query execution plan"""
|
||||
tables = query_info['tables']
|
||||
joins = query_info['joins']
|
||||
|
||||
if not tables:
|
||||
return QueryNode(
|
||||
operation="EMPTY",
|
||||
algorithm=None,
|
||||
estimated_rows=0,
|
||||
estimated_size=0,
|
||||
estimated_cost=0,
|
||||
memory_required=0,
|
||||
memory_level="L1",
|
||||
children=[],
|
||||
explanation="Empty query"
|
||||
)
|
||||
|
||||
# Start with first table
|
||||
plan = self._create_scan_node(tables[0], query_info.get('filters'))
|
||||
|
||||
# Add joins
|
||||
for i, join in enumerate(joins):
|
||||
if i + 1 < len(tables):
|
||||
right_table = tables[i + 1]
|
||||
right_scan = self._create_scan_node(right_table, None)
|
||||
|
||||
# Choose join algorithm
|
||||
if optimize:
|
||||
algorithm = self._choose_join_algorithm(
|
||||
plan.estimated_size,
|
||||
right_scan.estimated_size
|
||||
)
|
||||
else:
|
||||
algorithm = JoinAlgorithm.NESTED_LOOP
|
||||
|
||||
plan = self._create_join_node(plan, right_scan, algorithm, join)
|
||||
|
||||
# Add sort if needed
|
||||
if query_info.get('order_by'):
|
||||
plan = self._create_sort_node(plan, optimize)
|
||||
|
||||
# Add aggregation if needed
|
||||
if query_info.get('aggregations'):
|
||||
plan = self._create_aggregation_node(plan, query_info['aggregations'])
|
||||
|
||||
return plan
|
||||
|
||||
def _create_scan_node(self, table_name: str, filters: Optional[str]) -> QueryNode:
|
||||
"""Create table scan node"""
|
||||
stats = self.table_stats.get(table_name, TableStats(
|
||||
name=table_name,
|
||||
row_count=1000,
|
||||
avg_row_size=100,
|
||||
total_size=100000,
|
||||
indexes=[],
|
||||
cardinality={}
|
||||
))
|
||||
|
||||
# Estimate selectivity
|
||||
selectivity = 0.1 if filters else 1.0
|
||||
estimated_rows = int(stats.row_count * selectivity)
|
||||
estimated_size = estimated_rows * stats.avg_row_size
|
||||
|
||||
# Choose scan type
|
||||
scan_type = ScanType.INDEX if stats.indexes and filters else ScanType.SEQUENTIAL
|
||||
|
||||
# Calculate cost
|
||||
cost = self.cost_model.calculate_scan_cost(estimated_size, scan_type)
|
||||
|
||||
level, _ = self.hierarchy.get_level_for_size(estimated_size)
|
||||
|
||||
return QueryNode(
|
||||
operation=f"SCAN {table_name}",
|
||||
algorithm=scan_type.value,
|
||||
estimated_rows=estimated_rows,
|
||||
estimated_size=estimated_size,
|
||||
estimated_cost=cost,
|
||||
memory_required=estimated_size,
|
||||
memory_level=level,
|
||||
children=[],
|
||||
explanation=f"{scan_type.value} scan on {table_name}"
|
||||
)
|
||||
|
||||
def _create_join_node(self, left: QueryNode, right: QueryNode,
|
||||
algorithm: JoinAlgorithm, join_info: Dict) -> QueryNode:
|
||||
"""Create join node"""
|
||||
# Estimate join output size
|
||||
join_selectivity = 0.1 # Simplified
|
||||
estimated_rows = int(left.estimated_rows * right.estimated_rows * join_selectivity)
|
||||
estimated_size = estimated_rows * (left.estimated_size // left.estimated_rows +
|
||||
right.estimated_size // right.estimated_rows)
|
||||
|
||||
# Calculate memory required
|
||||
if algorithm == JoinAlgorithm.HASH_JOIN:
|
||||
memory_required = min(left.estimated_size, right.estimated_size) * 1.5
|
||||
elif algorithm == JoinAlgorithm.SORT_MERGE:
|
||||
memory_required = left.estimated_size + right.estimated_size
|
||||
elif algorithm == JoinAlgorithm.BLOCK_NESTED:
|
||||
memory_required = int(np.sqrt(min(left.estimated_size, right.estimated_size)))
|
||||
else: # NESTED_LOOP
|
||||
memory_required = 1000 # Minimal buffer
|
||||
|
||||
# Calculate buffer size considering memory limit
|
||||
buffer_size = min(memory_required, self.memory_limit)
|
||||
|
||||
# Calculate cost
|
||||
cost = self.cost_model.calculate_join_cost(
|
||||
left.estimated_rows, right.estimated_rows, algorithm, buffer_size
|
||||
)
|
||||
|
||||
level, _ = self.hierarchy.get_level_for_size(memory_required)
|
||||
|
||||
return QueryNode(
|
||||
operation="JOIN",
|
||||
algorithm=algorithm.value,
|
||||
estimated_rows=estimated_rows,
|
||||
estimated_size=estimated_size,
|
||||
estimated_cost=cost + left.estimated_cost + right.estimated_cost,
|
||||
memory_required=memory_required,
|
||||
memory_level=level,
|
||||
children=[left, right],
|
||||
explanation=f"{algorithm.value} join with {buffer_size / 1024:.0f}KB buffer"
|
||||
)
|
||||
|
||||
def _create_sort_node(self, child: QueryNode, optimize: bool) -> QueryNode:
|
||||
"""Create sort node"""
|
||||
if optimize:
|
||||
# Use √n memory for external sort
|
||||
memory_limit = int(np.sqrt(child.estimated_size))
|
||||
else:
|
||||
# Try to sort in memory
|
||||
memory_limit = child.estimated_size
|
||||
|
||||
cost = self.cost_model.calculate_sort_cost(child.estimated_size, memory_limit)
|
||||
level, _ = self.hierarchy.get_level_for_size(memory_limit)
|
||||
|
||||
return QueryNode(
|
||||
operation="SORT",
|
||||
algorithm="external_sort" if memory_limit < child.estimated_size else "quicksort",
|
||||
estimated_rows=child.estimated_rows,
|
||||
estimated_size=child.estimated_size,
|
||||
estimated_cost=cost + child.estimated_cost,
|
||||
memory_required=memory_limit,
|
||||
memory_level=level,
|
||||
children=[child],
|
||||
explanation=f"Sort with {memory_limit / 1024:.0f}KB memory"
|
||||
)
|
||||
|
||||
def _create_aggregation_node(self, child: QueryNode,
|
||||
aggregations: List[str]) -> QueryNode:
|
||||
"""Create aggregation node"""
|
||||
# Estimate groups (simplified)
|
||||
estimated_groups = int(np.sqrt(child.estimated_rows))
|
||||
estimated_size = estimated_groups * 100 # Rough estimate
|
||||
|
||||
# Hash-based aggregation
|
||||
memory_required = estimated_size * 1.5
|
||||
|
||||
level, _ = self.hierarchy.get_level_for_size(memory_required)
|
||||
|
||||
return QueryNode(
|
||||
operation="AGGREGATE",
|
||||
algorithm="hash_aggregate",
|
||||
estimated_rows=estimated_groups,
|
||||
estimated_size=estimated_size,
|
||||
estimated_cost=child.estimated_cost + child.estimated_rows,
|
||||
memory_required=memory_required,
|
||||
memory_level=level,
|
||||
children=[child],
|
||||
explanation=f"Hash aggregation: {', '.join(aggregations)}"
|
||||
)
|
||||
|
||||
def _choose_join_algorithm(self, left_size: int, right_size: int) -> JoinAlgorithm:
|
||||
"""Choose optimal join algorithm based on sizes and memory"""
|
||||
min_size = min(left_size, right_size)
|
||||
max_size = max(left_size, right_size)
|
||||
|
||||
# Can we fit hash table in memory?
|
||||
hash_memory = min_size * 1.5
|
||||
if hash_memory <= self.memory_limit:
|
||||
return JoinAlgorithm.HASH_JOIN
|
||||
|
||||
# Can we fit both relations for sort-merge?
|
||||
sort_memory = left_size + right_size
|
||||
if sort_memory <= self.memory_limit:
|
||||
return JoinAlgorithm.SORT_MERGE
|
||||
|
||||
# Use block nested loop with √n memory
|
||||
sqrt_memory = int(np.sqrt(min_size))
|
||||
if sqrt_memory <= self.memory_limit:
|
||||
return JoinAlgorithm.BLOCK_NESTED
|
||||
|
||||
# Fall back to nested loop
|
||||
return JoinAlgorithm.NESTED_LOOP
|
||||
|
||||
def _calculate_buffer_sizes(self, plan: QueryNode) -> Dict[str, int]:
|
||||
"""Calculate optimal buffer sizes for operations"""
|
||||
buffer_sizes = {}
|
||||
|
||||
def traverse(node: QueryNode, path: str = ""):
|
||||
if node.operation == "SCAN":
|
||||
# √n buffer for sequential scans
|
||||
buffer_size = min(
|
||||
int(np.sqrt(node.estimated_size)),
|
||||
self.memory_limit // 10
|
||||
)
|
||||
buffer_sizes[f"{path}scan_buffer"] = buffer_size
|
||||
|
||||
elif node.operation == "JOIN":
|
||||
# Optimal buffer based on algorithm
|
||||
if node.algorithm == "block_nested":
|
||||
buffer_size = int(np.sqrt(node.memory_required))
|
||||
else:
|
||||
buffer_size = min(node.memory_required, self.memory_limit // 4)
|
||||
buffer_sizes[f"{path}join_buffer"] = buffer_size
|
||||
|
||||
elif node.operation == "SORT":
|
||||
# √n buffer for external sort
|
||||
buffer_size = int(np.sqrt(node.estimated_size))
|
||||
buffer_sizes[f"{path}sort_buffer"] = buffer_size
|
||||
|
||||
for i, child in enumerate(node.children):
|
||||
traverse(child, f"{path}{node.operation}_{i}_")
|
||||
|
||||
traverse(plan)
|
||||
return buffer_sizes
|
||||
|
||||
def _determine_spill_strategy(self, plan: QueryNode) -> Dict[str, str]:
|
||||
"""Determine when and how to spill to disk"""
|
||||
spill_strategy = {}
|
||||
|
||||
def traverse(node: QueryNode, path: str = ""):
|
||||
if node.memory_required > self.memory_limit:
|
||||
if node.operation == "JOIN":
|
||||
if node.algorithm == "hash_join":
|
||||
spill_strategy[path] = "grace_hash_join"
|
||||
elif node.algorithm == "sort_merge":
|
||||
spill_strategy[path] = "external_sort_both_inputs"
|
||||
else:
|
||||
spill_strategy[path] = "block_nested_with_spill"
|
||||
|
||||
elif node.operation == "SORT":
|
||||
spill_strategy[path] = "multi_pass_external_sort"
|
||||
|
||||
elif node.operation == "AGGREGATE":
|
||||
spill_strategy[path] = "spill_partial_aggregates"
|
||||
|
||||
for i, child in enumerate(node.children):
|
||||
traverse(child, f"{path}{node.operation}_{i}_")
|
||||
|
||||
traverse(plan)
|
||||
return spill_strategy
|
||||
|
||||
def _generate_optimization_explanation(self, original: QueryNode,
|
||||
optimized: QueryNode,
|
||||
buffer_sizes: Dict[str, int]) -> str:
|
||||
"""Generate AI-style explanation of optimizations"""
|
||||
explanations = []
|
||||
|
||||
# Overall improvement
|
||||
memory_reduction = (1 - optimized.memory_required / original.memory_required) * 100
|
||||
speedup = original.estimated_cost / optimized.estimated_cost
|
||||
|
||||
explanations.append(
|
||||
f"Optimized query plan reduces memory usage by {memory_reduction:.1f}% "
|
||||
f"with {speedup:.1f}x estimated speedup."
|
||||
)
|
||||
|
||||
# Specific optimizations
|
||||
def compare_nodes(orig: QueryNode, opt: QueryNode, path: str = ""):
|
||||
if orig.algorithm != opt.algorithm:
|
||||
if orig.operation == "JOIN":
|
||||
explanations.append(
|
||||
f"Changed {path} from {orig.algorithm} to {opt.algorithm} "
|
||||
f"saving {(orig.memory_required - opt.memory_required) / 1024:.0f}KB"
|
||||
)
|
||||
elif orig.operation == "SORT":
|
||||
explanations.append(
|
||||
f"Using external sort at {path} with √n memory "
|
||||
f"({opt.memory_required / 1024:.0f}KB instead of "
|
||||
f"{orig.memory_required / 1024:.0f}KB)"
|
||||
)
|
||||
|
||||
for i, (orig_child, opt_child) in enumerate(zip(orig.children, opt.children)):
|
||||
compare_nodes(orig_child, opt_child, f"{path}{orig.operation}_{i}_")
|
||||
|
||||
compare_nodes(original, optimized)
|
||||
|
||||
# Buffer recommendations
|
||||
total_buffers = sum(buffer_sizes.values())
|
||||
explanations.append(
|
||||
f"Allocated {len(buffer_sizes)} buffers totaling "
|
||||
f"{total_buffers / 1024:.0f}KB for optimal performance."
|
||||
)
|
||||
|
||||
# Memory hierarchy awareness
|
||||
if optimized.memory_level != original.memory_level:
|
||||
explanations.append(
|
||||
f"Optimized plan fits in {optimized.memory_level} "
|
||||
f"instead of {original.memory_level}, reducing latency."
|
||||
)
|
||||
|
||||
return " ".join(explanations)
|
||||
|
||||
def explain_plan(self, plan: QueryNode, indent: int = 0) -> str:
|
||||
"""Generate text representation of query plan"""
|
||||
lines = []
|
||||
prefix = " " * indent
|
||||
|
||||
lines.append(f"{prefix}{plan.operation} ({plan.algorithm})")
|
||||
lines.append(f"{prefix} Rows: {plan.estimated_rows:,}")
|
||||
lines.append(f"{prefix} Size: {plan.estimated_size / 1024:.1f}KB")
|
||||
lines.append(f"{prefix} Memory: {plan.memory_required / 1024:.1f}KB ({plan.memory_level})")
|
||||
lines.append(f"{prefix} Cost: {plan.estimated_cost:.0f}")
|
||||
|
||||
for child in plan.children:
|
||||
lines.append(self.explain_plan(child, indent + 1))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def apply_hints(self, sql: str, target: str = 'latency',
|
||||
memory_limit: Optional[str] = None) -> str:
|
||||
"""Apply optimizer hints to SQL query"""
|
||||
# Parse memory limit if provided
|
||||
if memory_limit:
|
||||
limit_match = re.match(r'(\d+)(MB|GB)?', memory_limit, re.IGNORECASE)
|
||||
if limit_match:
|
||||
value = int(limit_match.group(1))
|
||||
unit = limit_match.group(2) or 'MB'
|
||||
if unit.upper() == 'GB':
|
||||
value *= 1024
|
||||
self.memory_limit = value * 1024 * 1024
|
||||
|
||||
# Optimize query
|
||||
result = self.optimize_query(sql)
|
||||
|
||||
# Generate hint comment
|
||||
hint = f"/* SpaceTime Optimizer: {result.explanation} */\n"
|
||||
|
||||
return hint + sql
|
||||
|
||||
|
||||
# Example usage and testing
|
||||
if __name__ == "__main__":
|
||||
# Create test database
|
||||
conn = sqlite3.connect(':memory:')
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create test tables
|
||||
cursor.execute("""
|
||||
CREATE TABLE customers (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
country TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE orders (
|
||||
id INTEGER PRIMARY KEY,
|
||||
customer_id INTEGER,
|
||||
amount REAL,
|
||||
date TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE products (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
price REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test data
|
||||
for i in range(10000):
|
||||
cursor.execute("INSERT INTO customers VALUES (?, ?, ?)",
|
||||
(i, f"Customer {i}", f"Country {i % 100}"))
|
||||
|
||||
for i in range(50000):
|
||||
cursor.execute("INSERT INTO orders VALUES (?, ?, ?, ?)",
|
||||
(i, i % 10000, i * 10.0, '2024-01-01'))
|
||||
|
||||
for i in range(1000):
|
||||
cursor.execute("INSERT INTO products VALUES (?, ?, ?)",
|
||||
(i, f"Product {i}", i * 5.0))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Create optimizer
|
||||
optimizer = MemoryAwareOptimizer(conn, memory_limit=1024*1024) # 1MB limit
|
||||
|
||||
# Test queries
|
||||
queries = [
|
||||
"""
|
||||
SELECT c.name, SUM(o.amount)
|
||||
FROM customers c
|
||||
JOIN orders o ON c.id = o.customer_id
|
||||
WHERE c.country = 'Country 1'
|
||||
GROUP BY c.name
|
||||
ORDER BY SUM(o.amount) DESC
|
||||
""",
|
||||
|
||||
"""
|
||||
SELECT *
|
||||
FROM orders o1
|
||||
JOIN orders o2 ON o1.customer_id = o2.customer_id
|
||||
WHERE o1.amount > 1000
|
||||
"""
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries, 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Query {i}:")
|
||||
print(query.strip())
|
||||
print("="*60)
|
||||
|
||||
# Optimize query
|
||||
result = optimizer.optimize_query(query)
|
||||
|
||||
print("\nOriginal Plan:")
|
||||
print(optimizer.explain_plan(result.original_plan))
|
||||
|
||||
print("\nOptimized Plan:")
|
||||
print(optimizer.explain_plan(result.optimized_plan))
|
||||
|
||||
print(f"\nOptimization Results:")
|
||||
print(f" Memory Saved: {result.memory_saved / 1024:.1f}KB")
|
||||
print(f" Estimated Speedup: {result.estimated_speedup:.1f}x")
|
||||
print(f"\nBuffer Sizes:")
|
||||
for name, size in result.buffer_sizes.items():
|
||||
print(f" {name}: {size / 1024:.1f}KB")
|
||||
|
||||
if result.spill_strategy:
|
||||
print(f"\nSpill Strategy:")
|
||||
for op, strategy in result.spill_strategy.items():
|
||||
print(f" {op}: {strategy}")
|
||||
|
||||
print(f"\nExplanation: {result.explanation}")
|
||||
|
||||
# Test hint application
|
||||
print("\n" + "="*60)
|
||||
print("Query with hints:")
|
||||
print("="*60)
|
||||
|
||||
hinted_sql = optimizer.apply_hints(
|
||||
"SELECT * FROM customers c JOIN orders o ON c.id = o.customer_id",
|
||||
target='memory',
|
||||
memory_limit='512KB'
|
||||
)
|
||||
print(hinted_sql)
|
||||
|
||||
conn.close()
|
||||
Reference in New Issue
Block a user