Autarch Will Control The Internet

This commit is contained in:
DigiJ
2026-03-13 15:17:15 -07:00
commit 4d3570781e
401 changed files with 484494 additions and 0 deletions

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@@ -0,0 +1,26 @@
#!/bin/bash
# Cross-compile autarch-dns for all supported platforms
set -e
VERSION="1.0.0"
OUTPUT_BASE="../../tools"
echo "Building autarch-dns v${VERSION}..."
# Linux ARM64 (Orange Pi 5 Plus)
echo " → linux/arm64"
GOOS=linux GOARCH=arm64 go build -ldflags="-s -w -X main.version=${VERSION}" \
-o "${OUTPUT_BASE}/linux-arm64/autarch-dns" .
# Linux AMD64
echo " → linux/amd64"
GOOS=linux GOARCH=amd64 go build -ldflags="-s -w -X main.version=${VERSION}" \
-o "${OUTPUT_BASE}/linux-x86_64/autarch-dns" .
# Windows AMD64
echo " → windows/amd64"
GOOS=windows GOARCH=amd64 go build -ldflags="-s -w -X main.version=${VERSION}" \
-o "${OUTPUT_BASE}/windows-x86_64/autarch-dns.exe" .
echo "Done! Binaries:"
ls -lh "${OUTPUT_BASE}"/*/autarch-dns* 2>/dev/null || true

View File

@@ -0,0 +1,84 @@
package config
import (
"crypto/rand"
"encoding/hex"
)
// Config holds all DNS server configuration.
type Config struct {
ListenDNS string `json:"listen_dns"`
ListenAPI string `json:"listen_api"`
APIToken string `json:"api_token"`
Upstream []string `json:"upstream"`
CacheTTL int `json:"cache_ttl"`
ZonesDir string `json:"zones_dir"`
DNSSECKeyDir string `json:"dnssec_keys_dir"`
LogQueries bool `json:"log_queries"`
// Hosts file support
HostsFile string `json:"hosts_file"` // Path to hosts file (e.g., /etc/hosts)
HostsAutoLoad bool `json:"hosts_auto_load"` // Auto-load system hosts file on start
// Encryption
EnableDoH bool `json:"enable_doh"` // DNS-over-HTTPS to upstream
EnableDoT bool `json:"enable_dot"` // DNS-over-TLS to upstream
// Security hardening
RateLimit int `json:"rate_limit"` // Max queries/sec per source IP (0=unlimited)
BlockList []string `json:"block_list"` // Blocked domain patterns
AllowTransfer []string `json:"allow_transfer"` // IPs allowed zone transfers (empty=none)
MinimalResponses bool `json:"minimal_responses"` // Minimize response data
RefuseANY bool `json:"refuse_any"` // Refuse ANY queries (amplification protection)
MaxUDPSize int `json:"max_udp_size"` // Max UDP response size
// Advanced
QueryLogMax int `json:"querylog_max"` // Max query log entries (default 1000)
NegativeCacheTTL int `json:"negative_cache_ttl"` // TTL for NXDOMAIN cache (default 60)
PrefetchEnabled bool `json:"prefetch_enabled"` // Prefetch expiring cache entries
ServFailCacheTTL int `json:"servfail_cache_ttl"` // TTL for SERVFAIL cache (default 30)
}
// DefaultConfig returns security-hardened defaults.
// No upstream forwarders — full recursive resolution from root hints.
// Upstream can be configured as optional fallback if recursive fails.
func DefaultConfig() *Config {
return &Config{
ListenDNS: "0.0.0.0:53",
ListenAPI: "127.0.0.1:5380",
APIToken: generateToken(),
Upstream: []string{}, // Empty = pure recursive from root hints
CacheTTL: 300,
ZonesDir: "data/dns/zones",
DNSSECKeyDir: "data/dns/keys",
LogQueries: true,
// Hosts
HostsFile: "",
HostsAutoLoad: false,
// Encryption defaults
EnableDoH: true,
EnableDoT: true,
// Security defaults
RateLimit: 100, // 100 qps per source IP
BlockList: []string{},
AllowTransfer: []string{}, // No zone transfers
MinimalResponses: true,
RefuseANY: true, // Block DNS amplification attacks
MaxUDPSize: 1232, // Safe MTU, prevent fragmentation
// Advanced defaults
QueryLogMax: 1000,
NegativeCacheTTL: 60,
PrefetchEnabled: false,
ServFailCacheTTL: 30,
}
}
func generateToken() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}

View File

@@ -0,0 +1,13 @@
module github.com/darkhal/autarch-dns
go 1.22
require github.com/miekg/dns v1.1.62
require (
golang.org/x/mod v0.18.0 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/tools v0.22.0 // indirect
)

View File

@@ -0,0 +1,12 @@
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=

View File

@@ -0,0 +1,84 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/darkhal/autarch-dns/api"
"github.com/darkhal/autarch-dns/config"
"github.com/darkhal/autarch-dns/server"
)
var version = "2.1.0"
func main() {
configPath := flag.String("config", "config.json", "Path to config file")
listenDNS := flag.String("dns", "", "DNS listen address (overrides config)")
listenAPI := flag.String("api", "", "API listen address (overrides config)")
apiToken := flag.String("token", "", "API auth token (overrides config)")
showVersion := flag.Bool("version", false, "Show version")
flag.Parse()
if *showVersion {
fmt.Printf("autarch-dns v%s\n", version)
os.Exit(0)
}
// Load config
cfg := config.DefaultConfig()
if data, err := os.ReadFile(*configPath); err == nil {
if err := json.Unmarshal(data, cfg); err != nil {
log.Printf("Warning: invalid config file: %v", err)
}
}
// CLI overrides
if *listenDNS != "" {
cfg.ListenDNS = *listenDNS
}
if *listenAPI != "" {
cfg.ListenAPI = *listenAPI
}
if *apiToken != "" {
cfg.APIToken = *apiToken
}
// Initialize zone store
store := server.NewZoneStore(cfg.ZonesDir)
if err := store.LoadAll(); err != nil {
log.Printf("Warning: loading zones: %v", err)
}
// Start DNS server
dnsServer := server.NewDNSServer(cfg, store)
go func() {
log.Printf("DNS server listening on %s (UDP+TCP)", cfg.ListenDNS)
if err := dnsServer.Start(); err != nil {
log.Fatalf("DNS server error: %v", err)
}
}()
// Start API server
apiServer := api.NewAPIServer(cfg, store, dnsServer)
go func() {
log.Printf("API server listening on %s", cfg.ListenAPI)
if err := apiServer.Start(); err != nil {
log.Fatalf("API server error: %v", err)
}
}()
log.Printf("autarch-dns v%s started", version)
// Wait for shutdown signal
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
log.Println("Shutting down...")
dnsServer.Stop()
}

View File

@@ -0,0 +1,656 @@
package server
import (
"log"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/darkhal/autarch-dns/config"
"github.com/miekg/dns"
)
// Metrics holds query statistics.
type Metrics struct {
TotalQueries uint64 `json:"total_queries"`
CacheHits uint64 `json:"cache_hits"`
CacheMisses uint64 `json:"cache_misses"`
LocalAnswers uint64 `json:"local_answers"`
ResolvedQ uint64 `json:"resolved"`
BlockedQ uint64 `json:"blocked"`
FailedQ uint64 `json:"failed"`
StartTime string `json:"start_time"`
}
// QueryLogEntry records a single DNS query.
type QueryLogEntry struct {
Timestamp string `json:"timestamp"`
Client string `json:"client"`
Name string `json:"name"`
Type string `json:"type"`
Rcode string `json:"rcode"`
Answers int `json:"answers"`
Latency string `json:"latency"`
Source string `json:"source"` // "local", "cache", "recursive", "blocked", "failed"
}
// CacheEntry holds a cached DNS response.
type CacheEntry struct {
msg *dns.Msg
expiresAt time.Time
}
// CacheInfo is an exportable view of a cache entry.
type CacheInfo struct {
Key string `json:"key"`
Name string `json:"name"`
Type string `json:"type"`
TTL int `json:"ttl_remaining"`
Answers int `json:"answers"`
ExpiresAt string `json:"expires_at"`
}
// DomainCount tracks query frequency per domain.
type DomainCount struct {
Domain string `json:"domain"`
Count uint64 `json:"count"`
}
// DNSServer is the main DNS server.
type DNSServer struct {
cfg *config.Config
store *ZoneStore
hosts *HostsStore
resolver *RecursiveResolver
metrics Metrics
cache map[string]*CacheEntry
cacheMu sync.RWMutex
udpServ *dns.Server
tcpServ *dns.Server
// Query log — ring buffer
queryLog []QueryLogEntry
queryLogMu sync.RWMutex
queryLogMax int
// Domain frequency tracking
domainCounts map[string]uint64
domainCountsMu sync.RWMutex
// Query type tracking
typeCounts map[string]uint64
typeCountsMu sync.RWMutex
// Client tracking
clientCounts map[string]uint64
clientCountsMu sync.RWMutex
// Blocklist — fast lookup
blocklist map[string]bool
blocklistMu sync.RWMutex
// Conditional forwarding: zone -> upstream servers
conditionalFwd map[string][]string
conditionalFwdMu sync.RWMutex
}
// NewDNSServer creates a DNS server.
func NewDNSServer(cfg *config.Config, store *ZoneStore) *DNSServer {
resolver := NewRecursiveResolver()
resolver.EnableDoT = cfg.EnableDoT
resolver.EnableDoH = cfg.EnableDoH
logMax := cfg.QueryLogMax
if logMax <= 0 {
logMax = 1000
}
s := &DNSServer{
cfg: cfg,
store: store,
hosts: NewHostsStore(),
resolver: resolver,
cache: make(map[string]*CacheEntry),
queryLog: make([]QueryLogEntry, 0, logMax),
queryLogMax: logMax,
domainCounts: make(map[string]uint64),
typeCounts: make(map[string]uint64),
clientCounts: make(map[string]uint64),
blocklist: make(map[string]bool),
conditionalFwd: make(map[string][]string),
metrics: Metrics{
StartTime: time.Now().UTC().Format(time.RFC3339),
},
}
// Load blocklist from config
for _, pattern := range cfg.BlockList {
s.blocklist[dns.Fqdn(strings.ToLower(pattern))] = true
}
// Load hosts file if configured
if cfg.HostsFile != "" {
if err := s.hosts.LoadFile(cfg.HostsFile); err != nil {
log.Printf("[hosts] Warning: could not load hosts file %s: %v", cfg.HostsFile, err)
}
}
return s
}
// GetHosts returns the hosts store.
func (s *DNSServer) GetHosts() *HostsStore {
return s.hosts
}
// GetEncryptionStatus returns encryption info from the resolver.
func (s *DNSServer) GetEncryptionStatus() map[string]interface{} {
return s.resolver.GetEncryptionStatus()
}
// SetEncryption updates DoT/DoH settings on the resolver.
func (s *DNSServer) SetEncryption(dot, doh bool) {
s.resolver.EnableDoT = dot
s.resolver.EnableDoH = doh
s.cfg.EnableDoT = dot
s.cfg.EnableDoH = doh
}
// GetResolver returns the underlying recursive resolver.
func (s *DNSServer) GetResolver() *RecursiveResolver {
return s.resolver
}
// Start begins listening on UDP and TCP.
func (s *DNSServer) Start() error {
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
s.udpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "udp", Handler: mux}
s.tcpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "tcp", Handler: mux}
errCh := make(chan error, 2)
go func() { errCh <- s.udpServ.ListenAndServe() }()
go func() { errCh <- s.tcpServ.ListenAndServe() }()
go s.cacheCleanup()
return <-errCh
}
// Stop shuts down both servers.
func (s *DNSServer) Stop() {
if s.udpServ != nil {
s.udpServ.Shutdown()
}
if s.tcpServ != nil {
s.tcpServ.Shutdown()
}
}
// GetMetrics returns current metrics.
func (s *DNSServer) GetMetrics() Metrics {
return Metrics{
TotalQueries: atomic.LoadUint64(&s.metrics.TotalQueries),
CacheHits: atomic.LoadUint64(&s.metrics.CacheHits),
CacheMisses: atomic.LoadUint64(&s.metrics.CacheMisses),
LocalAnswers: atomic.LoadUint64(&s.metrics.LocalAnswers),
ResolvedQ: atomic.LoadUint64(&s.metrics.ResolvedQ),
BlockedQ: atomic.LoadUint64(&s.metrics.BlockedQ),
FailedQ: atomic.LoadUint64(&s.metrics.FailedQ),
StartTime: s.metrics.StartTime,
}
}
// GetQueryLog returns the last N query log entries.
func (s *DNSServer) GetQueryLog(limit int) []QueryLogEntry {
s.queryLogMu.RLock()
defer s.queryLogMu.RUnlock()
n := len(s.queryLog)
if limit <= 0 || limit > n {
limit = n
}
// Return most recent first
result := make([]QueryLogEntry, limit)
for i := 0; i < limit; i++ {
result[i] = s.queryLog[n-1-i]
}
return result
}
// ClearQueryLog empties the log.
func (s *DNSServer) ClearQueryLog() {
s.queryLogMu.Lock()
s.queryLog = s.queryLog[:0]
s.queryLogMu.Unlock()
}
// GetCacheEntries returns all cache entries.
func (s *DNSServer) GetCacheEntries() []CacheInfo {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
now := time.Now()
entries := make([]CacheInfo, 0, len(s.cache))
for key, entry := range s.cache {
if now.After(entry.expiresAt) {
continue
}
parts := strings.SplitN(key, "/", 2)
name, qtype := key, ""
if len(parts) == 2 {
name, qtype = parts[0], parts[1]
}
entries = append(entries, CacheInfo{
Key: key,
Name: name,
Type: qtype,
TTL: int(entry.expiresAt.Sub(now).Seconds()),
Answers: len(entry.msg.Answer),
ExpiresAt: entry.expiresAt.Format(time.RFC3339),
})
}
return entries
}
// CacheSize returns number of active cache entries.
func (s *DNSServer) CacheSize() int {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
return len(s.cache)
}
// FlushCache clears all cached responses.
func (s *DNSServer) FlushCache() int {
s.cacheMu.Lock()
n := len(s.cache)
s.cache = make(map[string]*CacheEntry)
s.cacheMu.Unlock()
// Also flush resolver NS cache
s.resolver.FlushNSCache()
return n
}
// FlushCacheEntry removes a single cache entry.
func (s *DNSServer) FlushCacheEntry(key string) bool {
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
if _, ok := s.cache[key]; ok {
delete(s.cache, key)
return true
}
return false
}
// GetTopDomains returns the most-queried domains.
func (s *DNSServer) GetTopDomains(limit int) []DomainCount {
s.domainCountsMu.RLock()
defer s.domainCountsMu.RUnlock()
counts := make([]DomainCount, 0, len(s.domainCounts))
for domain, count := range s.domainCounts {
counts = append(counts, DomainCount{Domain: domain, Count: count})
}
sort.Slice(counts, func(i, j int) bool { return counts[i].Count > counts[j].Count })
if limit > 0 && limit < len(counts) {
counts = counts[:limit]
}
return counts
}
// GetQueryTypeCounts returns counts by query type.
func (s *DNSServer) GetQueryTypeCounts() map[string]uint64 {
s.typeCountsMu.RLock()
defer s.typeCountsMu.RUnlock()
result := make(map[string]uint64, len(s.typeCounts))
for k, v := range s.typeCounts {
result[k] = v
}
return result
}
// GetClientCounts returns counts by client IP.
func (s *DNSServer) GetClientCounts() map[string]uint64 {
s.clientCountsMu.RLock()
defer s.clientCountsMu.RUnlock()
result := make(map[string]uint64, len(s.clientCounts))
for k, v := range s.clientCounts {
result[k] = v
}
return result
}
// AddBlocklistEntry adds a domain to the blocklist.
func (s *DNSServer) AddBlocklistEntry(domain string) {
s.blocklistMu.Lock()
s.blocklist[dns.Fqdn(strings.ToLower(domain))] = true
s.blocklistMu.Unlock()
}
// RemoveBlocklistEntry removes a domain from the blocklist.
func (s *DNSServer) RemoveBlocklistEntry(domain string) {
s.blocklistMu.Lock()
delete(s.blocklist, dns.Fqdn(strings.ToLower(domain)))
s.blocklistMu.Unlock()
}
// GetBlocklist returns all blocked domains.
func (s *DNSServer) GetBlocklist() []string {
s.blocklistMu.RLock()
defer s.blocklistMu.RUnlock()
list := make([]string, 0, len(s.blocklist))
for domain := range s.blocklist {
list = append(list, domain)
}
sort.Strings(list)
return list
}
// ImportBlocklist adds multiple domains at once.
func (s *DNSServer) ImportBlocklist(domains []string) int {
s.blocklistMu.Lock()
defer s.blocklistMu.Unlock()
count := 0
for _, d := range domains {
d = strings.TrimSpace(strings.ToLower(d))
if d == "" || strings.HasPrefix(d, "#") {
continue
}
s.blocklist[dns.Fqdn(d)] = true
count++
}
return count
}
// SetConditionalForward sets upstream servers for a specific zone.
func (s *DNSServer) SetConditionalForward(zone string, upstreams []string) {
s.conditionalFwdMu.Lock()
s.conditionalFwd[dns.Fqdn(strings.ToLower(zone))] = upstreams
s.conditionalFwdMu.Unlock()
}
// RemoveConditionalForward removes conditional forwarding for a zone.
func (s *DNSServer) RemoveConditionalForward(zone string) {
s.conditionalFwdMu.Lock()
delete(s.conditionalFwd, dns.Fqdn(strings.ToLower(zone)))
s.conditionalFwdMu.Unlock()
}
// GetConditionalForwards returns all conditional forwarding rules.
func (s *DNSServer) GetConditionalForwards() map[string][]string {
s.conditionalFwdMu.RLock()
defer s.conditionalFwdMu.RUnlock()
result := make(map[string][]string, len(s.conditionalFwd))
for k, v := range s.conditionalFwd {
result[k] = v
}
return result
}
// GetResolverNSCache returns the resolver's NS delegation cache.
func (s *DNSServer) GetResolverNSCache() map[string][]string {
return s.resolver.GetNSCache()
}
func (s *DNSServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
start := time.Now()
atomic.AddUint64(&s.metrics.TotalQueries, 1)
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = false
msg.RecursionAvailable = true
if len(r.Question) == 0 {
msg.Rcode = dns.RcodeFormatError
w.WriteMsg(msg)
return
}
q := r.Question[0]
qName := q.Name
qTypeStr := dns.TypeToString[q.Qtype]
clientAddr := w.RemoteAddr().String()
// Track stats
s.trackDomain(qName)
s.trackType(qTypeStr)
s.trackClient(clientAddr)
if s.cfg.LogQueries {
log.Printf("[query] %s %s from %s", qTypeStr, qName, clientAddr)
}
// Security: Refuse ANY queries (DNS amplification protection)
if s.cfg.RefuseANY && q.Qtype == dns.TypeANY {
msg.Rcode = dns.RcodeNotImplemented
atomic.AddUint64(&s.metrics.FailedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOTIMPL", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// Security: Block zone transfer requests (AXFR/IXFR)
if q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR {
msg.Rcode = dns.RcodeRefused
atomic.AddUint64(&s.metrics.FailedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// Security: Minimal responses — don't expose server info
if s.cfg.MinimalResponses {
if q.Qtype == dns.TypeTXT && (qName == "version.bind." || qName == "hostname.bind." || qName == "version.server.") {
msg.Rcode = dns.RcodeRefused
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
}
// Blocklist check
if s.isBlocked(qName) {
msg.Rcode = dns.RcodeNameError // NXDOMAIN
atomic.AddUint64(&s.metrics.BlockedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NXDOMAIN", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// 1a. Check hosts file
hostsAnswers := s.hosts.Lookup(qName, q.Qtype)
if len(hostsAnswers) > 0 {
msg.Authoritative = true
msg.Answer = hostsAnswers
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(hostsAnswers), time.Since(start), "hosts")
w.WriteMsg(msg)
return
}
// 1b. Check local zones
answers := s.store.Lookup(qName, q.Qtype)
if len(answers) > 0 {
msg.Authoritative = true
msg.Answer = answers
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(answers), time.Since(start), "local")
w.WriteMsg(msg)
return
}
// 2. Check cache
cacheKey := cacheKeyFor(q)
if cached := s.getCached(cacheKey); cached != nil {
cached.SetReply(r)
atomic.AddUint64(&s.metrics.CacheHits, 1)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[cached.Rcode], len(cached.Answer), time.Since(start), "cache")
w.WriteMsg(cached)
return
}
atomic.AddUint64(&s.metrics.CacheMisses, 1)
// 3. Check conditional forwarding
if fwdServers := s.getConditionalForward(qName); fwdServers != nil {
c := &dns.Client{Timeout: 5 * time.Second}
for _, srv := range fwdServers {
resp, _, err := c.Exchange(r, srv)
if err == nil && resp != nil {
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
s.putCache(cacheKey, resp)
resp.SetReply(r)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "conditional")
w.WriteMsg(resp)
return
}
}
}
// 4. Recursive resolution from root hints (with optional upstream fallback)
resp := s.resolver.ResolveWithFallback(r, s.cfg.Upstream)
if resp != nil {
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
s.putCache(cacheKey, resp)
resp.SetReply(r)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "recursive")
w.WriteMsg(resp)
return
}
// 5. SERVFAIL
atomic.AddUint64(&s.metrics.FailedQ, 1)
msg.Rcode = dns.RcodeServerFailure
s.logQuery(clientAddr, qName, qTypeStr, "SERVFAIL", 0, time.Since(start), "failed")
w.WriteMsg(msg)
}
// ── Blocklist ────────────────────────────────────────────────────────
func (s *DNSServer) isBlocked(name string) bool {
s.blocklistMu.RLock()
defer s.blocklistMu.RUnlock()
fqdn := dns.Fqdn(strings.ToLower(name))
// Exact match
if s.blocklist[fqdn] {
return true
}
// Wildcard: check parent domains
labels := dns.SplitDomainName(fqdn)
for i := 1; i < len(labels); i++ {
parent := dns.Fqdn(strings.Join(labels[i:], "."))
if s.blocklist[parent] {
return true
}
}
return false
}
// ── Conditional forwarding ───────────────────────────────────────────
func (s *DNSServer) getConditionalForward(name string) []string {
s.conditionalFwdMu.RLock()
defer s.conditionalFwdMu.RUnlock()
fqdn := dns.Fqdn(strings.ToLower(name))
labels := dns.SplitDomainName(fqdn)
for i := 0; i < len(labels); i++ {
zone := dns.Fqdn(strings.Join(labels[i:], "."))
if servers, ok := s.conditionalFwd[zone]; ok {
return servers
}
}
return nil
}
// ── Tracking ─────────────────────────────────────────────────────────
func (s *DNSServer) trackDomain(name string) {
s.domainCountsMu.Lock()
s.domainCounts[name]++
s.domainCountsMu.Unlock()
}
func (s *DNSServer) trackType(qtype string) {
s.typeCountsMu.Lock()
s.typeCounts[qtype]++
s.typeCountsMu.Unlock()
}
func (s *DNSServer) trackClient(addr string) {
// Strip port
if idx := strings.LastIndex(addr, ":"); idx > 0 {
addr = addr[:idx]
}
s.clientCountsMu.Lock()
s.clientCounts[addr]++
s.clientCountsMu.Unlock()
}
func (s *DNSServer) logQuery(client, name, qtype, rcode string, answers int, latency time.Duration, source string) {
entry := QueryLogEntry{
Timestamp: time.Now().UTC().Format(time.RFC3339Nano),
Client: client,
Name: name,
Type: qtype,
Rcode: rcode,
Answers: answers,
Latency: latency.String(),
Source: source,
}
s.queryLogMu.Lock()
if len(s.queryLog) >= s.queryLogMax {
// Shift: remove oldest 10%
trim := s.queryLogMax / 10
copy(s.queryLog, s.queryLog[trim:])
s.queryLog = s.queryLog[:len(s.queryLog)-trim]
}
s.queryLog = append(s.queryLog, entry)
s.queryLogMu.Unlock()
}
// ── Cache ────────────────────────────────────────────────────────────
func cacheKeyFor(q dns.Question) string {
return q.Name + "/" + dns.TypeToString[q.Qtype]
}
func (s *DNSServer) getCached(key string) *dns.Msg {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
entry, ok := s.cache[key]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.msg.Copy()
}
func (s *DNSServer) putCache(key string, msg *dns.Msg) {
ttl := time.Duration(s.cfg.CacheTTL) * time.Second
if ttl <= 0 {
return
}
s.cacheMu.Lock()
s.cache[key] = &CacheEntry{msg: msg.Copy(), expiresAt: time.Now().Add(ttl)}
s.cacheMu.Unlock()
}
func (s *DNSServer) cacheCleanup() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.cacheMu.Lock()
now := time.Now()
for k, v := range s.cache {
if now.After(v.expiresAt) {
delete(s.cache, k)
}
}
s.cacheMu.Unlock()
}
}

View File

@@ -0,0 +1,349 @@
package server
import (
"bufio"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
// HostEntry represents a single hosts file entry.
type HostEntry struct {
IP string `json:"ip"`
Hostname string `json:"hostname"`
Aliases []string `json:"aliases,omitempty"`
Comment string `json:"comment,omitempty"`
}
// HostsStore manages a hosts-file-like database.
type HostsStore struct {
mu sync.RWMutex
entries []HostEntry
path string // path to hosts file on disk (if loaded from file)
}
// NewHostsStore creates a new hosts store.
func NewHostsStore() *HostsStore {
return &HostsStore{
entries: make([]HostEntry, 0),
}
}
// LoadFile parses a hosts file from disk.
func (h *HostsStore) LoadFile(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
h.mu.Lock()
defer h.mu.Unlock()
h.path = path
h.entries = h.entries[:0]
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Strip inline comments
comment := ""
if idx := strings.Index(line, "#"); idx >= 0 {
comment = strings.TrimSpace(line[idx+1:])
line = strings.TrimSpace(line[:idx])
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
ip := fields[0]
if net.ParseIP(ip) == nil {
continue // invalid IP
}
entry := HostEntry{
IP: ip,
Hostname: strings.ToLower(fields[1]),
Comment: comment,
}
if len(fields) > 2 {
aliases := make([]string, len(fields)-2)
for i, a := range fields[2:] {
aliases[i] = strings.ToLower(a)
}
entry.Aliases = aliases
}
h.entries = append(h.entries, entry)
}
log.Printf("[hosts] Loaded %d entries from %s", len(h.entries), path)
return scanner.Err()
}
// LoadFromText parses hosts-format text (like pasting /etc/hosts content).
func (h *HostsStore) LoadFromText(content string) int {
h.mu.Lock()
defer h.mu.Unlock()
count := 0
scanner := bufio.NewScanner(strings.NewReader(content))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
comment := ""
if idx := strings.Index(line, "#"); idx >= 0 {
comment = strings.TrimSpace(line[idx+1:])
line = strings.TrimSpace(line[:idx])
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
ip := fields[0]
if net.ParseIP(ip) == nil {
continue
}
entry := HostEntry{
IP: ip,
Hostname: strings.ToLower(fields[1]),
Comment: comment,
}
if len(fields) > 2 {
aliases := make([]string, len(fields)-2)
for i, a := range fields[2:] {
aliases[i] = strings.ToLower(a)
}
entry.Aliases = aliases
}
// Dedup by hostname
found := false
for i, e := range h.entries {
if e.Hostname == entry.Hostname {
h.entries[i] = entry
found = true
break
}
}
if !found {
h.entries = append(h.entries, entry)
}
count++
}
return count
}
// Add adds a single host entry.
func (h *HostsStore) Add(ip, hostname string, aliases []string, comment string) error {
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP: %s", ip)
}
hostname = strings.ToLower(strings.TrimSpace(hostname))
if hostname == "" {
return fmt.Errorf("hostname required")
}
h.mu.Lock()
defer h.mu.Unlock()
// Check for duplicate
for i, e := range h.entries {
if e.Hostname == hostname {
h.entries[i].IP = ip
h.entries[i].Aliases = aliases
h.entries[i].Comment = comment
return nil
}
}
h.entries = append(h.entries, HostEntry{
IP: ip,
Hostname: hostname,
Aliases: aliases,
Comment: comment,
})
return nil
}
// Remove removes a host entry by hostname.
func (h *HostsStore) Remove(hostname string) bool {
hostname = strings.ToLower(strings.TrimSpace(hostname))
h.mu.Lock()
defer h.mu.Unlock()
for i, e := range h.entries {
if e.Hostname == hostname {
h.entries = append(h.entries[:i], h.entries[i+1:]...)
return true
}
}
return false
}
// Clear removes all entries.
func (h *HostsStore) Clear() int {
h.mu.Lock()
defer h.mu.Unlock()
n := len(h.entries)
h.entries = h.entries[:0]
return n
}
// List returns all entries.
func (h *HostsStore) List() []HostEntry {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]HostEntry, len(h.entries))
copy(result, h.entries)
return result
}
// Count returns the number of entries.
func (h *HostsStore) Count() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.entries)
}
// Lookup resolves a hostname from the hosts store.
// Returns DNS RRs matching the query name and type.
func (h *HostsStore) Lookup(name string, qtype uint16) []dns.RR {
if qtype != dns.TypeA && qtype != dns.TypeAAAA && qtype != dns.TypePTR {
return nil
}
h.mu.RLock()
defer h.mu.RUnlock()
fqdn := dns.Fqdn(strings.ToLower(name))
baseName := strings.TrimSuffix(fqdn, ".")
// PTR lookup (reverse DNS)
if qtype == dns.TypePTR {
// Convert in-addr.arpa name to IP
ip := ptrToIP(fqdn)
if ip == "" {
return nil
}
for _, e := range h.entries {
if e.IP == ip {
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: fqdn,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
},
Ptr: dns.Fqdn(e.Hostname),
}
return []dns.RR{rr}
}
}
return nil
}
// Forward lookup (A / AAAA)
var results []dns.RR
for _, e := range h.entries {
// Match hostname or aliases
match := strings.EqualFold(e.Hostname, baseName) || strings.EqualFold(dns.Fqdn(e.Hostname), fqdn)
if !match {
for _, a := range e.Aliases {
if strings.EqualFold(a, baseName) || strings.EqualFold(dns.Fqdn(a), fqdn) {
match = true
break
}
}
}
if !match {
continue
}
ip := net.ParseIP(e.IP)
if ip == nil {
continue
}
if qtype == dns.TypeA && ip.To4() != nil {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: fqdn,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: ip.To4(),
}
results = append(results, rr)
} else if qtype == dns.TypeAAAA && ip.To4() == nil {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: fqdn,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 60,
},
AAAA: ip,
}
results = append(results, rr)
}
}
return results
}
// Export returns hosts file format text.
func (h *HostsStore) Export() string {
h.mu.RLock()
defer h.mu.RUnlock()
var sb strings.Builder
sb.WriteString("# AUTARCH DNS hosts file\n")
sb.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339)))
sb.WriteString("# Entries: " + fmt.Sprintf("%d", len(h.entries)) + "\n\n")
for _, e := range h.entries {
line := e.IP + "\t" + e.Hostname
for _, a := range e.Aliases {
line += "\t" + a
}
if e.Comment != "" {
line += "\t# " + e.Comment
}
sb.WriteString(line + "\n")
}
return sb.String()
}
// ptrToIP converts a PTR domain name (in-addr.arpa) to an IP string.
func ptrToIP(name string) string {
name = strings.TrimSuffix(strings.ToLower(name), ".")
if !strings.HasSuffix(name, ".in-addr.arpa") {
return ""
}
name = strings.TrimSuffix(name, ".in-addr.arpa")
parts := strings.Split(name, ".")
if len(parts) != 4 {
return ""
}
// Reverse the octets
return parts[3] + "." + parts[2] + "." + parts[1] + "." + parts[0]
}

View File

@@ -0,0 +1,528 @@
package server
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
// Root nameserver IPs (IANA root hints).
// These are hardcoded — they almost never change.
var rootServers = []string{
"198.41.0.4:53", // a.root-servers.net
"170.247.170.2:53", // b.root-servers.net
"192.33.4.12:53", // c.root-servers.net
"199.7.91.13:53", // d.root-servers.net
"192.203.230.10:53", // e.root-servers.net
"192.5.5.241:53", // f.root-servers.net
"192.112.36.4:53", // g.root-servers.net
"198.97.190.53:53", // h.root-servers.net
"192.36.148.17:53", // i.root-servers.net
"192.58.128.30:53", // j.root-servers.net
"193.0.14.129:53", // k.root-servers.net
"199.7.83.42:53", // l.root-servers.net
"202.12.27.33:53", // m.root-servers.net
}
// Well-known DoH endpoints — when user configures these as upstream,
// we auto-detect and use DoH instead of plain DNS.
var knownDoHEndpoints = map[string]string{
"8.8.8.8": "https://dns.google/dns-query",
"8.8.4.4": "https://dns.google/dns-query",
"1.1.1.1": "https://cloudflare-dns.com/dns-query",
"1.0.0.1": "https://cloudflare-dns.com/dns-query",
"9.9.9.9": "https://dns.quad9.net/dns-query",
"149.112.112.112": "https://dns.quad9.net/dns-query",
"208.67.222.222": "https://dns.opendns.com/dns-query",
"208.67.220.220": "https://dns.opendns.com/dns-query",
"94.140.14.14": "https://dns.adguard-dns.com/dns-query",
"94.140.15.15": "https://dns.adguard-dns.com/dns-query",
}
// Well-known DoT servers — port 853 TLS.
var knownDoTServers = map[string]string{
"8.8.8.8": "dns.google",
"8.8.4.4": "dns.google",
"1.1.1.1": "one.one.one.one",
"1.0.0.1": "one.one.one.one",
"9.9.9.9": "dns.quad9.net",
"149.112.112.112": "dns.quad9.net",
"208.67.222.222": "dns.opendns.com",
"208.67.220.220": "dns.opendns.com",
"94.140.14.14": "dns-unfiltered.adguard.com",
"94.140.15.15": "dns-unfiltered.adguard.com",
}
// EncryptionMode determines how upstream queries are sent.
type EncryptionMode int
const (
ModePlain EncryptionMode = iota // Standard UDP/TCP DNS
ModeDoT // DNS-over-TLS (port 853)
ModeDoH // DNS-over-HTTPS (RFC 8484)
)
// RecursiveResolver performs iterative DNS resolution from root hints.
type RecursiveResolver struct {
// NS cache: zone -> list of nameserver IPs
nsCache map[string][]string
nsCacheMu sync.RWMutex
client *dns.Client
dotClient *dns.Client // TLS client for DoT
dohHTTP *http.Client
maxDepth int
timeout time.Duration
// Encryption settings
EnableDoT bool
EnableDoH bool
}
// NewRecursiveResolver creates a resolver with root hints.
func NewRecursiveResolver() *RecursiveResolver {
return &RecursiveResolver{
nsCache: make(map[string][]string),
client: &dns.Client{Timeout: 4 * time.Second},
dotClient: &dns.Client{
Net: "tcp-tls",
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
},
dohHTTP: &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
DisableCompression: false,
ForceAttemptHTTP2: true,
},
},
maxDepth: 20,
timeout: 4 * time.Second,
}
}
// Resolve performs full iterative resolution for the given query message.
// Returns the final authoritative response, or nil on failure.
func (rr *RecursiveResolver) Resolve(req *dns.Msg) *dns.Msg {
if len(req.Question) == 0 {
return nil
}
q := req.Question[0]
return rr.resolve(q.Name, q.Qtype, 0)
}
func (rr *RecursiveResolver) resolve(name string, qtype uint16, depth int) *dns.Msg {
if depth >= rr.maxDepth {
log.Printf("[resolver] max depth reached for %s", name)
return nil
}
name = dns.Fqdn(name)
// Find the best nameservers to start from.
// Walk up the name to find cached NS records, fall back to root.
nameservers := rr.findBestNS(name)
// Iterative resolution: keep querying NS servers until we get an answer
for i := 0; i < rr.maxDepth; i++ {
resp := rr.queryServers(nameservers, name, qtype)
if resp == nil {
return nil
}
// Got an authoritative answer or a final answer with records
if resp.Authoritative && len(resp.Answer) > 0 {
return resp
}
// Check if answer section has what we want (non-authoritative but valid)
if len(resp.Answer) > 0 {
hasTarget := false
var cnameRR *dns.CNAME
for _, ans := range resp.Answer {
if ans.Header().Rrtype == qtype {
hasTarget = true
}
if cn, ok := ans.(*dns.CNAME); ok && qtype != dns.TypeCNAME {
cnameRR = cn
}
}
if hasTarget {
return resp
}
// Follow CNAME chain
if cnameRR != nil {
cResp := rr.resolve(cnameRR.Target, qtype, depth+1)
if cResp != nil {
// Prepend the CNAME to the answer
cResp.Answer = append([]dns.RR{cnameRR}, cResp.Answer...)
return cResp
}
}
return resp
}
// NXDOMAIN — name doesn't exist
if resp.Rcode == dns.RcodeNameError {
return resp
}
// NOERROR with no answer and no NS in authority = we're done
if len(resp.Ns) == 0 && len(resp.Answer) == 0 {
return resp
}
// Referral: extract NS records from authority section
var newNS []string
var nsNames []string
for _, rr := range resp.Ns {
if ns, ok := rr.(*dns.NS); ok {
nsNames = append(nsNames, ns.Ns)
}
}
if len(nsNames) == 0 {
// SOA in authority = negative response from authoritative server
for _, rr := range resp.Ns {
if _, ok := rr.(*dns.SOA); ok {
return resp
}
}
return resp
}
// Try to get IPs from the additional section (glue records)
glue := make(map[string]string)
for _, rr := range resp.Extra {
if a, ok := rr.(*dns.A); ok {
glue[strings.ToLower(a.Hdr.Name)] = a.A.String() + ":53"
}
}
for _, nsName := range nsNames {
key := strings.ToLower(dns.Fqdn(nsName))
if ip, ok := glue[key]; ok {
newNS = append(newNS, ip)
}
}
// If no glue, resolve NS names ourselves
if len(newNS) == 0 {
for _, nsName := range nsNames {
ips := rr.resolveNSName(nsName, depth+1)
newNS = append(newNS, ips...)
if len(newNS) >= 3 {
break // Enough NS IPs
}
}
}
if len(newNS) == 0 {
log.Printf("[resolver] no NS IPs found for delegation of %s", name)
return nil
}
// Cache the delegation
zone := extractZone(resp.Ns)
if zone != "" {
rr.cacheNS(zone, newNS)
}
nameservers = newNS
}
return nil
}
// resolveNSName resolves a nameserver hostname to its IP(s).
func (rr *RecursiveResolver) resolveNSName(nsName string, depth int) []string {
resp := rr.resolve(nsName, dns.TypeA, depth)
if resp == nil {
return nil
}
var ips []string
for _, ans := range resp.Answer {
if a, ok := ans.(*dns.A); ok {
ips = append(ips, a.A.String()+":53")
}
}
return ips
}
// queryServers sends a query to a list of nameservers, returns first valid response.
func (rr *RecursiveResolver) queryServers(servers []string, name string, qtype uint16) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(name), qtype)
msg.RecursionDesired = false // We're doing iterative resolution
for _, server := range servers {
resp, _, err := rr.client.Exchange(msg, server)
if err != nil {
continue
}
if resp != nil {
return resp
}
}
// Retry with TCP for truncated responses
msg.RecursionDesired = false
tcpClient := &dns.Client{Net: "tcp", Timeout: rr.timeout}
for _, server := range servers {
resp, _, err := tcpClient.Exchange(msg, server)
if err != nil {
continue
}
if resp != nil {
return resp
}
}
return nil
}
// queryUpstreamDoT sends a query to an upstream server via DNS-over-TLS (port 853).
func (rr *RecursiveResolver) QueryUpstreamDoT(req *dns.Msg, server string) (*dns.Msg, error) {
// Extract IP from server address (may include :53)
ip := server
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
ip = ip[:idx]
}
// Get TLS server name for certificate validation
serverName, ok := knownDoTServers[ip]
if !ok {
serverName = ip // Use IP as fallback (less secure, but works)
}
dotAddr := ip + ":853"
client := &dns.Client{
Net: "tcp-tls",
Timeout: 5 * time.Second,
TLSConfig: &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS12,
},
}
msg := req.Copy()
msg.RecursionDesired = true
resp, _, err := client.Exchange(msg, dotAddr)
return resp, err
}
// queryUpstreamDoH sends a query to an upstream server via DNS-over-HTTPS (RFC 8484).
func (rr *RecursiveResolver) QueryUpstreamDoH(req *dns.Msg, server string) (*dns.Msg, error) {
// Extract IP from server address
ip := server
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
ip = ip[:idx]
}
// Find the DoH endpoint URL
endpoint, ok := knownDoHEndpoints[ip]
if !ok {
return nil, fmt.Errorf("no DoH endpoint known for %s", ip)
}
// Encode DNS message as wire format
msg := req.Copy()
msg.RecursionDesired = true
wireMsg, err := msg.Pack()
if err != nil {
return nil, fmt.Errorf("pack DNS message: %w", err)
}
// POST as application/dns-message (RFC 8484)
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(wireMsg))
if err != nil {
return nil, fmt.Errorf("create HTTP request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/dns-message")
httpReq.Header.Set("Accept", "application/dns-message")
httpResp, err := rr.dohHTTP.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("DoH request to %s: %w", endpoint, err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH response status %d from %s", httpResp.StatusCode, endpoint)
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("read DoH response: %w", err)
}
resp := new(dns.Msg)
if err := resp.Unpack(body); err != nil {
return nil, fmt.Errorf("unpack DoH response: %w", err)
}
return resp, nil
}
// queryUpstreamEncrypted tries DoH first (if enabled), then DoT, then plain.
func (rr *RecursiveResolver) queryUpstreamEncrypted(req *dns.Msg, server string) (*dns.Msg, string, error) {
ip := server
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
ip = ip[:idx]
}
// Try DoH if enabled and we know the endpoint
if rr.EnableDoH {
if _, ok := knownDoHEndpoints[ip]; ok {
resp, err := rr.QueryUpstreamDoH(req, server)
if err == nil && resp != nil {
return resp, "doh", nil
}
log.Printf("[resolver] DoH failed for %s: %v, falling back", ip, err)
}
}
// Try DoT if enabled
if rr.EnableDoT {
resp, err := rr.QueryUpstreamDoT(req, server)
if err == nil && resp != nil {
return resp, "dot", nil
}
log.Printf("[resolver] DoT failed for %s: %v, falling back", ip, err)
}
// Plain DNS fallback
c := &dns.Client{Timeout: 5 * time.Second}
resp, _, err := c.Exchange(req, server)
if err != nil {
return nil, "plain", err
}
return resp, "plain", nil
}
// findBestNS finds the closest cached NS for the given name, or returns root servers.
func (rr *RecursiveResolver) findBestNS(name string) []string {
rr.nsCacheMu.RLock()
defer rr.nsCacheMu.RUnlock()
// Walk up the domain name
labels := dns.SplitDomainName(name)
for i := 0; i < len(labels); i++ {
zone := dns.Fqdn(strings.Join(labels[i:], "."))
if ns, ok := rr.nsCache[zone]; ok && len(ns) > 0 {
return ns
}
}
return rootServers
}
// cacheNS stores nameserver IPs for a zone.
func (rr *RecursiveResolver) cacheNS(zone string, servers []string) {
rr.nsCacheMu.Lock()
rr.nsCache[dns.Fqdn(zone)] = servers
rr.nsCacheMu.Unlock()
}
// extractZone gets the zone name from NS authority records.
func extractZone(ns []dns.RR) string {
for _, rr := range ns {
if nsRR, ok := rr.(*dns.NS); ok {
return nsRR.Hdr.Name
}
}
return ""
}
// ResolveWithFallback tries recursive resolution, falls back to upstream forwarders.
// Now with DoT/DoH encryption support for upstream queries.
func (rr *RecursiveResolver) ResolveWithFallback(req *dns.Msg, upstream []string) *dns.Msg {
// Try full recursive first
resp := rr.Resolve(req)
if resp != nil && resp.Rcode != dns.RcodeServerFailure {
return resp
}
// Fallback to upstream forwarders if configured — use encrypted transport
if len(upstream) > 0 {
for _, us := range upstream {
resp, mode, err := rr.queryUpstreamEncrypted(req, us)
if err == nil && resp != nil {
if mode != "plain" {
log.Printf("[resolver] upstream %s answered via %s", us, mode)
}
return resp
}
}
}
return resp
}
// GetEncryptionStatus returns the current encryption mode info.
func (rr *RecursiveResolver) GetEncryptionStatus() map[string]interface{} {
status := map[string]interface{}{
"dot_enabled": rr.EnableDoT,
"doh_enabled": rr.EnableDoH,
"dot_servers": knownDoTServers,
"doh_servers": knownDoHEndpoints,
}
if rr.EnableDoH {
status["preferred_mode"] = "doh"
} else if rr.EnableDoT {
status["preferred_mode"] = "dot"
} else {
status["preferred_mode"] = "plain"
}
return status
}
// FlushNSCache clears all cached NS delegations.
func (rr *RecursiveResolver) FlushNSCache() {
rr.nsCacheMu.Lock()
rr.nsCache = make(map[string][]string)
rr.nsCacheMu.Unlock()
}
// GetNSCache returns a copy of the NS delegation cache.
func (rr *RecursiveResolver) GetNSCache() map[string][]string {
rr.nsCacheMu.RLock()
defer rr.nsCacheMu.RUnlock()
result := make(map[string][]string, len(rr.nsCache))
for k, v := range rr.nsCache {
cp := make([]string, len(v))
copy(cp, v)
result[k] = cp
}
return result
}
// String returns resolver info for debugging.
func (rr *RecursiveResolver) String() string {
rr.nsCacheMu.RLock()
defer rr.nsCacheMu.RUnlock()
mode := "plain"
if rr.EnableDoH {
mode = "DoH"
} else if rr.EnableDoT {
mode = "DoT"
}
return fmt.Sprintf("RecursiveResolver{cached_zones=%d, max_depth=%d, mode=%s}", len(rr.nsCache), rr.maxDepth, mode)
}

View File

@@ -0,0 +1,525 @@
package server
import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
// RecordType represents supported DNS record types.
type RecordType string
const (
TypeA RecordType = "A"
TypeAAAA RecordType = "AAAA"
TypeCNAME RecordType = "CNAME"
TypeMX RecordType = "MX"
TypeTXT RecordType = "TXT"
TypeNS RecordType = "NS"
TypeSRV RecordType = "SRV"
TypePTR RecordType = "PTR"
TypeSOA RecordType = "SOA"
)
// Record is a single DNS record.
type Record struct {
ID string `json:"id"`
Type RecordType `json:"type"`
Name string `json:"name"`
Value string `json:"value"`
TTL uint32 `json:"ttl"`
Priority uint16 `json:"priority,omitempty"` // MX, SRV
Weight uint16 `json:"weight,omitempty"` // SRV
Port uint16 `json:"port,omitempty"` // SRV
}
// Zone represents a DNS zone with its records.
type Zone struct {
Domain string `json:"domain"`
SOA SOARecord `json:"soa"`
Records []Record `json:"records"`
DNSSEC bool `json:"dnssec"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// SOARecord holds SOA-specific fields.
type SOARecord struct {
PrimaryNS string `json:"primary_ns"`
AdminEmail string `json:"admin_email"`
Serial uint32 `json:"serial"`
Refresh uint32 `json:"refresh"`
Retry uint32 `json:"retry"`
Expire uint32 `json:"expire"`
MinTTL uint32 `json:"min_ttl"`
}
// ZoneStore manages zones on disk and in memory.
type ZoneStore struct {
mu sync.RWMutex
zones map[string]*Zone
zonesDir string
}
// NewZoneStore creates a store backed by a directory.
func NewZoneStore(dir string) *ZoneStore {
os.MkdirAll(dir, 0755)
return &ZoneStore{
zones: make(map[string]*Zone),
zonesDir: dir,
}
}
// LoadAll reads all zone files from disk.
func (s *ZoneStore) LoadAll() error {
entries, err := os.ReadDir(s.zonesDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
for _, e := range entries {
if filepath.Ext(e.Name()) != ".json" {
continue
}
data, err := os.ReadFile(filepath.Join(s.zonesDir, e.Name()))
if err != nil {
continue
}
var z Zone
if err := json.Unmarshal(data, &z); err != nil {
continue
}
s.zones[dns.Fqdn(z.Domain)] = &z
}
return nil
}
// Save writes a zone to disk.
func (s *ZoneStore) Save(z *Zone) error {
z.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
data, err := json.MarshalIndent(z, "", " ")
if err != nil {
return err
}
fname := filepath.Join(s.zonesDir, z.Domain+".json")
return os.WriteFile(fname, data, 0644)
}
// Get returns a zone by domain.
func (s *ZoneStore) Get(domain string) *Zone {
s.mu.RLock()
defer s.mu.RUnlock()
return s.zones[dns.Fqdn(domain)]
}
// List returns all zones.
func (s *ZoneStore) List() []*Zone {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*Zone, 0, len(s.zones))
for _, z := range s.zones {
result = append(result, z)
}
return result
}
// Create adds a new zone.
func (s *ZoneStore) Create(domain string) (*Zone, error) {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.zones[fqdn]; exists {
return nil, fmt.Errorf("zone %s already exists", domain)
}
now := time.Now().UTC().Format(time.RFC3339)
z := &Zone{
Domain: domain,
SOA: SOARecord{
PrimaryNS: "ns1." + domain,
AdminEmail: "admin." + domain,
Serial: uint32(time.Now().Unix()),
Refresh: 3600,
Retry: 600,
Expire: 86400,
MinTTL: 300,
},
Records: []Record{
{ID: "ns1", Type: TypeNS, Name: domain + ".", Value: "ns1." + domain + ".", TTL: 3600},
},
CreatedAt: now,
UpdatedAt: now,
}
s.zones[fqdn] = z
return z, s.Save(z)
}
// Delete removes a zone.
func (s *ZoneStore) Delete(domain string) error {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.zones[fqdn]; !exists {
return fmt.Errorf("zone %s not found", domain)
}
delete(s.zones, fqdn)
fname := filepath.Join(s.zonesDir, domain+".json")
os.Remove(fname)
return nil
}
// AddRecord adds a record to a zone.
func (s *ZoneStore) AddRecord(domain string, rec Record) error {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
z, ok := s.zones[fqdn]
if !ok {
return fmt.Errorf("zone %s not found", domain)
}
if rec.ID == "" {
rec.ID = fmt.Sprintf("r%d", time.Now().UnixNano())
}
if rec.TTL == 0 {
rec.TTL = 300
}
z.Records = append(z.Records, rec)
z.SOA.Serial++
return s.Save(z)
}
// DeleteRecord removes a record by ID.
func (s *ZoneStore) DeleteRecord(domain, recordID string) error {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
z, ok := s.zones[fqdn]
if !ok {
return fmt.Errorf("zone %s not found", domain)
}
for i, r := range z.Records {
if r.ID == recordID {
z.Records = append(z.Records[:i], z.Records[i+1:]...)
z.SOA.Serial++
return s.Save(z)
}
}
return fmt.Errorf("record %s not found", recordID)
}
// UpdateRecord updates a record by ID.
func (s *ZoneStore) UpdateRecord(domain, recordID string, rec Record) error {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
z, ok := s.zones[fqdn]
if !ok {
return fmt.Errorf("zone %s not found", domain)
}
for i, r := range z.Records {
if r.ID == recordID {
rec.ID = recordID
z.Records[i] = rec
z.SOA.Serial++
return s.Save(z)
}
}
return fmt.Errorf("record %s not found", recordID)
}
// Lookup finds records matching a query name and type within all zones.
func (s *ZoneStore) Lookup(name string, qtype uint16) []dns.RR {
s.mu.RLock()
defer s.mu.RUnlock()
fqdn := dns.Fqdn(name)
var results []dns.RR
// Find the zone for this name
for zoneDomain, z := range s.zones {
if !dns.IsSubDomain(zoneDomain, fqdn) {
continue
}
// Check records
for _, rec := range z.Records {
recFQDN := dns.Fqdn(rec.Name)
if recFQDN != fqdn {
continue
}
if rr := recordToRR(rec, fqdn); rr != nil {
if qtype == dns.TypeANY || rr.Header().Rrtype == qtype {
results = append(results, rr)
}
}
}
// SOA for zone apex
if fqdn == zoneDomain && (qtype == dns.TypeSOA || qtype == dns.TypeANY) {
soa := &dns.SOA{
Hdr: dns.RR_Header{Name: zoneDomain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: z.SOA.MinTTL},
Ns: dns.Fqdn(z.SOA.PrimaryNS),
Mbox: dns.Fqdn(z.SOA.AdminEmail),
Serial: z.SOA.Serial,
Refresh: z.SOA.Refresh,
Retry: z.SOA.Retry,
Expire: z.SOA.Expire,
Minttl: z.SOA.MinTTL,
}
results = append(results, soa)
}
}
return results
}
func recordToRR(rec Record, fqdn string) dns.RR {
hdr := dns.RR_Header{Name: fqdn, Class: dns.ClassINET, Ttl: rec.TTL}
switch rec.Type {
case TypeA:
hdr.Rrtype = dns.TypeA
rr := &dns.A{Hdr: hdr}
rr.A = parseIP(rec.Value)
if rr.A == nil {
return nil
}
return rr
case TypeAAAA:
hdr.Rrtype = dns.TypeAAAA
rr := &dns.AAAA{Hdr: hdr}
rr.AAAA = parseIP(rec.Value)
if rr.AAAA == nil {
return nil
}
return rr
case TypeCNAME:
hdr.Rrtype = dns.TypeCNAME
return &dns.CNAME{Hdr: hdr, Target: dns.Fqdn(rec.Value)}
case TypeMX:
hdr.Rrtype = dns.TypeMX
return &dns.MX{Hdr: hdr, Preference: rec.Priority, Mx: dns.Fqdn(rec.Value)}
case TypeTXT:
hdr.Rrtype = dns.TypeTXT
return &dns.TXT{Hdr: hdr, Txt: []string{rec.Value}}
case TypeNS:
hdr.Rrtype = dns.TypeNS
return &dns.NS{Hdr: hdr, Ns: dns.Fqdn(rec.Value)}
case TypeSRV:
hdr.Rrtype = dns.TypeSRV
return &dns.SRV{Hdr: hdr, Priority: rec.Priority, Weight: rec.Weight, Port: rec.Port, Target: dns.Fqdn(rec.Value)}
case TypePTR:
hdr.Rrtype = dns.TypePTR
return &dns.PTR{Hdr: hdr, Ptr: dns.Fqdn(rec.Value)}
}
return nil
}
func parseIP(s string) net.IP {
return net.ParseIP(s)
}
// ExportZoneFile exports a zone in BIND zone file format.
func (s *ZoneStore) ExportZoneFile(domain string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
z, ok := s.zones[dns.Fqdn(domain)]
if !ok {
return "", fmt.Errorf("zone %s not found", domain)
}
var b strings.Builder
b.WriteString(fmt.Sprintf("; Zone file for %s\n", z.Domain))
b.WriteString(fmt.Sprintf("; Exported at %s\n", time.Now().UTC().Format(time.RFC3339)))
b.WriteString(fmt.Sprintf("$ORIGIN %s.\n", z.Domain))
b.WriteString(fmt.Sprintf("$TTL %d\n\n", z.SOA.MinTTL))
// SOA
b.WriteString(fmt.Sprintf("@ IN SOA %s. %s. (\n", z.SOA.PrimaryNS, z.SOA.AdminEmail))
b.WriteString(fmt.Sprintf(" %d ; serial\n", z.SOA.Serial))
b.WriteString(fmt.Sprintf(" %d ; refresh\n", z.SOA.Refresh))
b.WriteString(fmt.Sprintf(" %d ; retry\n", z.SOA.Retry))
b.WriteString(fmt.Sprintf(" %d ; expire\n", z.SOA.Expire))
b.WriteString(fmt.Sprintf(" %d ; minimum TTL\n)\n\n", z.SOA.MinTTL))
// Records grouped by type
for _, rec := range z.Records {
name := rec.Name
// Make relative to origin
suffix := "." + z.Domain + "."
if strings.HasSuffix(name, suffix) {
name = strings.TrimSuffix(name, suffix)
} else if name == z.Domain+"." {
name = "@"
}
switch rec.Type {
case TypeMX:
b.WriteString(fmt.Sprintf("%-24s %d IN MX %d %s\n", name, rec.TTL, rec.Priority, rec.Value))
case TypeSRV:
b.WriteString(fmt.Sprintf("%-24s %d IN SRV %d %d %d %s\n", name, rec.TTL, rec.Priority, rec.Weight, rec.Port, rec.Value))
default:
b.WriteString(fmt.Sprintf("%-24s %d IN %-6s %s\n", name, rec.TTL, rec.Type, rec.Value))
}
}
return b.String(), nil
}
// ImportZoneFile parses a BIND-style zone file and adds records.
// Returns number of records added.
func (s *ZoneStore) ImportZoneFile(domain, content string) (int, error) {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
z, ok := s.zones[fqdn]
if !ok {
return 0, fmt.Errorf("zone %s not found — create it first", domain)
}
added := 0
zp := dns.NewZoneParser(strings.NewReader(content), dns.Fqdn(domain), "")
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
hdr := rr.Header()
rec := Record{
ID: fmt.Sprintf("imp%d", time.Now().UnixNano()+int64(added)),
Name: hdr.Name,
TTL: hdr.Ttl,
}
switch v := rr.(type) {
case *dns.A:
rec.Type = TypeA
rec.Value = v.A.String()
case *dns.AAAA:
rec.Type = TypeAAAA
rec.Value = v.AAAA.String()
case *dns.CNAME:
rec.Type = TypeCNAME
rec.Value = v.Target
case *dns.MX:
rec.Type = TypeMX
rec.Value = v.Mx
rec.Priority = v.Preference
case *dns.TXT:
rec.Type = TypeTXT
rec.Value = strings.Join(v.Txt, " ")
case *dns.NS:
rec.Type = TypeNS
rec.Value = v.Ns
case *dns.SRV:
rec.Type = TypeSRV
rec.Value = v.Target
rec.Priority = v.Priority
rec.Weight = v.Weight
rec.Port = v.Port
case *dns.PTR:
rec.Type = TypePTR
rec.Value = v.Ptr
default:
continue // Skip unsupported types
}
z.Records = append(z.Records, rec)
added++
}
if added > 0 {
z.SOA.Serial++
s.Save(z)
}
return added, nil
}
// CloneZone duplicates a zone under a new domain.
func (s *ZoneStore) CloneZone(srcDomain, dstDomain string) (*Zone, error) {
srcFQDN := dns.Fqdn(srcDomain)
dstFQDN := dns.Fqdn(dstDomain)
s.mu.Lock()
defer s.mu.Unlock()
src, ok := s.zones[srcFQDN]
if !ok {
return nil, fmt.Errorf("source zone %s not found", srcDomain)
}
if _, exists := s.zones[dstFQDN]; exists {
return nil, fmt.Errorf("destination zone %s already exists", dstDomain)
}
now := time.Now().UTC().Format(time.RFC3339)
z := &Zone{
Domain: dstDomain,
SOA: SOARecord{
PrimaryNS: strings.Replace(src.SOA.PrimaryNS, srcDomain, dstDomain, -1),
AdminEmail: strings.Replace(src.SOA.AdminEmail, srcDomain, dstDomain, -1),
Serial: uint32(time.Now().Unix()),
Refresh: src.SOA.Refresh,
Retry: src.SOA.Retry,
Expire: src.SOA.Expire,
MinTTL: src.SOA.MinTTL,
},
CreatedAt: now,
UpdatedAt: now,
}
// Clone records, replacing domain references
for _, rec := range src.Records {
newRec := rec
newRec.ID = fmt.Sprintf("c%d", time.Now().UnixNano())
newRec.Name = strings.Replace(rec.Name, srcDomain, dstDomain, -1)
newRec.Value = strings.Replace(rec.Value, srcDomain, dstDomain, -1)
z.Records = append(z.Records, newRec)
time.Sleep(time.Nanosecond) // Ensure unique IDs
}
s.zones[dstFQDN] = z
return z, s.Save(z)
}
// BulkAddRecords adds multiple records at once.
func (s *ZoneStore) BulkAddRecords(domain string, records []Record) (int, error) {
fqdn := dns.Fqdn(domain)
s.mu.Lock()
defer s.mu.Unlock()
z, ok := s.zones[fqdn]
if !ok {
return 0, fmt.Errorf("zone %s not found", domain)
}
added := 0
for _, rec := range records {
if rec.ID == "" {
rec.ID = fmt.Sprintf("b%d", time.Now().UnixNano()+int64(added))
}
if rec.TTL == 0 {
rec.TTL = 300
}
z.Records = append(z.Records, rec)
added++
}
if added > 0 {
z.SOA.Serial++
s.Save(z)
}
return added, nil
}