Autarch/core/model_router.py

306 lines
11 KiB
Python
Raw Normal View History

"""
AUTARCH Model Router
Manages concurrent SLM/LAM/SAM model instances for autonomous operation.
Model Tiers:
SLM (Small Language Model) Fast classification, routing, yes/no decisions
SAM (Small Action Model) Quick tool execution, simple automated responses
LAM (Large Action Model) Complex multi-step agent tasks, strategic planning
"""
import json
import logging
import threading
from typing import Optional, Dict, Any
from enum import Enum
from .config import get_config
_logger = logging.getLogger('autarch.model_router')
class ModelTier(Enum):
SLM = 'slm'
SAM = 'sam'
LAM = 'lam'
# Fallback chain: if a tier fails, try the next one
_FALLBACK = {
ModelTier.SLM: [ModelTier.SAM, ModelTier.LAM],
ModelTier.SAM: [ModelTier.LAM],
ModelTier.LAM: [],
}
class _TierConfigProxy:
"""Proxies Config but overrides the backend section for a specific model tier.
When a tier says backend=local with model_path=X, this proxy makes the LLM
class (which reads [llama]) see the tier's model_path/n_ctx/etc instead.
"""
def __init__(self, base_config, tier_name: str):
self._base = base_config
self._tier = tier_name
self._overrides: Dict[str, Dict[str, str]] = {}
self._build_overrides()
def _build_overrides(self):
backend = self._base.get(self._tier, 'backend', 'local')
model_path = self._base.get(self._tier, 'model_path', '')
n_ctx = self._base.get(self._tier, 'n_ctx', '2048')
n_gpu_layers = self._base.get(self._tier, 'n_gpu_layers', '-1')
n_threads = self._base.get(self._tier, 'n_threads', '4')
if backend == 'local':
self._overrides['llama'] = {
'model_path': model_path,
'n_ctx': n_ctx,
'n_gpu_layers': n_gpu_layers,
'n_threads': n_threads,
}
elif backend == 'transformers':
self._overrides['transformers'] = {
'model_path': model_path,
}
# claude and huggingface are API-based — no path override needed
def get(self, section: str, key: str, fallback=None):
overrides = self._overrides.get(section, {})
if key in overrides:
return overrides[key]
return self._base.get(section, key, fallback)
def get_int(self, section: str, key: str, fallback: int = 0) -> int:
overrides = self._overrides.get(section, {})
if key in overrides:
try:
return int(overrides[key])
except (ValueError, TypeError):
return fallback
return self._base.get_int(section, key, fallback)
def get_float(self, section: str, key: str, fallback: float = 0.0) -> float:
overrides = self._overrides.get(section, {})
if key in overrides:
try:
return float(overrides[key])
except (ValueError, TypeError):
return fallback
return self._base.get_float(section, key, fallback)
def get_bool(self, section: str, key: str, fallback: bool = False) -> bool:
overrides = self._overrides.get(section, {})
if key in overrides:
val = str(overrides[key]).lower()
return val in ('true', '1', 'yes', 'on')
return self._base.get_bool(section, key, fallback)
# Delegate all settings getters to base (they call self.get internally)
def get_llama_settings(self) -> dict:
from .config import Config
return Config.get_llama_settings(self)
def get_transformers_settings(self) -> dict:
from .config import Config
return Config.get_transformers_settings(self)
def get_claude_settings(self) -> dict:
return self._base.get_claude_settings()
def get_huggingface_settings(self) -> dict:
return self._base.get_huggingface_settings()
class ModelRouter:
"""Manages up to 3 concurrent LLM instances (SLM, SAM, LAM).
Each tier can use a different backend (local GGUF, transformers, Claude API,
HuggingFace). The router handles loading, unloading, fallback, and thread-safe
access.
"""
def __init__(self, config=None):
self.config = config or get_config()
self._instances: Dict[ModelTier, Any] = {}
self._locks: Dict[ModelTier, threading.Lock] = {
tier: threading.Lock() for tier in ModelTier
}
self._load_lock = threading.Lock()
@property
def status(self) -> Dict[str, dict]:
"""Return load status of all tiers."""
result = {}
for tier in ModelTier:
inst = self._instances.get(tier)
settings = self.config.get_tier_settings(tier.value)
result[tier.value] = {
'loaded': inst is not None and inst.is_loaded,
'model_name': inst.model_name if inst and inst.is_loaded else None,
'backend': settings['backend'],
'enabled': settings['enabled'],
'model_path': settings['model_path'],
}
return result
def load_tier(self, tier: ModelTier, verbose: bool = False) -> bool:
"""Load a single tier's model. Thread-safe."""
settings = self.config.get_tier_settings(tier.value)
if not settings['enabled']:
_logger.info(f"[Router] Tier {tier.value} is disabled, skipping")
return False
if not settings['model_path'] and settings['backend'] == 'local':
_logger.warning(f"[Router] No model_path configured for {tier.value}")
return False
with self._load_lock:
# Unload existing if any
if tier in self._instances:
self.unload_tier(tier)
try:
inst = self._create_instance(tier, verbose)
self._instances[tier] = inst
_logger.info(f"[Router] Loaded {tier.value}: {inst.model_name}")
return True
except Exception as e:
_logger.error(f"[Router] Failed to load {tier.value}: {e}")
return False
def unload_tier(self, tier: ModelTier):
"""Unload a tier's model and free resources."""
inst = self._instances.pop(tier, None)
if inst:
try:
inst.unload_model()
_logger.info(f"[Router] Unloaded {tier.value}")
except Exception as e:
_logger.error(f"[Router] Error unloading {tier.value}: {e}")
def load_all(self, verbose: bool = False) -> Dict[str, bool]:
"""Load all enabled tiers. Returns {tier_name: success}."""
results = {}
for tier in ModelTier:
results[tier.value] = self.load_tier(tier, verbose)
return results
def unload_all(self):
"""Unload all tiers."""
for tier in list(self._instances.keys()):
self.unload_tier(tier)
def get_instance(self, tier: ModelTier):
"""Get the LLM instance for a tier (may be None if not loaded)."""
return self._instances.get(tier)
def is_tier_loaded(self, tier: ModelTier) -> bool:
"""Check if a tier has a loaded model."""
inst = self._instances.get(tier)
return inst is not None and inst.is_loaded
def classify(self, text: str) -> Dict[str, Any]:
"""Use SLM to classify/triage an event or task.
Returns: {'tier': 'sam'|'lam', 'category': str, 'urgency': str, 'reasoning': str}
Falls back to SAM tier if SLM is not loaded.
"""
classify_prompt = f"""Classify this event/task for autonomous handling.
Respond with ONLY a JSON object, no other text:
{{"tier": "sam" or "lam", "category": "defense|offense|counter|analyze|osint|simulate", "urgency": "high|medium|low", "reasoning": "brief explanation"}}
Event: {text}"""
# Try SLM first, then fallback
for tier in [ModelTier.SLM, ModelTier.SAM, ModelTier.LAM]:
inst = self._instances.get(tier)
if inst and inst.is_loaded:
try:
with self._locks[tier]:
response = inst.generate(classify_prompt, max_tokens=200, temperature=0.1)
# Parse JSON from response
response = response.strip()
# Find JSON in response
start = response.find('{')
end = response.rfind('}')
if start >= 0 and end > start:
return json.loads(response[start:end + 1])
except Exception as e:
_logger.warning(f"[Router] Classification failed on {tier.value}: {e}")
continue
# Default if all tiers fail
return {'tier': 'sam', 'category': 'defense', 'urgency': 'medium',
'reasoning': 'Default classification (no model available)'}
def generate(self, tier: ModelTier, prompt: str, **kwargs) -> str:
"""Generate with a specific tier, falling back to higher tiers on failure.
Fallback chain: SLM -> SAM -> LAM, SAM -> LAM
"""
chain = [tier] + _FALLBACK.get(tier, [])
for t in chain:
inst = self._instances.get(t)
if inst and inst.is_loaded:
try:
with self._locks[t]:
return inst.generate(prompt, **kwargs)
except Exception as e:
_logger.warning(f"[Router] Generate failed on {t.value}: {e}")
continue
from .llm import LLMError
raise LLMError(f"All tiers exhausted for generation (started at {tier.value})")
def _create_instance(self, tier: ModelTier, verbose: bool = False):
"""Create an LLM instance from tier config."""
from .llm import LLM, TransformersLLM, ClaudeLLM, HuggingFaceLLM
section = tier.value
backend = self.config.get(section, 'backend', 'local')
proxy = _TierConfigProxy(self.config, section)
if verbose:
model_path = self.config.get(section, 'model_path', '')
_logger.info(f"[Router] Creating {tier.value} instance: backend={backend}, model={model_path}")
if backend == 'local':
inst = LLM(proxy)
elif backend == 'transformers':
inst = TransformersLLM(proxy)
elif backend == 'claude':
inst = ClaudeLLM(proxy)
elif backend == 'huggingface':
inst = HuggingFaceLLM(proxy)
else:
from .llm import LLMError
raise LLMError(f"Unknown backend '{backend}' for tier {tier.value}")
inst.load_model(verbose=verbose)
return inst
# Singleton
_router_instance = None
def get_model_router() -> ModelRouter:
"""Get the global ModelRouter instance."""
global _router_instance
if _router_instance is None:
_router_instance = ModelRouter()
return _router_instance
def reset_model_router():
"""Reset the global ModelRouter (unloads all models)."""
global _router_instance
if _router_instance is not None:
_router_instance.unload_all()
_router_instance = None