""" AUTARCH Penetration Testing Tree (PTT) Hierarchical task tracker for structured penetration testing workflows. Based on PentestGPT's USENIX paper methodology. """ import uuid from enum import Enum from dataclasses import dataclass, field from datetime import datetime from typing import Optional, List, Dict, Any class NodeStatus(Enum): TODO = "todo" IN_PROGRESS = "in_progress" COMPLETED = "completed" NOT_APPLICABLE = "not_applicable" class PTTNodeType(Enum): RECONNAISSANCE = "reconnaissance" INITIAL_ACCESS = "initial_access" PRIVILEGE_ESCALATION = "privilege_escalation" LATERAL_MOVEMENT = "lateral_movement" PERSISTENCE = "persistence" CREDENTIAL_ACCESS = "credential_access" EXFILTRATION = "exfiltration" CUSTOM = "custom" @dataclass class PTTNode: """A single node in the Penetration Testing Tree.""" id: str label: str node_type: PTTNodeType status: NodeStatus = NodeStatus.TODO parent_id: Optional[str] = None children: List[str] = field(default_factory=list) details: str = "" tool_output: Optional[str] = None findings: List[str] = field(default_factory=list) priority: int = 3 created_at: str = "" updated_at: str = "" def __post_init__(self): now = datetime.now().isoformat() if not self.created_at: self.created_at = now if not self.updated_at: self.updated_at = now def to_dict(self) -> dict: return { 'id': self.id, 'label': self.label, 'node_type': self.node_type.value, 'status': self.status.value, 'parent_id': self.parent_id, 'children': self.children.copy(), 'details': self.details, 'tool_output': self.tool_output, 'findings': self.findings.copy(), 'priority': self.priority, 'created_at': self.created_at, 'updated_at': self.updated_at, } @classmethod def from_dict(cls, data: dict) -> 'PTTNode': return cls( id=data['id'], label=data['label'], node_type=PTTNodeType(data['node_type']), status=NodeStatus(data['status']), parent_id=data.get('parent_id'), children=data.get('children', []), details=data.get('details', ''), tool_output=data.get('tool_output'), findings=data.get('findings', []), priority=data.get('priority', 3), created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), ) # Status display symbols _STATUS_SYMBOLS = { NodeStatus.TODO: '[ ]', NodeStatus.IN_PROGRESS: '[~]', NodeStatus.COMPLETED: '[x]', NodeStatus.NOT_APPLICABLE: '[-]', } class PentestTree: """Penetration Testing Tree - hierarchical task tracker.""" def __init__(self, target: str): self.target = target self.nodes: Dict[str, PTTNode] = {} self.root_nodes: List[str] = [] now = datetime.now().isoformat() self.created_at = now self.updated_at = now def add_node( self, label: str, node_type: PTTNodeType, parent_id: Optional[str] = None, details: str = "", priority: int = 3, status: NodeStatus = NodeStatus.TODO, ) -> str: """Add a node to the tree. Returns the new node's ID.""" node_id = str(uuid.uuid4())[:8] node = PTTNode( id=node_id, label=label, node_type=node_type, status=status, parent_id=parent_id, details=details, priority=priority, ) self.nodes[node_id] = node if parent_id and parent_id in self.nodes: self.nodes[parent_id].children.append(node_id) elif parent_id is None: self.root_nodes.append(node_id) self.updated_at = datetime.now().isoformat() return node_id def update_node( self, node_id: str, status: Optional[NodeStatus] = None, details: Optional[str] = None, tool_output: Optional[str] = None, findings: Optional[List[str]] = None, priority: Optional[int] = None, label: Optional[str] = None, ) -> bool: """Update a node's properties. Returns True if found and updated.""" node = self.nodes.get(node_id) if not node: return False if status is not None: node.status = status if details is not None: node.details = details if tool_output is not None: node.tool_output = tool_output if findings is not None: node.findings.extend(findings) if priority is not None: node.priority = priority if label is not None: node.label = label node.updated_at = datetime.now().isoformat() self.updated_at = node.updated_at return True def delete_node(self, node_id: str) -> bool: """Delete a node and all its children recursively.""" node = self.nodes.get(node_id) if not node: return False # Recursively delete children for child_id in node.children.copy(): self.delete_node(child_id) # Remove from parent's children list if node.parent_id and node.parent_id in self.nodes: parent = self.nodes[node.parent_id] if node_id in parent.children: parent.children.remove(node_id) # Remove from root nodes if applicable if node_id in self.root_nodes: self.root_nodes.remove(node_id) del self.nodes[node_id] self.updated_at = datetime.now().isoformat() return True def get_node(self, node_id: str) -> Optional[PTTNode]: return self.nodes.get(node_id) def get_next_todo(self) -> Optional[PTTNode]: """Get the highest priority TODO node.""" todos = [n for n in self.nodes.values() if n.status == NodeStatus.TODO] if not todos: return None return min(todos, key=lambda n: n.priority) def get_all_by_status(self, status: NodeStatus) -> List[PTTNode]: return [n for n in self.nodes.values() if n.status == status] def get_subtree(self, node_id: str) -> List[PTTNode]: """Get all nodes in a subtree (including the root).""" node = self.nodes.get(node_id) if not node: return [] result = [node] for child_id in node.children: result.extend(self.get_subtree(child_id)) return result def find_node_by_label(self, label: str) -> Optional[PTTNode]: """Find a node by label (case-insensitive partial match).""" label_lower = label.lower() for node in self.nodes.values(): if label_lower in node.label.lower(): return node return None def get_stats(self) -> Dict[str, int]: """Get tree statistics.""" stats = {'total': len(self.nodes)} for status in NodeStatus: stats[status.value] = len(self.get_all_by_status(status)) return stats def render_text(self) -> str: """Render full tree as indented text for terminal display.""" if not self.root_nodes: return " (empty tree)" lines = [f"Target: {self.target}"] lines.append("") for root_id in self.root_nodes: self._render_node(root_id, lines, indent=0) return "\n".join(lines) def _render_node(self, node_id: str, lines: List[str], indent: int): node = self.nodes.get(node_id) if not node: return prefix = " " * indent symbol = _STATUS_SYMBOLS.get(node.status, '[ ]') priority_str = f" P{node.priority}" if node.priority != 3 else "" lines.append(f"{prefix}{symbol} {node.label}{priority_str}") if node.findings: for finding in node.findings[:3]: lines.append(f"{prefix} -> {finding}") for child_id in node.children: self._render_node(child_id, lines, indent + 1) def render_summary(self) -> str: """Render compact summary for LLM context injection. Designed to fit within tight token budgets (4096 ctx). Only shows TODO and IN_PROGRESS nodes with minimal detail. """ stats = self.get_stats() lines = [ f"Target: {self.target}", f"Nodes: {stats['total']} total, {stats['todo']} todo, " f"{stats['completed']} done, {stats['in_progress']} active", ] # Show active and todo nodes only active = self.get_all_by_status(NodeStatus.IN_PROGRESS) todos = sorted( self.get_all_by_status(NodeStatus.TODO), key=lambda n: n.priority ) if active: lines.append("Active:") for n in active: lines.append(f" [{n.id}] {n.label}") if todos: lines.append("Todo:") for n in todos[:5]: lines.append(f" [{n.id}] P{n.priority} {n.label}") if len(todos) > 5: lines.append(f" ... and {len(todos) - 5} more") # Show recent findings (last 5) all_findings = [] for node in self.nodes.values(): if node.findings: for f in node.findings: all_findings.append(f) if all_findings: lines.append("Key findings:") for f in all_findings[-5:]: lines.append(f" - {f}") return "\n".join(lines) def initialize_standard_branches(self): """Create standard MITRE ATT&CK-aligned top-level branches.""" branches = [ ("Reconnaissance", PTTNodeType.RECONNAISSANCE, 1, "Information gathering and target enumeration"), ("Initial Access", PTTNodeType.INITIAL_ACCESS, 2, "Gaining initial foothold on target"), ("Privilege Escalation", PTTNodeType.PRIVILEGE_ESCALATION, 3, "Escalating from initial access to higher privileges"), ("Lateral Movement", PTTNodeType.LATERAL_MOVEMENT, 4, "Moving to other systems in the network"), ("Credential Access", PTTNodeType.CREDENTIAL_ACCESS, 3, "Obtaining credentials and secrets"), ("Persistence", PTTNodeType.PERSISTENCE, 5, "Maintaining access to compromised systems"), ] for label, ntype, priority, details in branches: self.add_node( label=label, node_type=ntype, priority=priority, details=details, ) def to_dict(self) -> dict: return { 'target': self.target, 'created_at': self.created_at, 'updated_at': self.updated_at, 'root_nodes': self.root_nodes.copy(), 'nodes': {nid: n.to_dict() for nid, n in self.nodes.items()}, } @classmethod def from_dict(cls, data: dict) -> 'PentestTree': tree = cls(target=data['target']) tree.created_at = data.get('created_at', '') tree.updated_at = data.get('updated_at', '') tree.root_nodes = data.get('root_nodes', []) for nid, ndata in data.get('nodes', {}).items(): tree.nodes[nid] = PTTNode.from_dict(ndata) return tree