761 lines
27 KiB
Python
761 lines
27 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Memory-Aware Query Optimizer: Database query optimizer considering memory hierarchies
|
|
|
|
Features:
|
|
- Cost Model: Include L3/RAM/SSD boundaries in cost calculations
|
|
- Algorithm Selection: Choose between hash/sort/nested-loop based on true costs
|
|
- Buffer Sizing: Automatically size buffers to √(data_size)
|
|
- Spill Planning: Optimize when and how to spill to disk
|
|
- AI Explanations: Clear reasoning for optimization decisions
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import sqlite3
|
|
import psutil
|
|
import numpy as np
|
|
import time
|
|
import json
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Dict, List, Tuple, Optional, Any, Union
|
|
from enum import Enum
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# Import core components
|
|
from core.spacetime_core import (
|
|
MemoryHierarchy,
|
|
SqrtNCalculator,
|
|
OptimizationStrategy,
|
|
StrategyAnalyzer
|
|
)
|
|
|
|
|
|
class JoinAlgorithm(Enum):
|
|
"""Join algorithms with different space-time tradeoffs"""
|
|
NESTED_LOOP = "nested_loop" # O(1) space, O(n*m) time
|
|
SORT_MERGE = "sort_merge" # O(n+m) space, O(n log n + m log m) time
|
|
HASH_JOIN = "hash_join" # O(min(n,m)) space, O(n+m) time
|
|
BLOCK_NESTED = "block_nested" # O(√n) space, O(n*m/√n) time
|
|
|
|
|
|
class ScanType(Enum):
|
|
"""Scan types for table access"""
|
|
SEQUENTIAL = "sequential" # Full table scan
|
|
INDEX = "index" # Index scan
|
|
BITMAP = "bitmap" # Bitmap index scan
|
|
|
|
|
|
@dataclass
|
|
class TableStats:
|
|
"""Statistics about a database table"""
|
|
name: str
|
|
row_count: int
|
|
avg_row_size: int
|
|
total_size: int
|
|
indexes: List[str]
|
|
cardinality: Dict[str, int] # Column -> distinct values
|
|
|
|
|
|
@dataclass
|
|
class QueryNode:
|
|
"""Node in query execution plan"""
|
|
operation: str
|
|
algorithm: Optional[str]
|
|
estimated_rows: int
|
|
estimated_size: int
|
|
estimated_cost: float
|
|
memory_required: int
|
|
memory_level: str
|
|
children: List['QueryNode']
|
|
explanation: str
|
|
|
|
|
|
@dataclass
|
|
class OptimizationResult:
|
|
"""Result of query optimization"""
|
|
original_plan: QueryNode
|
|
optimized_plan: QueryNode
|
|
memory_saved: int
|
|
estimated_speedup: float
|
|
buffer_sizes: Dict[str, int]
|
|
spill_strategy: Dict[str, str]
|
|
explanation: str
|
|
|
|
|
|
class CostModel:
|
|
"""Cost model considering memory hierarchy"""
|
|
|
|
def __init__(self, hierarchy: MemoryHierarchy):
|
|
self.hierarchy = hierarchy
|
|
|
|
# Cost factors (relative to L1 access)
|
|
self.cpu_factor = 0.1
|
|
self.l1_factor = 1.0
|
|
self.l2_factor = 4.0
|
|
self.l3_factor = 12.0
|
|
self.ram_factor = 100.0
|
|
self.disk_factor = 10000.0
|
|
|
|
def calculate_scan_cost(self, table_size: int, scan_type: ScanType) -> float:
|
|
"""Calculate cost of scanning a table"""
|
|
level, latency = self.hierarchy.get_level_for_size(table_size)
|
|
|
|
if scan_type == ScanType.SEQUENTIAL:
|
|
# Sequential scan benefits from prefetching
|
|
return table_size * latency * 0.5
|
|
elif scan_type == ScanType.INDEX:
|
|
# Random access pattern
|
|
return table_size * latency * 2.0
|
|
else: # BITMAP
|
|
# Mixed pattern
|
|
return table_size * latency
|
|
|
|
def calculate_join_cost(self, left_size: int, right_size: int,
|
|
algorithm: JoinAlgorithm, buffer_size: int) -> float:
|
|
"""Calculate cost of join operation"""
|
|
if algorithm == JoinAlgorithm.NESTED_LOOP:
|
|
# O(n*m) comparisons, minimal memory
|
|
comparisons = left_size * right_size
|
|
memory_used = buffer_size
|
|
|
|
elif algorithm == JoinAlgorithm.SORT_MERGE:
|
|
# Sort both sides then merge
|
|
sort_cost = left_size * np.log2(left_size) + right_size * np.log2(right_size)
|
|
merge_cost = left_size + right_size
|
|
comparisons = sort_cost + merge_cost
|
|
memory_used = left_size + right_size
|
|
|
|
elif algorithm == JoinAlgorithm.HASH_JOIN:
|
|
# Build hash table on smaller side
|
|
build_size = min(left_size, right_size)
|
|
probe_size = max(left_size, right_size)
|
|
comparisons = build_size + probe_size
|
|
memory_used = build_size * 1.5 # Hash table overhead
|
|
|
|
else: # BLOCK_NESTED
|
|
# Process in √n blocks
|
|
block_size = int(np.sqrt(min(left_size, right_size)))
|
|
blocks = (left_size // block_size) * (right_size // block_size)
|
|
comparisons = blocks * block_size * block_size
|
|
memory_used = block_size
|
|
|
|
# Get memory level for this operation
|
|
level, latency = self.hierarchy.get_level_for_size(memory_used)
|
|
|
|
# Add spill cost if memory exceeded
|
|
spill_cost = 0
|
|
if memory_used > buffer_size:
|
|
spill_ratio = memory_used / buffer_size
|
|
spill_cost = comparisons * self.disk_factor * 0.1 * spill_ratio
|
|
|
|
return comparisons * latency + spill_cost
|
|
|
|
def calculate_sort_cost(self, data_size: int, memory_limit: int) -> float:
|
|
"""Calculate cost of sorting with limited memory"""
|
|
if data_size <= memory_limit:
|
|
# In-memory sort
|
|
comparisons = data_size * np.log2(data_size)
|
|
level, latency = self.hierarchy.get_level_for_size(data_size)
|
|
return comparisons * latency
|
|
else:
|
|
# External sort with √n memory
|
|
runs = data_size // memory_limit
|
|
merge_passes = np.log2(runs)
|
|
total_io = data_size * merge_passes * 2 # Read + write
|
|
return total_io * self.disk_factor
|
|
|
|
|
|
class QueryAnalyzer:
|
|
"""Analyze queries and extract operations"""
|
|
|
|
@staticmethod
|
|
def parse_query(sql: str) -> Dict[str, Any]:
|
|
"""Parse SQL query to extract operations"""
|
|
sql_upper = sql.upper()
|
|
|
|
# Extract tables
|
|
tables = []
|
|
from_match = re.search(r'FROM\s+(\w+)', sql_upper)
|
|
if from_match:
|
|
tables.append(from_match.group(1))
|
|
|
|
join_matches = re.findall(r'JOIN\s+(\w+)', sql_upper)
|
|
tables.extend(join_matches)
|
|
|
|
# Extract join conditions
|
|
joins = []
|
|
join_pattern = r'(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)'
|
|
for match in re.finditer(join_pattern, sql, re.IGNORECASE):
|
|
joins.append({
|
|
'left_table': match.group(1),
|
|
'left_col': match.group(2),
|
|
'right_table': match.group(3),
|
|
'right_col': match.group(4)
|
|
})
|
|
|
|
# Extract filters
|
|
where_match = re.search(r'WHERE\s+(.+?)(?:GROUP|ORDER|LIMIT|$)', sql_upper)
|
|
filters = where_match.group(1) if where_match else None
|
|
|
|
# Extract aggregations
|
|
agg_functions = ['COUNT', 'SUM', 'AVG', 'MIN', 'MAX']
|
|
aggregations = []
|
|
for func in agg_functions:
|
|
if func in sql_upper:
|
|
aggregations.append(func)
|
|
|
|
# Extract order by
|
|
order_match = re.search(r'ORDER\s+BY\s+(.+?)(?:LIMIT|$)', sql_upper)
|
|
order_by = order_match.group(1) if order_match else None
|
|
|
|
return {
|
|
'tables': tables,
|
|
'joins': joins,
|
|
'filters': filters,
|
|
'aggregations': aggregations,
|
|
'order_by': order_by
|
|
}
|
|
|
|
|
|
class MemoryAwareOptimizer:
|
|
"""Main query optimizer with memory awareness"""
|
|
|
|
def __init__(self, connection: sqlite3.Connection,
|
|
memory_limit: Optional[int] = None):
|
|
self.conn = connection
|
|
self.hierarchy = MemoryHierarchy.detect_system()
|
|
self.cost_model = CostModel(self.hierarchy)
|
|
self.memory_limit = memory_limit or int(psutil.virtual_memory().available * 0.5)
|
|
self.table_stats = {}
|
|
|
|
# Collect table statistics
|
|
self._collect_statistics()
|
|
|
|
def _collect_statistics(self):
|
|
"""Collect statistics about database tables"""
|
|
cursor = self.conn.cursor()
|
|
|
|
# Get all tables
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = cursor.fetchall()
|
|
|
|
for (table_name,) in tables:
|
|
# Get row count
|
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
|
row_count = cursor.fetchone()[0]
|
|
|
|
# Estimate row size (simplified)
|
|
cursor.execute(f"PRAGMA table_info({table_name})")
|
|
columns = cursor.fetchall()
|
|
avg_row_size = len(columns) * 20 # Rough estimate
|
|
|
|
# Get indexes
|
|
cursor.execute(f"PRAGMA index_list({table_name})")
|
|
indexes = [idx[1] for idx in cursor.fetchall()]
|
|
|
|
self.table_stats[table_name] = TableStats(
|
|
name=table_name,
|
|
row_count=row_count,
|
|
avg_row_size=avg_row_size,
|
|
total_size=row_count * avg_row_size,
|
|
indexes=indexes,
|
|
cardinality={}
|
|
)
|
|
|
|
def optimize_query(self, sql: str) -> OptimizationResult:
|
|
"""Optimize a SQL query considering memory constraints"""
|
|
# Parse query
|
|
query_info = QueryAnalyzer.parse_query(sql)
|
|
|
|
# Build original plan
|
|
original_plan = self._build_execution_plan(query_info, optimize=False)
|
|
|
|
# Build optimized plan
|
|
optimized_plan = self._build_execution_plan(query_info, optimize=True)
|
|
|
|
# Calculate buffer sizes
|
|
buffer_sizes = self._calculate_buffer_sizes(optimized_plan)
|
|
|
|
# Determine spill strategy
|
|
spill_strategy = self._determine_spill_strategy(optimized_plan)
|
|
|
|
# Calculate improvements
|
|
memory_saved = original_plan.memory_required - optimized_plan.memory_required
|
|
estimated_speedup = original_plan.estimated_cost / optimized_plan.estimated_cost
|
|
|
|
# Generate explanation
|
|
explanation = self._generate_optimization_explanation(
|
|
original_plan, optimized_plan, buffer_sizes
|
|
)
|
|
|
|
return OptimizationResult(
|
|
original_plan=original_plan,
|
|
optimized_plan=optimized_plan,
|
|
memory_saved=memory_saved,
|
|
estimated_speedup=estimated_speedup,
|
|
buffer_sizes=buffer_sizes,
|
|
spill_strategy=spill_strategy,
|
|
explanation=explanation
|
|
)
|
|
|
|
def _build_execution_plan(self, query_info: Dict[str, Any],
|
|
optimize: bool) -> QueryNode:
|
|
"""Build query execution plan"""
|
|
tables = query_info['tables']
|
|
joins = query_info['joins']
|
|
|
|
if not tables:
|
|
return QueryNode(
|
|
operation="EMPTY",
|
|
algorithm=None,
|
|
estimated_rows=0,
|
|
estimated_size=0,
|
|
estimated_cost=0,
|
|
memory_required=0,
|
|
memory_level="L1",
|
|
children=[],
|
|
explanation="Empty query"
|
|
)
|
|
|
|
# Start with first table
|
|
plan = self._create_scan_node(tables[0], query_info.get('filters'))
|
|
|
|
# Add joins
|
|
for i, join in enumerate(joins):
|
|
if i + 1 < len(tables):
|
|
right_table = tables[i + 1]
|
|
right_scan = self._create_scan_node(right_table, None)
|
|
|
|
# Choose join algorithm
|
|
if optimize:
|
|
algorithm = self._choose_join_algorithm(
|
|
plan.estimated_size,
|
|
right_scan.estimated_size
|
|
)
|
|
else:
|
|
algorithm = JoinAlgorithm.NESTED_LOOP
|
|
|
|
plan = self._create_join_node(plan, right_scan, algorithm, join)
|
|
|
|
# Add sort if needed
|
|
if query_info.get('order_by'):
|
|
plan = self._create_sort_node(plan, optimize)
|
|
|
|
# Add aggregation if needed
|
|
if query_info.get('aggregations'):
|
|
plan = self._create_aggregation_node(plan, query_info['aggregations'])
|
|
|
|
return plan
|
|
|
|
def _create_scan_node(self, table_name: str, filters: Optional[str]) -> QueryNode:
|
|
"""Create table scan node"""
|
|
stats = self.table_stats.get(table_name, TableStats(
|
|
name=table_name,
|
|
row_count=1000,
|
|
avg_row_size=100,
|
|
total_size=100000,
|
|
indexes=[],
|
|
cardinality={}
|
|
))
|
|
|
|
# Estimate selectivity
|
|
selectivity = 0.1 if filters else 1.0
|
|
estimated_rows = int(stats.row_count * selectivity)
|
|
estimated_size = estimated_rows * stats.avg_row_size
|
|
|
|
# Choose scan type
|
|
scan_type = ScanType.INDEX if stats.indexes and filters else ScanType.SEQUENTIAL
|
|
|
|
# Calculate cost
|
|
cost = self.cost_model.calculate_scan_cost(estimated_size, scan_type)
|
|
|
|
level, _ = self.hierarchy.get_level_for_size(estimated_size)
|
|
|
|
return QueryNode(
|
|
operation=f"SCAN {table_name}",
|
|
algorithm=scan_type.value,
|
|
estimated_rows=estimated_rows,
|
|
estimated_size=estimated_size,
|
|
estimated_cost=cost,
|
|
memory_required=estimated_size,
|
|
memory_level=level,
|
|
children=[],
|
|
explanation=f"{scan_type.value} scan on {table_name}"
|
|
)
|
|
|
|
def _create_join_node(self, left: QueryNode, right: QueryNode,
|
|
algorithm: JoinAlgorithm, join_info: Dict) -> QueryNode:
|
|
"""Create join node"""
|
|
# Estimate join output size
|
|
join_selectivity = 0.1 # Simplified
|
|
estimated_rows = int(left.estimated_rows * right.estimated_rows * join_selectivity)
|
|
estimated_size = estimated_rows * (left.estimated_size // left.estimated_rows +
|
|
right.estimated_size // right.estimated_rows)
|
|
|
|
# Calculate memory required
|
|
if algorithm == JoinAlgorithm.HASH_JOIN:
|
|
memory_required = min(left.estimated_size, right.estimated_size) * 1.5
|
|
elif algorithm == JoinAlgorithm.SORT_MERGE:
|
|
memory_required = left.estimated_size + right.estimated_size
|
|
elif algorithm == JoinAlgorithm.BLOCK_NESTED:
|
|
memory_required = int(np.sqrt(min(left.estimated_size, right.estimated_size)))
|
|
else: # NESTED_LOOP
|
|
memory_required = 1000 # Minimal buffer
|
|
|
|
# Calculate buffer size considering memory limit
|
|
buffer_size = min(memory_required, self.memory_limit)
|
|
|
|
# Calculate cost
|
|
cost = self.cost_model.calculate_join_cost(
|
|
left.estimated_rows, right.estimated_rows, algorithm, buffer_size
|
|
)
|
|
|
|
level, _ = self.hierarchy.get_level_for_size(memory_required)
|
|
|
|
return QueryNode(
|
|
operation="JOIN",
|
|
algorithm=algorithm.value,
|
|
estimated_rows=estimated_rows,
|
|
estimated_size=estimated_size,
|
|
estimated_cost=cost + left.estimated_cost + right.estimated_cost,
|
|
memory_required=memory_required,
|
|
memory_level=level,
|
|
children=[left, right],
|
|
explanation=f"{algorithm.value} join with {buffer_size / 1024:.0f}KB buffer"
|
|
)
|
|
|
|
def _create_sort_node(self, child: QueryNode, optimize: bool) -> QueryNode:
|
|
"""Create sort node"""
|
|
if optimize:
|
|
# Use √n memory for external sort
|
|
memory_limit = int(np.sqrt(child.estimated_size))
|
|
else:
|
|
# Try to sort in memory
|
|
memory_limit = child.estimated_size
|
|
|
|
cost = self.cost_model.calculate_sort_cost(child.estimated_size, memory_limit)
|
|
level, _ = self.hierarchy.get_level_for_size(memory_limit)
|
|
|
|
return QueryNode(
|
|
operation="SORT",
|
|
algorithm="external_sort" if memory_limit < child.estimated_size else "quicksort",
|
|
estimated_rows=child.estimated_rows,
|
|
estimated_size=child.estimated_size,
|
|
estimated_cost=cost + child.estimated_cost,
|
|
memory_required=memory_limit,
|
|
memory_level=level,
|
|
children=[child],
|
|
explanation=f"Sort with {memory_limit / 1024:.0f}KB memory"
|
|
)
|
|
|
|
def _create_aggregation_node(self, child: QueryNode,
|
|
aggregations: List[str]) -> QueryNode:
|
|
"""Create aggregation node"""
|
|
# Estimate groups (simplified)
|
|
estimated_groups = int(np.sqrt(child.estimated_rows))
|
|
estimated_size = estimated_groups * 100 # Rough estimate
|
|
|
|
# Hash-based aggregation
|
|
memory_required = estimated_size * 1.5
|
|
|
|
level, _ = self.hierarchy.get_level_for_size(memory_required)
|
|
|
|
return QueryNode(
|
|
operation="AGGREGATE",
|
|
algorithm="hash_aggregate",
|
|
estimated_rows=estimated_groups,
|
|
estimated_size=estimated_size,
|
|
estimated_cost=child.estimated_cost + child.estimated_rows,
|
|
memory_required=memory_required,
|
|
memory_level=level,
|
|
children=[child],
|
|
explanation=f"Hash aggregation: {', '.join(aggregations)}"
|
|
)
|
|
|
|
def _choose_join_algorithm(self, left_size: int, right_size: int) -> JoinAlgorithm:
|
|
"""Choose optimal join algorithm based on sizes and memory"""
|
|
min_size = min(left_size, right_size)
|
|
max_size = max(left_size, right_size)
|
|
|
|
# Can we fit hash table in memory?
|
|
hash_memory = min_size * 1.5
|
|
if hash_memory <= self.memory_limit:
|
|
return JoinAlgorithm.HASH_JOIN
|
|
|
|
# Can we fit both relations for sort-merge?
|
|
sort_memory = left_size + right_size
|
|
if sort_memory <= self.memory_limit:
|
|
return JoinAlgorithm.SORT_MERGE
|
|
|
|
# Use block nested loop with √n memory
|
|
sqrt_memory = int(np.sqrt(min_size))
|
|
if sqrt_memory <= self.memory_limit:
|
|
return JoinAlgorithm.BLOCK_NESTED
|
|
|
|
# Fall back to nested loop
|
|
return JoinAlgorithm.NESTED_LOOP
|
|
|
|
def _calculate_buffer_sizes(self, plan: QueryNode) -> Dict[str, int]:
|
|
"""Calculate optimal buffer sizes for operations"""
|
|
buffer_sizes = {}
|
|
|
|
def traverse(node: QueryNode, path: str = ""):
|
|
if node.operation == "SCAN":
|
|
# √n buffer for sequential scans
|
|
buffer_size = min(
|
|
int(np.sqrt(node.estimated_size)),
|
|
self.memory_limit // 10
|
|
)
|
|
buffer_sizes[f"{path}scan_buffer"] = buffer_size
|
|
|
|
elif node.operation == "JOIN":
|
|
# Optimal buffer based on algorithm
|
|
if node.algorithm == "block_nested":
|
|
buffer_size = int(np.sqrt(node.memory_required))
|
|
else:
|
|
buffer_size = min(node.memory_required, self.memory_limit // 4)
|
|
buffer_sizes[f"{path}join_buffer"] = buffer_size
|
|
|
|
elif node.operation == "SORT":
|
|
# √n buffer for external sort
|
|
buffer_size = int(np.sqrt(node.estimated_size))
|
|
buffer_sizes[f"{path}sort_buffer"] = buffer_size
|
|
|
|
for i, child in enumerate(node.children):
|
|
traverse(child, f"{path}{node.operation}_{i}_")
|
|
|
|
traverse(plan)
|
|
return buffer_sizes
|
|
|
|
def _determine_spill_strategy(self, plan: QueryNode) -> Dict[str, str]:
|
|
"""Determine when and how to spill to disk"""
|
|
spill_strategy = {}
|
|
|
|
def traverse(node: QueryNode, path: str = ""):
|
|
if node.memory_required > self.memory_limit:
|
|
if node.operation == "JOIN":
|
|
if node.algorithm == "hash_join":
|
|
spill_strategy[path] = "grace_hash_join"
|
|
elif node.algorithm == "sort_merge":
|
|
spill_strategy[path] = "external_sort_both_inputs"
|
|
else:
|
|
spill_strategy[path] = "block_nested_with_spill"
|
|
|
|
elif node.operation == "SORT":
|
|
spill_strategy[path] = "multi_pass_external_sort"
|
|
|
|
elif node.operation == "AGGREGATE":
|
|
spill_strategy[path] = "spill_partial_aggregates"
|
|
|
|
for i, child in enumerate(node.children):
|
|
traverse(child, f"{path}{node.operation}_{i}_")
|
|
|
|
traverse(plan)
|
|
return spill_strategy
|
|
|
|
def _generate_optimization_explanation(self, original: QueryNode,
|
|
optimized: QueryNode,
|
|
buffer_sizes: Dict[str, int]) -> str:
|
|
"""Generate AI-style explanation of optimizations"""
|
|
explanations = []
|
|
|
|
# Overall improvement
|
|
memory_reduction = (1 - optimized.memory_required / original.memory_required) * 100
|
|
speedup = original.estimated_cost / optimized.estimated_cost
|
|
|
|
explanations.append(
|
|
f"Optimized query plan reduces memory usage by {memory_reduction:.1f}% "
|
|
f"with {speedup:.1f}x estimated speedup."
|
|
)
|
|
|
|
# Specific optimizations
|
|
def compare_nodes(orig: QueryNode, opt: QueryNode, path: str = ""):
|
|
if orig.algorithm != opt.algorithm:
|
|
if orig.operation == "JOIN":
|
|
explanations.append(
|
|
f"Changed {path} from {orig.algorithm} to {opt.algorithm} "
|
|
f"saving {(orig.memory_required - opt.memory_required) / 1024:.0f}KB"
|
|
)
|
|
elif orig.operation == "SORT":
|
|
explanations.append(
|
|
f"Using external sort at {path} with √n memory "
|
|
f"({opt.memory_required / 1024:.0f}KB instead of "
|
|
f"{orig.memory_required / 1024:.0f}KB)"
|
|
)
|
|
|
|
for i, (orig_child, opt_child) in enumerate(zip(orig.children, opt.children)):
|
|
compare_nodes(orig_child, opt_child, f"{path}{orig.operation}_{i}_")
|
|
|
|
compare_nodes(original, optimized)
|
|
|
|
# Buffer recommendations
|
|
total_buffers = sum(buffer_sizes.values())
|
|
explanations.append(
|
|
f"Allocated {len(buffer_sizes)} buffers totaling "
|
|
f"{total_buffers / 1024:.0f}KB for optimal performance."
|
|
)
|
|
|
|
# Memory hierarchy awareness
|
|
if optimized.memory_level != original.memory_level:
|
|
explanations.append(
|
|
f"Optimized plan fits in {optimized.memory_level} "
|
|
f"instead of {original.memory_level}, reducing latency."
|
|
)
|
|
|
|
return " ".join(explanations)
|
|
|
|
def explain_plan(self, plan: QueryNode, indent: int = 0) -> str:
|
|
"""Generate text representation of query plan"""
|
|
lines = []
|
|
prefix = " " * indent
|
|
|
|
lines.append(f"{prefix}{plan.operation} ({plan.algorithm})")
|
|
lines.append(f"{prefix} Rows: {plan.estimated_rows:,}")
|
|
lines.append(f"{prefix} Size: {plan.estimated_size / 1024:.1f}KB")
|
|
lines.append(f"{prefix} Memory: {plan.memory_required / 1024:.1f}KB ({plan.memory_level})")
|
|
lines.append(f"{prefix} Cost: {plan.estimated_cost:.0f}")
|
|
|
|
for child in plan.children:
|
|
lines.append(self.explain_plan(child, indent + 1))
|
|
|
|
return "\n".join(lines)
|
|
|
|
def apply_hints(self, sql: str, target: str = 'latency',
|
|
memory_limit: Optional[str] = None) -> str:
|
|
"""Apply optimizer hints to SQL query"""
|
|
# Parse memory limit if provided
|
|
if memory_limit:
|
|
limit_match = re.match(r'(\d+)(MB|GB)?', memory_limit, re.IGNORECASE)
|
|
if limit_match:
|
|
value = int(limit_match.group(1))
|
|
unit = limit_match.group(2) or 'MB'
|
|
if unit.upper() == 'GB':
|
|
value *= 1024
|
|
self.memory_limit = value * 1024 * 1024
|
|
|
|
# Optimize query
|
|
result = self.optimize_query(sql)
|
|
|
|
# Generate hint comment
|
|
hint = f"/* SpaceTime Optimizer: {result.explanation} */\n"
|
|
|
|
return hint + sql
|
|
|
|
|
|
# Example usage and testing
|
|
if __name__ == "__main__":
|
|
# Create test database
|
|
conn = sqlite3.connect(':memory:')
|
|
cursor = conn.cursor()
|
|
|
|
# Create test tables
|
|
cursor.execute("""
|
|
CREATE TABLE customers (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT,
|
|
country TEXT
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE orders (
|
|
id INTEGER PRIMARY KEY,
|
|
customer_id INTEGER,
|
|
amount REAL,
|
|
date TEXT
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE products (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT,
|
|
price REAL
|
|
)
|
|
""")
|
|
|
|
# Insert test data
|
|
for i in range(10000):
|
|
cursor.execute("INSERT INTO customers VALUES (?, ?, ?)",
|
|
(i, f"Customer {i}", f"Country {i % 100}"))
|
|
|
|
for i in range(50000):
|
|
cursor.execute("INSERT INTO orders VALUES (?, ?, ?, ?)",
|
|
(i, i % 10000, i * 10.0, '2024-01-01'))
|
|
|
|
for i in range(1000):
|
|
cursor.execute("INSERT INTO products VALUES (?, ?, ?)",
|
|
(i, f"Product {i}", i * 5.0))
|
|
|
|
conn.commit()
|
|
|
|
# Create optimizer
|
|
optimizer = MemoryAwareOptimizer(conn, memory_limit=1024*1024) # 1MB limit
|
|
|
|
# Test queries
|
|
queries = [
|
|
"""
|
|
SELECT c.name, SUM(o.amount)
|
|
FROM customers c
|
|
JOIN orders o ON c.id = o.customer_id
|
|
WHERE c.country = 'Country 1'
|
|
GROUP BY c.name
|
|
ORDER BY SUM(o.amount) DESC
|
|
""",
|
|
|
|
"""
|
|
SELECT *
|
|
FROM orders o1
|
|
JOIN orders o2 ON o1.customer_id = o2.customer_id
|
|
WHERE o1.amount > 1000
|
|
"""
|
|
]
|
|
|
|
for i, query in enumerate(queries, 1):
|
|
print(f"\n{'='*60}")
|
|
print(f"Query {i}:")
|
|
print(query.strip())
|
|
print("="*60)
|
|
|
|
# Optimize query
|
|
result = optimizer.optimize_query(query)
|
|
|
|
print("\nOriginal Plan:")
|
|
print(optimizer.explain_plan(result.original_plan))
|
|
|
|
print("\nOptimized Plan:")
|
|
print(optimizer.explain_plan(result.optimized_plan))
|
|
|
|
print(f"\nOptimization Results:")
|
|
print(f" Memory Saved: {result.memory_saved / 1024:.1f}KB")
|
|
print(f" Estimated Speedup: {result.estimated_speedup:.1f}x")
|
|
print(f"\nBuffer Sizes:")
|
|
for name, size in result.buffer_sizes.items():
|
|
print(f" {name}: {size / 1024:.1f}KB")
|
|
|
|
if result.spill_strategy:
|
|
print(f"\nSpill Strategy:")
|
|
for op, strategy in result.spill_strategy.items():
|
|
print(f" {op}: {strategy}")
|
|
|
|
print(f"\nExplanation: {result.explanation}")
|
|
|
|
# Test hint application
|
|
print("\n" + "="*60)
|
|
print("Query with hints:")
|
|
print("="*60)
|
|
|
|
hinted_sql = optimizer.apply_hints(
|
|
"SELECT * FROM customers c JOIN orders o ON c.id = o.customer_id",
|
|
target='memory',
|
|
memory_limit='512KB'
|
|
)
|
|
print(hinted_sql)
|
|
|
|
conn.close()
|