"""Input sanitization — prevent command injection in SSH commands.""" import re def hostname(val): """Validate and sanitize a hostname/domain.""" val = str(val).strip().lower() if not re.match(r'^[a-z0-9]([a-z0-9\-\.]{0,253}[a-z0-9])?$', val): raise ValueError(f"Invalid hostname: {val}") return val def ip_address(val): """Validate an IPv4 or IPv6 address.""" val = str(val).strip() # IPv4 if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', val): parts = val.split(".") if all(0 <= int(p) <= 255 for p in parts): return val # IPv6 if re.match(r'^[0-9a-fA-F:]+$', val) and "::" in val or val.count(":") >= 2: return val # CIDR if re.match(r'^(\d{1,3}\.){3}\d{1,3}/\d{1,2}$', val): return val raise ValueError(f"Invalid IP address: {val}") def port(val): """Validate a port number.""" val = int(val) if not 1 <= val <= 65535: raise ValueError(f"Invalid port: {val}") return val def filepath(val, allow_absolute=True): """Sanitize a file path — block shell metacharacters and traversal.""" val = str(val).strip() # Block shell metacharacters dangerous = set(';&|`$(){}[]!#~<>\\"\'\n\r\t') if any(c in val for c in dangerous): raise ValueError(f"Path contains forbidden characters: {val}") # Block path traversal if ".." in val: raise ValueError(f"Path traversal not allowed: {val}") if not allow_absolute and val.startswith("/"): raise ValueError(f"Absolute paths not allowed: {val}") return val def shell_arg(val): """Sanitize a value for safe use in a shell command. Only allows alphanumeric, dash, underscore, dot, slash, colon, @, space.""" val = str(val).strip() if not re.match(r'^[a-zA-Z0-9\-_\./:@ ]+$', val): raise ValueError(f"Invalid characters in argument: {val}") return val def service_name(val): """Validate a systemd service name.""" val = str(val).strip() if not re.match(r'^[a-zA-Z0-9\-_@\.]+$', val): raise ValueError(f"Invalid service name: {val}") return val def container_name(val): """Validate a Docker container name.""" val = str(val).strip() if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9_\.\-]+$', val): raise ValueError(f"Invalid container name: {val}") return val def dns_record_type(val): """Validate a DNS record type.""" val = str(val).strip().upper() allowed = {"A", "AAAA", "CNAME", "TXT", "MX", "NS", "SRV", "CAA", "PTR"} if val not in allowed: raise ValueError(f"Invalid DNS record type: {val}") return val def email_address(val): """Basic email validation.""" val = str(val).strip() if not re.match(r'^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$', val): raise ValueError(f"Invalid email: {val}") return val def positive_int(val, max_val=10000): """Validate a positive integer within bounds.""" val = int(val) if val < 1 or val > max_val: raise ValueError(f"Value out of range (1-{max_val}): {val}") return val