Autarch/core/pentest_tree.py

351 lines
11 KiB
Python
Raw Normal View History

"""
AUTARCH Penetration Testing Tree (PTT)
Hierarchical task tracker for structured penetration testing workflows.
Based on PentestGPT's USENIX paper methodology.
"""
import uuid
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, List, Dict, Any
class NodeStatus(Enum):
TODO = "todo"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
NOT_APPLICABLE = "not_applicable"
class PTTNodeType(Enum):
RECONNAISSANCE = "reconnaissance"
INITIAL_ACCESS = "initial_access"
PRIVILEGE_ESCALATION = "privilege_escalation"
LATERAL_MOVEMENT = "lateral_movement"
PERSISTENCE = "persistence"
CREDENTIAL_ACCESS = "credential_access"
EXFILTRATION = "exfiltration"
CUSTOM = "custom"
@dataclass
class PTTNode:
"""A single node in the Penetration Testing Tree."""
id: str
label: str
node_type: PTTNodeType
status: NodeStatus = NodeStatus.TODO
parent_id: Optional[str] = None
children: List[str] = field(default_factory=list)
details: str = ""
tool_output: Optional[str] = None
findings: List[str] = field(default_factory=list)
priority: int = 3
created_at: str = ""
updated_at: str = ""
def __post_init__(self):
now = datetime.now().isoformat()
if not self.created_at:
self.created_at = now
if not self.updated_at:
self.updated_at = now
def to_dict(self) -> dict:
return {
'id': self.id,
'label': self.label,
'node_type': self.node_type.value,
'status': self.status.value,
'parent_id': self.parent_id,
'children': self.children.copy(),
'details': self.details,
'tool_output': self.tool_output,
'findings': self.findings.copy(),
'priority': self.priority,
'created_at': self.created_at,
'updated_at': self.updated_at,
}
@classmethod
def from_dict(cls, data: dict) -> 'PTTNode':
return cls(
id=data['id'],
label=data['label'],
node_type=PTTNodeType(data['node_type']),
status=NodeStatus(data['status']),
parent_id=data.get('parent_id'),
children=data.get('children', []),
details=data.get('details', ''),
tool_output=data.get('tool_output'),
findings=data.get('findings', []),
priority=data.get('priority', 3),
created_at=data.get('created_at', ''),
updated_at=data.get('updated_at', ''),
)
# Status display symbols
_STATUS_SYMBOLS = {
NodeStatus.TODO: '[ ]',
NodeStatus.IN_PROGRESS: '[~]',
NodeStatus.COMPLETED: '[x]',
NodeStatus.NOT_APPLICABLE: '[-]',
}
class PentestTree:
"""Penetration Testing Tree - hierarchical task tracker."""
def __init__(self, target: str):
self.target = target
self.nodes: Dict[str, PTTNode] = {}
self.root_nodes: List[str] = []
now = datetime.now().isoformat()
self.created_at = now
self.updated_at = now
def add_node(
self,
label: str,
node_type: PTTNodeType,
parent_id: Optional[str] = None,
details: str = "",
priority: int = 3,
status: NodeStatus = NodeStatus.TODO,
) -> str:
"""Add a node to the tree. Returns the new node's ID."""
node_id = str(uuid.uuid4())[:8]
node = PTTNode(
id=node_id,
label=label,
node_type=node_type,
status=status,
parent_id=parent_id,
details=details,
priority=priority,
)
self.nodes[node_id] = node
if parent_id and parent_id in self.nodes:
self.nodes[parent_id].children.append(node_id)
elif parent_id is None:
self.root_nodes.append(node_id)
self.updated_at = datetime.now().isoformat()
return node_id
def update_node(
self,
node_id: str,
status: Optional[NodeStatus] = None,
details: Optional[str] = None,
tool_output: Optional[str] = None,
findings: Optional[List[str]] = None,
priority: Optional[int] = None,
label: Optional[str] = None,
) -> bool:
"""Update a node's properties. Returns True if found and updated."""
node = self.nodes.get(node_id)
if not node:
return False
if status is not None:
node.status = status
if details is not None:
node.details = details
if tool_output is not None:
node.tool_output = tool_output
if findings is not None:
node.findings.extend(findings)
if priority is not None:
node.priority = priority
if label is not None:
node.label = label
node.updated_at = datetime.now().isoformat()
self.updated_at = node.updated_at
return True
def delete_node(self, node_id: str) -> bool:
"""Delete a node and all its children recursively."""
node = self.nodes.get(node_id)
if not node:
return False
# Recursively delete children
for child_id in node.children.copy():
self.delete_node(child_id)
# Remove from parent's children list
if node.parent_id and node.parent_id in self.nodes:
parent = self.nodes[node.parent_id]
if node_id in parent.children:
parent.children.remove(node_id)
# Remove from root nodes if applicable
if node_id in self.root_nodes:
self.root_nodes.remove(node_id)
del self.nodes[node_id]
self.updated_at = datetime.now().isoformat()
return True
def get_node(self, node_id: str) -> Optional[PTTNode]:
return self.nodes.get(node_id)
def get_next_todo(self) -> Optional[PTTNode]:
"""Get the highest priority TODO node."""
todos = [n for n in self.nodes.values() if n.status == NodeStatus.TODO]
if not todos:
return None
return min(todos, key=lambda n: n.priority)
def get_all_by_status(self, status: NodeStatus) -> List[PTTNode]:
return [n for n in self.nodes.values() if n.status == status]
def get_subtree(self, node_id: str) -> List[PTTNode]:
"""Get all nodes in a subtree (including the root)."""
node = self.nodes.get(node_id)
if not node:
return []
result = [node]
for child_id in node.children:
result.extend(self.get_subtree(child_id))
return result
def find_node_by_label(self, label: str) -> Optional[PTTNode]:
"""Find a node by label (case-insensitive partial match)."""
label_lower = label.lower()
for node in self.nodes.values():
if label_lower in node.label.lower():
return node
return None
def get_stats(self) -> Dict[str, int]:
"""Get tree statistics."""
stats = {'total': len(self.nodes)}
for status in NodeStatus:
stats[status.value] = len(self.get_all_by_status(status))
return stats
def render_text(self) -> str:
"""Render full tree as indented text for terminal display."""
if not self.root_nodes:
return " (empty tree)"
lines = [f"Target: {self.target}"]
lines.append("")
for root_id in self.root_nodes:
self._render_node(root_id, lines, indent=0)
return "\n".join(lines)
def _render_node(self, node_id: str, lines: List[str], indent: int):
node = self.nodes.get(node_id)
if not node:
return
prefix = " " * indent
symbol = _STATUS_SYMBOLS.get(node.status, '[ ]')
priority_str = f" P{node.priority}" if node.priority != 3 else ""
lines.append(f"{prefix}{symbol} {node.label}{priority_str}")
if node.findings:
for finding in node.findings[:3]:
lines.append(f"{prefix} -> {finding}")
for child_id in node.children:
self._render_node(child_id, lines, indent + 1)
def render_summary(self) -> str:
"""Render compact summary for LLM context injection.
Designed to fit within tight token budgets (4096 ctx).
Only shows TODO and IN_PROGRESS nodes with minimal detail.
"""
stats = self.get_stats()
lines = [
f"Target: {self.target}",
f"Nodes: {stats['total']} total, {stats['todo']} todo, "
f"{stats['completed']} done, {stats['in_progress']} active",
]
# Show active and todo nodes only
active = self.get_all_by_status(NodeStatus.IN_PROGRESS)
todos = sorted(
self.get_all_by_status(NodeStatus.TODO),
key=lambda n: n.priority
)
if active:
lines.append("Active:")
for n in active:
lines.append(f" [{n.id}] {n.label}")
if todos:
lines.append("Todo:")
for n in todos[:5]:
lines.append(f" [{n.id}] P{n.priority} {n.label}")
if len(todos) > 5:
lines.append(f" ... and {len(todos) - 5} more")
# Show recent findings (last 5)
all_findings = []
for node in self.nodes.values():
if node.findings:
for f in node.findings:
all_findings.append(f)
if all_findings:
lines.append("Key findings:")
for f in all_findings[-5:]:
lines.append(f" - {f}")
return "\n".join(lines)
def initialize_standard_branches(self):
"""Create standard MITRE ATT&CK-aligned top-level branches."""
branches = [
("Reconnaissance", PTTNodeType.RECONNAISSANCE, 1,
"Information gathering and target enumeration"),
("Initial Access", PTTNodeType.INITIAL_ACCESS, 2,
"Gaining initial foothold on target"),
("Privilege Escalation", PTTNodeType.PRIVILEGE_ESCALATION, 3,
"Escalating from initial access to higher privileges"),
("Lateral Movement", PTTNodeType.LATERAL_MOVEMENT, 4,
"Moving to other systems in the network"),
("Credential Access", PTTNodeType.CREDENTIAL_ACCESS, 3,
"Obtaining credentials and secrets"),
("Persistence", PTTNodeType.PERSISTENCE, 5,
"Maintaining access to compromised systems"),
]
for label, ntype, priority, details in branches:
self.add_node(
label=label,
node_type=ntype,
priority=priority,
details=details,
)
def to_dict(self) -> dict:
return {
'target': self.target,
'created_at': self.created_at,
'updated_at': self.updated_at,
'root_nodes': self.root_nodes.copy(),
'nodes': {nid: n.to_dict() for nid, n in self.nodes.items()},
}
@classmethod
def from_dict(cls, data: dict) -> 'PentestTree':
tree = cls(target=data['target'])
tree.created_at = data.get('created_at', '')
tree.updated_at = data.get('updated_at', '')
tree.root_nodes = data.get('root_nodes', [])
for nid, ndata in data.get('nodes', {}).items():
tree.nodes[nid] = PTTNode.from_dict(ndata)
return tree