Autarch Will Control The Internet
This commit is contained in:
1081
services/dns-server/api/router.go
Normal file
1081
services/dns-server/api/router.go
Normal file
File diff suppressed because it is too large
Load Diff
BIN
services/dns-server/autarch-dns.exe
Normal file
BIN
services/dns-server/autarch-dns.exe
Normal file
Binary file not shown.
26
services/dns-server/build.sh
Normal file
26
services/dns-server/build.sh
Normal file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
# Cross-compile autarch-dns for all supported platforms
|
||||
set -e
|
||||
|
||||
VERSION="1.0.0"
|
||||
OUTPUT_BASE="../../tools"
|
||||
|
||||
echo "Building autarch-dns v${VERSION}..."
|
||||
|
||||
# Linux ARM64 (Orange Pi 5 Plus)
|
||||
echo " → linux/arm64"
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="-s -w -X main.version=${VERSION}" \
|
||||
-o "${OUTPUT_BASE}/linux-arm64/autarch-dns" .
|
||||
|
||||
# Linux AMD64
|
||||
echo " → linux/amd64"
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="-s -w -X main.version=${VERSION}" \
|
||||
-o "${OUTPUT_BASE}/linux-x86_64/autarch-dns" .
|
||||
|
||||
# Windows AMD64
|
||||
echo " → windows/amd64"
|
||||
GOOS=windows GOARCH=amd64 go build -ldflags="-s -w -X main.version=${VERSION}" \
|
||||
-o "${OUTPUT_BASE}/windows-x86_64/autarch-dns.exe" .
|
||||
|
||||
echo "Done! Binaries:"
|
||||
ls -lh "${OUTPUT_BASE}"/*/autarch-dns* 2>/dev/null || true
|
||||
84
services/dns-server/config/config.go
Normal file
84
services/dns-server/config/config.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// Config holds all DNS server configuration.
|
||||
type Config struct {
|
||||
ListenDNS string `json:"listen_dns"`
|
||||
ListenAPI string `json:"listen_api"`
|
||||
APIToken string `json:"api_token"`
|
||||
Upstream []string `json:"upstream"`
|
||||
CacheTTL int `json:"cache_ttl"`
|
||||
ZonesDir string `json:"zones_dir"`
|
||||
DNSSECKeyDir string `json:"dnssec_keys_dir"`
|
||||
LogQueries bool `json:"log_queries"`
|
||||
|
||||
// Hosts file support
|
||||
HostsFile string `json:"hosts_file"` // Path to hosts file (e.g., /etc/hosts)
|
||||
HostsAutoLoad bool `json:"hosts_auto_load"` // Auto-load system hosts file on start
|
||||
|
||||
// Encryption
|
||||
EnableDoH bool `json:"enable_doh"` // DNS-over-HTTPS to upstream
|
||||
EnableDoT bool `json:"enable_dot"` // DNS-over-TLS to upstream
|
||||
|
||||
// Security hardening
|
||||
RateLimit int `json:"rate_limit"` // Max queries/sec per source IP (0=unlimited)
|
||||
BlockList []string `json:"block_list"` // Blocked domain patterns
|
||||
AllowTransfer []string `json:"allow_transfer"` // IPs allowed zone transfers (empty=none)
|
||||
MinimalResponses bool `json:"minimal_responses"` // Minimize response data
|
||||
RefuseANY bool `json:"refuse_any"` // Refuse ANY queries (amplification protection)
|
||||
MaxUDPSize int `json:"max_udp_size"` // Max UDP response size
|
||||
|
||||
// Advanced
|
||||
QueryLogMax int `json:"querylog_max"` // Max query log entries (default 1000)
|
||||
NegativeCacheTTL int `json:"negative_cache_ttl"` // TTL for NXDOMAIN cache (default 60)
|
||||
PrefetchEnabled bool `json:"prefetch_enabled"` // Prefetch expiring cache entries
|
||||
ServFailCacheTTL int `json:"servfail_cache_ttl"` // TTL for SERVFAIL cache (default 30)
|
||||
}
|
||||
|
||||
// DefaultConfig returns security-hardened defaults.
|
||||
// No upstream forwarders — full recursive resolution from root hints.
|
||||
// Upstream can be configured as optional fallback if recursive fails.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
ListenDNS: "0.0.0.0:53",
|
||||
ListenAPI: "127.0.0.1:5380",
|
||||
APIToken: generateToken(),
|
||||
Upstream: []string{}, // Empty = pure recursive from root hints
|
||||
CacheTTL: 300,
|
||||
ZonesDir: "data/dns/zones",
|
||||
DNSSECKeyDir: "data/dns/keys",
|
||||
LogQueries: true,
|
||||
|
||||
// Hosts
|
||||
HostsFile: "",
|
||||
HostsAutoLoad: false,
|
||||
|
||||
// Encryption defaults
|
||||
EnableDoH: true,
|
||||
EnableDoT: true,
|
||||
|
||||
// Security defaults
|
||||
RateLimit: 100, // 100 qps per source IP
|
||||
BlockList: []string{},
|
||||
AllowTransfer: []string{}, // No zone transfers
|
||||
MinimalResponses: true,
|
||||
RefuseANY: true, // Block DNS amplification attacks
|
||||
MaxUDPSize: 1232, // Safe MTU, prevent fragmentation
|
||||
|
||||
// Advanced defaults
|
||||
QueryLogMax: 1000,
|
||||
NegativeCacheTTL: 60,
|
||||
PrefetchEnabled: false,
|
||||
ServFailCacheTTL: 30,
|
||||
}
|
||||
}
|
||||
|
||||
func generateToken() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
13
services/dns-server/go.mod
Normal file
13
services/dns-server/go.mod
Normal file
@@ -0,0 +1,13 @@
|
||||
module github.com/darkhal/autarch-dns
|
||||
|
||||
go 1.22
|
||||
|
||||
require github.com/miekg/dns v1.1.62
|
||||
|
||||
require (
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.22.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
)
|
||||
12
services/dns-server/go.sum
Normal file
12
services/dns-server/go.sum
Normal file
@@ -0,0 +1,12 @@
|
||||
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
|
||||
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
|
||||
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
||||
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
||||
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
||||
84
services/dns-server/main.go
Normal file
84
services/dns-server/main.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/darkhal/autarch-dns/api"
|
||||
"github.com/darkhal/autarch-dns/config"
|
||||
"github.com/darkhal/autarch-dns/server"
|
||||
)
|
||||
|
||||
var version = "2.1.0"
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "config.json", "Path to config file")
|
||||
listenDNS := flag.String("dns", "", "DNS listen address (overrides config)")
|
||||
listenAPI := flag.String("api", "", "API listen address (overrides config)")
|
||||
apiToken := flag.String("token", "", "API auth token (overrides config)")
|
||||
showVersion := flag.Bool("version", false, "Show version")
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
fmt.Printf("autarch-dns v%s\n", version)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Load config
|
||||
cfg := config.DefaultConfig()
|
||||
if data, err := os.ReadFile(*configPath); err == nil {
|
||||
if err := json.Unmarshal(data, cfg); err != nil {
|
||||
log.Printf("Warning: invalid config file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CLI overrides
|
||||
if *listenDNS != "" {
|
||||
cfg.ListenDNS = *listenDNS
|
||||
}
|
||||
if *listenAPI != "" {
|
||||
cfg.ListenAPI = *listenAPI
|
||||
}
|
||||
if *apiToken != "" {
|
||||
cfg.APIToken = *apiToken
|
||||
}
|
||||
|
||||
// Initialize zone store
|
||||
store := server.NewZoneStore(cfg.ZonesDir)
|
||||
if err := store.LoadAll(); err != nil {
|
||||
log.Printf("Warning: loading zones: %v", err)
|
||||
}
|
||||
|
||||
// Start DNS server
|
||||
dnsServer := server.NewDNSServer(cfg, store)
|
||||
go func() {
|
||||
log.Printf("DNS server listening on %s (UDP+TCP)", cfg.ListenDNS)
|
||||
if err := dnsServer.Start(); err != nil {
|
||||
log.Fatalf("DNS server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start API server
|
||||
apiServer := api.NewAPIServer(cfg, store, dnsServer)
|
||||
go func() {
|
||||
log.Printf("API server listening on %s", cfg.ListenAPI)
|
||||
if err := apiServer.Start(); err != nil {
|
||||
log.Fatalf("API server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("autarch-dns v%s started", version)
|
||||
|
||||
// Wait for shutdown signal
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sig
|
||||
|
||||
log.Println("Shutting down...")
|
||||
dnsServer.Stop()
|
||||
}
|
||||
656
services/dns-server/server/dns.go
Normal file
656
services/dns-server/server/dns.go
Normal file
@@ -0,0 +1,656 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/darkhal/autarch-dns/config"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Metrics holds query statistics.
|
||||
type Metrics struct {
|
||||
TotalQueries uint64 `json:"total_queries"`
|
||||
CacheHits uint64 `json:"cache_hits"`
|
||||
CacheMisses uint64 `json:"cache_misses"`
|
||||
LocalAnswers uint64 `json:"local_answers"`
|
||||
ResolvedQ uint64 `json:"resolved"`
|
||||
BlockedQ uint64 `json:"blocked"`
|
||||
FailedQ uint64 `json:"failed"`
|
||||
StartTime string `json:"start_time"`
|
||||
}
|
||||
|
||||
// QueryLogEntry records a single DNS query.
|
||||
type QueryLogEntry struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Client string `json:"client"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Rcode string `json:"rcode"`
|
||||
Answers int `json:"answers"`
|
||||
Latency string `json:"latency"`
|
||||
Source string `json:"source"` // "local", "cache", "recursive", "blocked", "failed"
|
||||
}
|
||||
|
||||
// CacheEntry holds a cached DNS response.
|
||||
type CacheEntry struct {
|
||||
msg *dns.Msg
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// CacheInfo is an exportable view of a cache entry.
|
||||
type CacheInfo struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
TTL int `json:"ttl_remaining"`
|
||||
Answers int `json:"answers"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
|
||||
// DomainCount tracks query frequency per domain.
|
||||
type DomainCount struct {
|
||||
Domain string `json:"domain"`
|
||||
Count uint64 `json:"count"`
|
||||
}
|
||||
|
||||
// DNSServer is the main DNS server.
|
||||
type DNSServer struct {
|
||||
cfg *config.Config
|
||||
store *ZoneStore
|
||||
hosts *HostsStore
|
||||
resolver *RecursiveResolver
|
||||
metrics Metrics
|
||||
cache map[string]*CacheEntry
|
||||
cacheMu sync.RWMutex
|
||||
udpServ *dns.Server
|
||||
tcpServ *dns.Server
|
||||
|
||||
// Query log — ring buffer
|
||||
queryLog []QueryLogEntry
|
||||
queryLogMu sync.RWMutex
|
||||
queryLogMax int
|
||||
|
||||
// Domain frequency tracking
|
||||
domainCounts map[string]uint64
|
||||
domainCountsMu sync.RWMutex
|
||||
|
||||
// Query type tracking
|
||||
typeCounts map[string]uint64
|
||||
typeCountsMu sync.RWMutex
|
||||
|
||||
// Client tracking
|
||||
clientCounts map[string]uint64
|
||||
clientCountsMu sync.RWMutex
|
||||
|
||||
// Blocklist — fast lookup
|
||||
blocklist map[string]bool
|
||||
blocklistMu sync.RWMutex
|
||||
|
||||
// Conditional forwarding: zone -> upstream servers
|
||||
conditionalFwd map[string][]string
|
||||
conditionalFwdMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDNSServer creates a DNS server.
|
||||
func NewDNSServer(cfg *config.Config, store *ZoneStore) *DNSServer {
|
||||
resolver := NewRecursiveResolver()
|
||||
resolver.EnableDoT = cfg.EnableDoT
|
||||
resolver.EnableDoH = cfg.EnableDoH
|
||||
|
||||
logMax := cfg.QueryLogMax
|
||||
if logMax <= 0 {
|
||||
logMax = 1000
|
||||
}
|
||||
|
||||
s := &DNSServer{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
hosts: NewHostsStore(),
|
||||
resolver: resolver,
|
||||
cache: make(map[string]*CacheEntry),
|
||||
queryLog: make([]QueryLogEntry, 0, logMax),
|
||||
queryLogMax: logMax,
|
||||
domainCounts: make(map[string]uint64),
|
||||
typeCounts: make(map[string]uint64),
|
||||
clientCounts: make(map[string]uint64),
|
||||
blocklist: make(map[string]bool),
|
||||
conditionalFwd: make(map[string][]string),
|
||||
metrics: Metrics{
|
||||
StartTime: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
// Load blocklist from config
|
||||
for _, pattern := range cfg.BlockList {
|
||||
s.blocklist[dns.Fqdn(strings.ToLower(pattern))] = true
|
||||
}
|
||||
|
||||
// Load hosts file if configured
|
||||
if cfg.HostsFile != "" {
|
||||
if err := s.hosts.LoadFile(cfg.HostsFile); err != nil {
|
||||
log.Printf("[hosts] Warning: could not load hosts file %s: %v", cfg.HostsFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// GetHosts returns the hosts store.
|
||||
func (s *DNSServer) GetHosts() *HostsStore {
|
||||
return s.hosts
|
||||
}
|
||||
|
||||
// GetEncryptionStatus returns encryption info from the resolver.
|
||||
func (s *DNSServer) GetEncryptionStatus() map[string]interface{} {
|
||||
return s.resolver.GetEncryptionStatus()
|
||||
}
|
||||
|
||||
// SetEncryption updates DoT/DoH settings on the resolver.
|
||||
func (s *DNSServer) SetEncryption(dot, doh bool) {
|
||||
s.resolver.EnableDoT = dot
|
||||
s.resolver.EnableDoH = doh
|
||||
s.cfg.EnableDoT = dot
|
||||
s.cfg.EnableDoH = doh
|
||||
}
|
||||
|
||||
// GetResolver returns the underlying recursive resolver.
|
||||
func (s *DNSServer) GetResolver() *RecursiveResolver {
|
||||
return s.resolver
|
||||
}
|
||||
|
||||
// Start begins listening on UDP and TCP.
|
||||
func (s *DNSServer) Start() error {
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", s.handleQuery)
|
||||
|
||||
s.udpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "udp", Handler: mux}
|
||||
s.tcpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "tcp", Handler: mux}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- s.udpServ.ListenAndServe() }()
|
||||
go func() { errCh <- s.tcpServ.ListenAndServe() }()
|
||||
|
||||
go s.cacheCleanup()
|
||||
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// Stop shuts down both servers.
|
||||
func (s *DNSServer) Stop() {
|
||||
if s.udpServ != nil {
|
||||
s.udpServ.Shutdown()
|
||||
}
|
||||
if s.tcpServ != nil {
|
||||
s.tcpServ.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current metrics.
|
||||
func (s *DNSServer) GetMetrics() Metrics {
|
||||
return Metrics{
|
||||
TotalQueries: atomic.LoadUint64(&s.metrics.TotalQueries),
|
||||
CacheHits: atomic.LoadUint64(&s.metrics.CacheHits),
|
||||
CacheMisses: atomic.LoadUint64(&s.metrics.CacheMisses),
|
||||
LocalAnswers: atomic.LoadUint64(&s.metrics.LocalAnswers),
|
||||
ResolvedQ: atomic.LoadUint64(&s.metrics.ResolvedQ),
|
||||
BlockedQ: atomic.LoadUint64(&s.metrics.BlockedQ),
|
||||
FailedQ: atomic.LoadUint64(&s.metrics.FailedQ),
|
||||
StartTime: s.metrics.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryLog returns the last N query log entries.
|
||||
func (s *DNSServer) GetQueryLog(limit int) []QueryLogEntry {
|
||||
s.queryLogMu.RLock()
|
||||
defer s.queryLogMu.RUnlock()
|
||||
|
||||
n := len(s.queryLog)
|
||||
if limit <= 0 || limit > n {
|
||||
limit = n
|
||||
}
|
||||
// Return most recent first
|
||||
result := make([]QueryLogEntry, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
result[i] = s.queryLog[n-1-i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ClearQueryLog empties the log.
|
||||
func (s *DNSServer) ClearQueryLog() {
|
||||
s.queryLogMu.Lock()
|
||||
s.queryLog = s.queryLog[:0]
|
||||
s.queryLogMu.Unlock()
|
||||
}
|
||||
|
||||
// GetCacheEntries returns all cache entries.
|
||||
func (s *DNSServer) GetCacheEntries() []CacheInfo {
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
entries := make([]CacheInfo, 0, len(s.cache))
|
||||
for key, entry := range s.cache {
|
||||
if now.After(entry.expiresAt) {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(key, "/", 2)
|
||||
name, qtype := key, ""
|
||||
if len(parts) == 2 {
|
||||
name, qtype = parts[0], parts[1]
|
||||
}
|
||||
entries = append(entries, CacheInfo{
|
||||
Key: key,
|
||||
Name: name,
|
||||
Type: qtype,
|
||||
TTL: int(entry.expiresAt.Sub(now).Seconds()),
|
||||
Answers: len(entry.msg.Answer),
|
||||
ExpiresAt: entry.expiresAt.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// CacheSize returns number of active cache entries.
|
||||
func (s *DNSServer) CacheSize() int {
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
return len(s.cache)
|
||||
}
|
||||
|
||||
// FlushCache clears all cached responses.
|
||||
func (s *DNSServer) FlushCache() int {
|
||||
s.cacheMu.Lock()
|
||||
n := len(s.cache)
|
||||
s.cache = make(map[string]*CacheEntry)
|
||||
s.cacheMu.Unlock()
|
||||
// Also flush resolver NS cache
|
||||
s.resolver.FlushNSCache()
|
||||
return n
|
||||
}
|
||||
|
||||
// FlushCacheEntry removes a single cache entry.
|
||||
func (s *DNSServer) FlushCacheEntry(key string) bool {
|
||||
s.cacheMu.Lock()
|
||||
defer s.cacheMu.Unlock()
|
||||
if _, ok := s.cache[key]; ok {
|
||||
delete(s.cache, key)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetTopDomains returns the most-queried domains.
|
||||
func (s *DNSServer) GetTopDomains(limit int) []DomainCount {
|
||||
s.domainCountsMu.RLock()
|
||||
defer s.domainCountsMu.RUnlock()
|
||||
|
||||
counts := make([]DomainCount, 0, len(s.domainCounts))
|
||||
for domain, count := range s.domainCounts {
|
||||
counts = append(counts, DomainCount{Domain: domain, Count: count})
|
||||
}
|
||||
sort.Slice(counts, func(i, j int) bool { return counts[i].Count > counts[j].Count })
|
||||
if limit > 0 && limit < len(counts) {
|
||||
counts = counts[:limit]
|
||||
}
|
||||
return counts
|
||||
}
|
||||
|
||||
// GetQueryTypeCounts returns counts by query type.
|
||||
func (s *DNSServer) GetQueryTypeCounts() map[string]uint64 {
|
||||
s.typeCountsMu.RLock()
|
||||
defer s.typeCountsMu.RUnlock()
|
||||
result := make(map[string]uint64, len(s.typeCounts))
|
||||
for k, v := range s.typeCounts {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetClientCounts returns counts by client IP.
|
||||
func (s *DNSServer) GetClientCounts() map[string]uint64 {
|
||||
s.clientCountsMu.RLock()
|
||||
defer s.clientCountsMu.RUnlock()
|
||||
result := make(map[string]uint64, len(s.clientCounts))
|
||||
for k, v := range s.clientCounts {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AddBlocklistEntry adds a domain to the blocklist.
|
||||
func (s *DNSServer) AddBlocklistEntry(domain string) {
|
||||
s.blocklistMu.Lock()
|
||||
s.blocklist[dns.Fqdn(strings.ToLower(domain))] = true
|
||||
s.blocklistMu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveBlocklistEntry removes a domain from the blocklist.
|
||||
func (s *DNSServer) RemoveBlocklistEntry(domain string) {
|
||||
s.blocklistMu.Lock()
|
||||
delete(s.blocklist, dns.Fqdn(strings.ToLower(domain)))
|
||||
s.blocklistMu.Unlock()
|
||||
}
|
||||
|
||||
// GetBlocklist returns all blocked domains.
|
||||
func (s *DNSServer) GetBlocklist() []string {
|
||||
s.blocklistMu.RLock()
|
||||
defer s.blocklistMu.RUnlock()
|
||||
list := make([]string, 0, len(s.blocklist))
|
||||
for domain := range s.blocklist {
|
||||
list = append(list, domain)
|
||||
}
|
||||
sort.Strings(list)
|
||||
return list
|
||||
}
|
||||
|
||||
// ImportBlocklist adds multiple domains at once.
|
||||
func (s *DNSServer) ImportBlocklist(domains []string) int {
|
||||
s.blocklistMu.Lock()
|
||||
defer s.blocklistMu.Unlock()
|
||||
count := 0
|
||||
for _, d := range domains {
|
||||
d = strings.TrimSpace(strings.ToLower(d))
|
||||
if d == "" || strings.HasPrefix(d, "#") {
|
||||
continue
|
||||
}
|
||||
s.blocklist[dns.Fqdn(d)] = true
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// SetConditionalForward sets upstream servers for a specific zone.
|
||||
func (s *DNSServer) SetConditionalForward(zone string, upstreams []string) {
|
||||
s.conditionalFwdMu.Lock()
|
||||
s.conditionalFwd[dns.Fqdn(strings.ToLower(zone))] = upstreams
|
||||
s.conditionalFwdMu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveConditionalForward removes conditional forwarding for a zone.
|
||||
func (s *DNSServer) RemoveConditionalForward(zone string) {
|
||||
s.conditionalFwdMu.Lock()
|
||||
delete(s.conditionalFwd, dns.Fqdn(strings.ToLower(zone)))
|
||||
s.conditionalFwdMu.Unlock()
|
||||
}
|
||||
|
||||
// GetConditionalForwards returns all conditional forwarding rules.
|
||||
func (s *DNSServer) GetConditionalForwards() map[string][]string {
|
||||
s.conditionalFwdMu.RLock()
|
||||
defer s.conditionalFwdMu.RUnlock()
|
||||
result := make(map[string][]string, len(s.conditionalFwd))
|
||||
for k, v := range s.conditionalFwd {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetResolverNSCache returns the resolver's NS delegation cache.
|
||||
func (s *DNSServer) GetResolverNSCache() map[string][]string {
|
||||
return s.resolver.GetNSCache()
|
||||
}
|
||||
|
||||
func (s *DNSServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
|
||||
start := time.Now()
|
||||
atomic.AddUint64(&s.metrics.TotalQueries, 1)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetReply(r)
|
||||
msg.Authoritative = false
|
||||
msg.RecursionAvailable = true
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
msg.Rcode = dns.RcodeFormatError
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
qName := q.Name
|
||||
qTypeStr := dns.TypeToString[q.Qtype]
|
||||
clientAddr := w.RemoteAddr().String()
|
||||
|
||||
// Track stats
|
||||
s.trackDomain(qName)
|
||||
s.trackType(qTypeStr)
|
||||
s.trackClient(clientAddr)
|
||||
|
||||
if s.cfg.LogQueries {
|
||||
log.Printf("[query] %s %s from %s", qTypeStr, qName, clientAddr)
|
||||
}
|
||||
|
||||
// Security: Refuse ANY queries (DNS amplification protection)
|
||||
if s.cfg.RefuseANY && q.Qtype == dns.TypeANY {
|
||||
msg.Rcode = dns.RcodeNotImplemented
|
||||
atomic.AddUint64(&s.metrics.FailedQ, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "NOTIMPL", 0, time.Since(start), "blocked")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Block zone transfer requests (AXFR/IXFR)
|
||||
if q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR {
|
||||
msg.Rcode = dns.RcodeRefused
|
||||
atomic.AddUint64(&s.metrics.FailedQ, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Minimal responses — don't expose server info
|
||||
if s.cfg.MinimalResponses {
|
||||
if q.Qtype == dns.TypeTXT && (qName == "version.bind." || qName == "hostname.bind." || qName == "version.server.") {
|
||||
msg.Rcode = dns.RcodeRefused
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Blocklist check
|
||||
if s.isBlocked(qName) {
|
||||
msg.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
atomic.AddUint64(&s.metrics.BlockedQ, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "NXDOMAIN", 0, time.Since(start), "blocked")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
// 1a. Check hosts file
|
||||
hostsAnswers := s.hosts.Lookup(qName, q.Qtype)
|
||||
if len(hostsAnswers) > 0 {
|
||||
msg.Authoritative = true
|
||||
msg.Answer = hostsAnswers
|
||||
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(hostsAnswers), time.Since(start), "hosts")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
// 1b. Check local zones
|
||||
answers := s.store.Lookup(qName, q.Qtype)
|
||||
if len(answers) > 0 {
|
||||
msg.Authoritative = true
|
||||
msg.Answer = answers
|
||||
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(answers), time.Since(start), "local")
|
||||
w.WriteMsg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Check cache
|
||||
cacheKey := cacheKeyFor(q)
|
||||
if cached := s.getCached(cacheKey); cached != nil {
|
||||
cached.SetReply(r)
|
||||
atomic.AddUint64(&s.metrics.CacheHits, 1)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[cached.Rcode], len(cached.Answer), time.Since(start), "cache")
|
||||
w.WriteMsg(cached)
|
||||
return
|
||||
}
|
||||
atomic.AddUint64(&s.metrics.CacheMisses, 1)
|
||||
|
||||
// 3. Check conditional forwarding
|
||||
if fwdServers := s.getConditionalForward(qName); fwdServers != nil {
|
||||
c := &dns.Client{Timeout: 5 * time.Second}
|
||||
for _, srv := range fwdServers {
|
||||
resp, _, err := c.Exchange(r, srv)
|
||||
if err == nil && resp != nil {
|
||||
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
|
||||
s.putCache(cacheKey, resp)
|
||||
resp.SetReply(r)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "conditional")
|
||||
w.WriteMsg(resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Recursive resolution from root hints (with optional upstream fallback)
|
||||
resp := s.resolver.ResolveWithFallback(r, s.cfg.Upstream)
|
||||
if resp != nil {
|
||||
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
|
||||
s.putCache(cacheKey, resp)
|
||||
resp.SetReply(r)
|
||||
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "recursive")
|
||||
w.WriteMsg(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. SERVFAIL
|
||||
atomic.AddUint64(&s.metrics.FailedQ, 1)
|
||||
msg.Rcode = dns.RcodeServerFailure
|
||||
s.logQuery(clientAddr, qName, qTypeStr, "SERVFAIL", 0, time.Since(start), "failed")
|
||||
w.WriteMsg(msg)
|
||||
}
|
||||
|
||||
// ── Blocklist ────────────────────────────────────────────────────────
|
||||
|
||||
func (s *DNSServer) isBlocked(name string) bool {
|
||||
s.blocklistMu.RLock()
|
||||
defer s.blocklistMu.RUnlock()
|
||||
|
||||
fqdn := dns.Fqdn(strings.ToLower(name))
|
||||
// Exact match
|
||||
if s.blocklist[fqdn] {
|
||||
return true
|
||||
}
|
||||
// Wildcard: check parent domains
|
||||
labels := dns.SplitDomainName(fqdn)
|
||||
for i := 1; i < len(labels); i++ {
|
||||
parent := dns.Fqdn(strings.Join(labels[i:], "."))
|
||||
if s.blocklist[parent] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ── Conditional forwarding ───────────────────────────────────────────
|
||||
|
||||
func (s *DNSServer) getConditionalForward(name string) []string {
|
||||
s.conditionalFwdMu.RLock()
|
||||
defer s.conditionalFwdMu.RUnlock()
|
||||
|
||||
fqdn := dns.Fqdn(strings.ToLower(name))
|
||||
labels := dns.SplitDomainName(fqdn)
|
||||
for i := 0; i < len(labels); i++ {
|
||||
zone := dns.Fqdn(strings.Join(labels[i:], "."))
|
||||
if servers, ok := s.conditionalFwd[zone]; ok {
|
||||
return servers
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── Tracking ─────────────────────────────────────────────────────────
|
||||
|
||||
func (s *DNSServer) trackDomain(name string) {
|
||||
s.domainCountsMu.Lock()
|
||||
s.domainCounts[name]++
|
||||
s.domainCountsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *DNSServer) trackType(qtype string) {
|
||||
s.typeCountsMu.Lock()
|
||||
s.typeCounts[qtype]++
|
||||
s.typeCountsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *DNSServer) trackClient(addr string) {
|
||||
// Strip port
|
||||
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
||||
addr = addr[:idx]
|
||||
}
|
||||
s.clientCountsMu.Lock()
|
||||
s.clientCounts[addr]++
|
||||
s.clientCountsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *DNSServer) logQuery(client, name, qtype, rcode string, answers int, latency time.Duration, source string) {
|
||||
entry := QueryLogEntry{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Client: client,
|
||||
Name: name,
|
||||
Type: qtype,
|
||||
Rcode: rcode,
|
||||
Answers: answers,
|
||||
Latency: latency.String(),
|
||||
Source: source,
|
||||
}
|
||||
|
||||
s.queryLogMu.Lock()
|
||||
if len(s.queryLog) >= s.queryLogMax {
|
||||
// Shift: remove oldest 10%
|
||||
trim := s.queryLogMax / 10
|
||||
copy(s.queryLog, s.queryLog[trim:])
|
||||
s.queryLog = s.queryLog[:len(s.queryLog)-trim]
|
||||
}
|
||||
s.queryLog = append(s.queryLog, entry)
|
||||
s.queryLogMu.Unlock()
|
||||
}
|
||||
|
||||
// ── Cache ────────────────────────────────────────────────────────────
|
||||
|
||||
func cacheKeyFor(q dns.Question) string {
|
||||
return q.Name + "/" + dns.TypeToString[q.Qtype]
|
||||
}
|
||||
|
||||
func (s *DNSServer) getCached(key string) *dns.Msg {
|
||||
s.cacheMu.RLock()
|
||||
defer s.cacheMu.RUnlock()
|
||||
entry, ok := s.cache[key]
|
||||
if !ok || time.Now().After(entry.expiresAt) {
|
||||
return nil
|
||||
}
|
||||
return entry.msg.Copy()
|
||||
}
|
||||
|
||||
func (s *DNSServer) putCache(key string, msg *dns.Msg) {
|
||||
ttl := time.Duration(s.cfg.CacheTTL) * time.Second
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
s.cacheMu.Lock()
|
||||
s.cache[key] = &CacheEntry{msg: msg.Copy(), expiresAt: time.Now().Add(ttl)}
|
||||
s.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *DNSServer) cacheCleanup() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.cacheMu.Lock()
|
||||
now := time.Now()
|
||||
for k, v := range s.cache {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.cache, k)
|
||||
}
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
349
services/dns-server/server/hosts.go
Normal file
349
services/dns-server/server/hosts.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// HostEntry represents a single hosts file entry.
|
||||
type HostEntry struct {
|
||||
IP string `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// HostsStore manages a hosts-file-like database.
|
||||
type HostsStore struct {
|
||||
mu sync.RWMutex
|
||||
entries []HostEntry
|
||||
path string // path to hosts file on disk (if loaded from file)
|
||||
}
|
||||
|
||||
// NewHostsStore creates a new hosts store.
|
||||
func NewHostsStore() *HostsStore {
|
||||
return &HostsStore{
|
||||
entries: make([]HostEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFile parses a hosts file from disk.
|
||||
func (h *HostsStore) LoadFile(path string) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.path = path
|
||||
h.entries = h.entries[:0]
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip inline comments
|
||||
comment := ""
|
||||
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||
comment = strings.TrimSpace(line[idx+1:])
|
||||
line = strings.TrimSpace(line[:idx])
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := fields[0]
|
||||
if net.ParseIP(ip) == nil {
|
||||
continue // invalid IP
|
||||
}
|
||||
|
||||
entry := HostEntry{
|
||||
IP: ip,
|
||||
Hostname: strings.ToLower(fields[1]),
|
||||
Comment: comment,
|
||||
}
|
||||
if len(fields) > 2 {
|
||||
aliases := make([]string, len(fields)-2)
|
||||
for i, a := range fields[2:] {
|
||||
aliases[i] = strings.ToLower(a)
|
||||
}
|
||||
entry.Aliases = aliases
|
||||
}
|
||||
h.entries = append(h.entries, entry)
|
||||
}
|
||||
|
||||
log.Printf("[hosts] Loaded %d entries from %s", len(h.entries), path)
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// LoadFromText parses hosts-format text (like pasting /etc/hosts content).
|
||||
func (h *HostsStore) LoadFromText(content string) int {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
comment := ""
|
||||
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||
comment = strings.TrimSpace(line[idx+1:])
|
||||
line = strings.TrimSpace(line[:idx])
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := fields[0]
|
||||
if net.ParseIP(ip) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
entry := HostEntry{
|
||||
IP: ip,
|
||||
Hostname: strings.ToLower(fields[1]),
|
||||
Comment: comment,
|
||||
}
|
||||
if len(fields) > 2 {
|
||||
aliases := make([]string, len(fields)-2)
|
||||
for i, a := range fields[2:] {
|
||||
aliases[i] = strings.ToLower(a)
|
||||
}
|
||||
entry.Aliases = aliases
|
||||
}
|
||||
|
||||
// Dedup by hostname
|
||||
found := false
|
||||
for i, e := range h.entries {
|
||||
if e.Hostname == entry.Hostname {
|
||||
h.entries[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
h.entries = append(h.entries, entry)
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// Add adds a single host entry.
|
||||
func (h *HostsStore) Add(ip, hostname string, aliases []string, comment string) error {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
hostname = strings.ToLower(strings.TrimSpace(hostname))
|
||||
if hostname == "" {
|
||||
return fmt.Errorf("hostname required")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Check for duplicate
|
||||
for i, e := range h.entries {
|
||||
if e.Hostname == hostname {
|
||||
h.entries[i].IP = ip
|
||||
h.entries[i].Aliases = aliases
|
||||
h.entries[i].Comment = comment
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
h.entries = append(h.entries, HostEntry{
|
||||
IP: ip,
|
||||
Hostname: hostname,
|
||||
Aliases: aliases,
|
||||
Comment: comment,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove removes a host entry by hostname.
|
||||
func (h *HostsStore) Remove(hostname string) bool {
|
||||
hostname = strings.ToLower(strings.TrimSpace(hostname))
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for i, e := range h.entries {
|
||||
if e.Hostname == hostname {
|
||||
h.entries = append(h.entries[:i], h.entries[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Clear removes all entries.
|
||||
func (h *HostsStore) Clear() int {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
n := len(h.entries)
|
||||
h.entries = h.entries[:0]
|
||||
return n
|
||||
}
|
||||
|
||||
// List returns all entries.
|
||||
func (h *HostsStore) List() []HostEntry {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
result := make([]HostEntry, len(h.entries))
|
||||
copy(result, h.entries)
|
||||
return result
|
||||
}
|
||||
|
||||
// Count returns the number of entries.
|
||||
func (h *HostsStore) Count() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.entries)
|
||||
}
|
||||
|
||||
// Lookup resolves a hostname from the hosts store.
|
||||
// Returns DNS RRs matching the query name and type.
|
||||
func (h *HostsStore) Lookup(name string, qtype uint16) []dns.RR {
|
||||
if qtype != dns.TypeA && qtype != dns.TypeAAAA && qtype != dns.TypePTR {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
fqdn := dns.Fqdn(strings.ToLower(name))
|
||||
baseName := strings.TrimSuffix(fqdn, ".")
|
||||
|
||||
// PTR lookup (reverse DNS)
|
||||
if qtype == dns.TypePTR {
|
||||
// Convert in-addr.arpa name to IP
|
||||
ip := ptrToIP(fqdn)
|
||||
if ip == "" {
|
||||
return nil
|
||||
}
|
||||
for _, e := range h.entries {
|
||||
if e.IP == ip {
|
||||
rr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Ptr: dns.Fqdn(e.Hostname),
|
||||
}
|
||||
return []dns.RR{rr}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward lookup (A / AAAA)
|
||||
var results []dns.RR
|
||||
for _, e := range h.entries {
|
||||
// Match hostname or aliases
|
||||
match := strings.EqualFold(e.Hostname, baseName) || strings.EqualFold(dns.Fqdn(e.Hostname), fqdn)
|
||||
if !match {
|
||||
for _, a := range e.Aliases {
|
||||
if strings.EqualFold(a, baseName) || strings.EqualFold(dns.Fqdn(a), fqdn) {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(e.IP)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if qtype == dns.TypeA && ip.To4() != nil {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
results = append(results, rr)
|
||||
} else if qtype == dns.TypeAAAA && ip.To4() == nil {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
AAAA: ip,
|
||||
}
|
||||
results = append(results, rr)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// Export returns hosts file format text.
|
||||
func (h *HostsStore) Export() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# AUTARCH DNS hosts file\n")
|
||||
sb.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339)))
|
||||
sb.WriteString("# Entries: " + fmt.Sprintf("%d", len(h.entries)) + "\n\n")
|
||||
|
||||
for _, e := range h.entries {
|
||||
line := e.IP + "\t" + e.Hostname
|
||||
for _, a := range e.Aliases {
|
||||
line += "\t" + a
|
||||
}
|
||||
if e.Comment != "" {
|
||||
line += "\t# " + e.Comment
|
||||
}
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ptrToIP converts a PTR domain name (in-addr.arpa) to an IP string.
|
||||
func ptrToIP(name string) string {
|
||||
name = strings.TrimSuffix(strings.ToLower(name), ".")
|
||||
if !strings.HasSuffix(name, ".in-addr.arpa") {
|
||||
return ""
|
||||
}
|
||||
name = strings.TrimSuffix(name, ".in-addr.arpa")
|
||||
parts := strings.Split(name, ".")
|
||||
if len(parts) != 4 {
|
||||
return ""
|
||||
}
|
||||
// Reverse the octets
|
||||
return parts[3] + "." + parts[2] + "." + parts[1] + "." + parts[0]
|
||||
}
|
||||
528
services/dns-server/server/resolver.go
Normal file
528
services/dns-server/server/resolver.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Root nameserver IPs (IANA root hints).
|
||||
// These are hardcoded — they almost never change.
|
||||
var rootServers = []string{
|
||||
"198.41.0.4:53", // a.root-servers.net
|
||||
"170.247.170.2:53", // b.root-servers.net
|
||||
"192.33.4.12:53", // c.root-servers.net
|
||||
"199.7.91.13:53", // d.root-servers.net
|
||||
"192.203.230.10:53", // e.root-servers.net
|
||||
"192.5.5.241:53", // f.root-servers.net
|
||||
"192.112.36.4:53", // g.root-servers.net
|
||||
"198.97.190.53:53", // h.root-servers.net
|
||||
"192.36.148.17:53", // i.root-servers.net
|
||||
"192.58.128.30:53", // j.root-servers.net
|
||||
"193.0.14.129:53", // k.root-servers.net
|
||||
"199.7.83.42:53", // l.root-servers.net
|
||||
"202.12.27.33:53", // m.root-servers.net
|
||||
}
|
||||
|
||||
// Well-known DoH endpoints — when user configures these as upstream,
|
||||
// we auto-detect and use DoH instead of plain DNS.
|
||||
var knownDoHEndpoints = map[string]string{
|
||||
"8.8.8.8": "https://dns.google/dns-query",
|
||||
"8.8.4.4": "https://dns.google/dns-query",
|
||||
"1.1.1.1": "https://cloudflare-dns.com/dns-query",
|
||||
"1.0.0.1": "https://cloudflare-dns.com/dns-query",
|
||||
"9.9.9.9": "https://dns.quad9.net/dns-query",
|
||||
"149.112.112.112": "https://dns.quad9.net/dns-query",
|
||||
"208.67.222.222": "https://dns.opendns.com/dns-query",
|
||||
"208.67.220.220": "https://dns.opendns.com/dns-query",
|
||||
"94.140.14.14": "https://dns.adguard-dns.com/dns-query",
|
||||
"94.140.15.15": "https://dns.adguard-dns.com/dns-query",
|
||||
}
|
||||
|
||||
// Well-known DoT servers — port 853 TLS.
|
||||
var knownDoTServers = map[string]string{
|
||||
"8.8.8.8": "dns.google",
|
||||
"8.8.4.4": "dns.google",
|
||||
"1.1.1.1": "one.one.one.one",
|
||||
"1.0.0.1": "one.one.one.one",
|
||||
"9.9.9.9": "dns.quad9.net",
|
||||
"149.112.112.112": "dns.quad9.net",
|
||||
"208.67.222.222": "dns.opendns.com",
|
||||
"208.67.220.220": "dns.opendns.com",
|
||||
"94.140.14.14": "dns-unfiltered.adguard.com",
|
||||
"94.140.15.15": "dns-unfiltered.adguard.com",
|
||||
}
|
||||
|
||||
// EncryptionMode determines how upstream queries are sent.
|
||||
type EncryptionMode int
|
||||
|
||||
const (
|
||||
ModePlain EncryptionMode = iota // Standard UDP/TCP DNS
|
||||
ModeDoT // DNS-over-TLS (port 853)
|
||||
ModeDoH // DNS-over-HTTPS (RFC 8484)
|
||||
)
|
||||
|
||||
// RecursiveResolver performs iterative DNS resolution from root hints.
|
||||
type RecursiveResolver struct {
|
||||
// NS cache: zone -> list of nameserver IPs
|
||||
nsCache map[string][]string
|
||||
nsCacheMu sync.RWMutex
|
||||
|
||||
client *dns.Client
|
||||
dotClient *dns.Client // TLS client for DoT
|
||||
dohHTTP *http.Client
|
||||
maxDepth int
|
||||
timeout time.Duration
|
||||
|
||||
// Encryption settings
|
||||
EnableDoT bool
|
||||
EnableDoH bool
|
||||
}
|
||||
|
||||
// NewRecursiveResolver creates a resolver with root hints.
|
||||
func NewRecursiveResolver() *RecursiveResolver {
|
||||
return &RecursiveResolver{
|
||||
nsCache: make(map[string][]string),
|
||||
client: &dns.Client{Timeout: 4 * time.Second},
|
||||
dotClient: &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Timeout: 5 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
dohHTTP: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DisableCompression: false,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
},
|
||||
maxDepth: 20,
|
||||
timeout: 4 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve performs full iterative resolution for the given query message.
|
||||
// Returns the final authoritative response, or nil on failure.
|
||||
func (rr *RecursiveResolver) Resolve(req *dns.Msg) *dns.Msg {
|
||||
if len(req.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
q := req.Question[0]
|
||||
return rr.resolve(q.Name, q.Qtype, 0)
|
||||
}
|
||||
|
||||
func (rr *RecursiveResolver) resolve(name string, qtype uint16, depth int) *dns.Msg {
|
||||
if depth >= rr.maxDepth {
|
||||
log.Printf("[resolver] max depth reached for %s", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
name = dns.Fqdn(name)
|
||||
|
||||
// Find the best nameservers to start from.
|
||||
// Walk up the name to find cached NS records, fall back to root.
|
||||
nameservers := rr.findBestNS(name)
|
||||
|
||||
// Iterative resolution: keep querying NS servers until we get an answer
|
||||
for i := 0; i < rr.maxDepth; i++ {
|
||||
resp := rr.queryServers(nameservers, name, qtype)
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Got an authoritative answer or a final answer with records
|
||||
if resp.Authoritative && len(resp.Answer) > 0 {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Check if answer section has what we want (non-authoritative but valid)
|
||||
if len(resp.Answer) > 0 {
|
||||
hasTarget := false
|
||||
var cnameRR *dns.CNAME
|
||||
for _, ans := range resp.Answer {
|
||||
if ans.Header().Rrtype == qtype {
|
||||
hasTarget = true
|
||||
}
|
||||
if cn, ok := ans.(*dns.CNAME); ok && qtype != dns.TypeCNAME {
|
||||
cnameRR = cn
|
||||
}
|
||||
}
|
||||
if hasTarget {
|
||||
return resp
|
||||
}
|
||||
// Follow CNAME chain
|
||||
if cnameRR != nil {
|
||||
cResp := rr.resolve(cnameRR.Target, qtype, depth+1)
|
||||
if cResp != nil {
|
||||
// Prepend the CNAME to the answer
|
||||
cResp.Answer = append([]dns.RR{cnameRR}, cResp.Answer...)
|
||||
return cResp
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// NXDOMAIN — name doesn't exist
|
||||
if resp.Rcode == dns.RcodeNameError {
|
||||
return resp
|
||||
}
|
||||
|
||||
// NOERROR with no answer and no NS in authority = we're done
|
||||
if len(resp.Ns) == 0 && len(resp.Answer) == 0 {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Referral: extract NS records from authority section
|
||||
var newNS []string
|
||||
var nsNames []string
|
||||
for _, rr := range resp.Ns {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
nsNames = append(nsNames, ns.Ns)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nsNames) == 0 {
|
||||
// SOA in authority = negative response from authoritative server
|
||||
for _, rr := range resp.Ns {
|
||||
if _, ok := rr.(*dns.SOA); ok {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Try to get IPs from the additional section (glue records)
|
||||
glue := make(map[string]string)
|
||||
for _, rr := range resp.Extra {
|
||||
if a, ok := rr.(*dns.A); ok {
|
||||
glue[strings.ToLower(a.Hdr.Name)] = a.A.String() + ":53"
|
||||
}
|
||||
}
|
||||
|
||||
for _, nsName := range nsNames {
|
||||
key := strings.ToLower(dns.Fqdn(nsName))
|
||||
if ip, ok := glue[key]; ok {
|
||||
newNS = append(newNS, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// If no glue, resolve NS names ourselves
|
||||
if len(newNS) == 0 {
|
||||
for _, nsName := range nsNames {
|
||||
ips := rr.resolveNSName(nsName, depth+1)
|
||||
newNS = append(newNS, ips...)
|
||||
if len(newNS) >= 3 {
|
||||
break // Enough NS IPs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(newNS) == 0 {
|
||||
log.Printf("[resolver] no NS IPs found for delegation of %s", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cache the delegation
|
||||
zone := extractZone(resp.Ns)
|
||||
if zone != "" {
|
||||
rr.cacheNS(zone, newNS)
|
||||
}
|
||||
|
||||
nameservers = newNS
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveNSName resolves a nameserver hostname to its IP(s).
|
||||
func (rr *RecursiveResolver) resolveNSName(nsName string, depth int) []string {
|
||||
resp := rr.resolve(nsName, dns.TypeA, depth)
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
for _, ans := range resp.Answer {
|
||||
if a, ok := ans.(*dns.A); ok {
|
||||
ips = append(ips, a.A.String()+":53")
|
||||
}
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// queryServers sends a query to a list of nameservers, returns first valid response.
|
||||
func (rr *RecursiveResolver) queryServers(servers []string, name string, qtype uint16) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(name), qtype)
|
||||
msg.RecursionDesired = false // We're doing iterative resolution
|
||||
|
||||
for _, server := range servers {
|
||||
resp, _, err := rr.client.Exchange(msg, server)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if resp != nil {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
// Retry with TCP for truncated responses
|
||||
msg.RecursionDesired = false
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: rr.timeout}
|
||||
for _, server := range servers {
|
||||
resp, _, err := tcpClient.Exchange(msg, server)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if resp != nil {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// queryUpstreamDoT sends a query to an upstream server via DNS-over-TLS (port 853).
|
||||
func (rr *RecursiveResolver) QueryUpstreamDoT(req *dns.Msg, server string) (*dns.Msg, error) {
|
||||
// Extract IP from server address (may include :53)
|
||||
ip := server
|
||||
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
|
||||
// Get TLS server name for certificate validation
|
||||
serverName, ok := knownDoTServers[ip]
|
||||
if !ok {
|
||||
serverName = ip // Use IP as fallback (less secure, but works)
|
||||
}
|
||||
|
||||
dotAddr := ip + ":853"
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Timeout: 5 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
msg := req.Copy()
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, _, err := client.Exchange(msg, dotAddr)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// queryUpstreamDoH sends a query to an upstream server via DNS-over-HTTPS (RFC 8484).
|
||||
func (rr *RecursiveResolver) QueryUpstreamDoH(req *dns.Msg, server string) (*dns.Msg, error) {
|
||||
// Extract IP from server address
|
||||
ip := server
|
||||
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
|
||||
// Find the DoH endpoint URL
|
||||
endpoint, ok := knownDoHEndpoints[ip]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no DoH endpoint known for %s", ip)
|
||||
}
|
||||
|
||||
// Encode DNS message as wire format
|
||||
msg := req.Copy()
|
||||
msg.RecursionDesired = true
|
||||
wireMsg, err := msg.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pack DNS message: %w", err)
|
||||
}
|
||||
|
||||
// POST as application/dns-message (RFC 8484)
|
||||
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(wireMsg))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/dns-message")
|
||||
httpReq.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
httpResp, err := rr.dohHTTP.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoH request to %s: %w", endpoint, err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DoH response status %d from %s", httpResp.StatusCode, endpoint)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read DoH response: %w", err)
|
||||
}
|
||||
|
||||
resp := new(dns.Msg)
|
||||
if err := resp.Unpack(body); err != nil {
|
||||
return nil, fmt.Errorf("unpack DoH response: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// queryUpstreamEncrypted tries DoH first (if enabled), then DoT, then plain.
|
||||
func (rr *RecursiveResolver) queryUpstreamEncrypted(req *dns.Msg, server string) (*dns.Msg, string, error) {
|
||||
ip := server
|
||||
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
|
||||
// Try DoH if enabled and we know the endpoint
|
||||
if rr.EnableDoH {
|
||||
if _, ok := knownDoHEndpoints[ip]; ok {
|
||||
resp, err := rr.QueryUpstreamDoH(req, server)
|
||||
if err == nil && resp != nil {
|
||||
return resp, "doh", nil
|
||||
}
|
||||
log.Printf("[resolver] DoH failed for %s: %v, falling back", ip, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try DoT if enabled
|
||||
if rr.EnableDoT {
|
||||
resp, err := rr.QueryUpstreamDoT(req, server)
|
||||
if err == nil && resp != nil {
|
||||
return resp, "dot", nil
|
||||
}
|
||||
log.Printf("[resolver] DoT failed for %s: %v, falling back", ip, err)
|
||||
}
|
||||
|
||||
// Plain DNS fallback
|
||||
c := &dns.Client{Timeout: 5 * time.Second}
|
||||
resp, _, err := c.Exchange(req, server)
|
||||
if err != nil {
|
||||
return nil, "plain", err
|
||||
}
|
||||
return resp, "plain", nil
|
||||
}
|
||||
|
||||
// findBestNS finds the closest cached NS for the given name, or returns root servers.
|
||||
func (rr *RecursiveResolver) findBestNS(name string) []string {
|
||||
rr.nsCacheMu.RLock()
|
||||
defer rr.nsCacheMu.RUnlock()
|
||||
|
||||
// Walk up the domain name
|
||||
labels := dns.SplitDomainName(name)
|
||||
for i := 0; i < len(labels); i++ {
|
||||
zone := dns.Fqdn(strings.Join(labels[i:], "."))
|
||||
if ns, ok := rr.nsCache[zone]; ok && len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
}
|
||||
|
||||
return rootServers
|
||||
}
|
||||
|
||||
// cacheNS stores nameserver IPs for a zone.
|
||||
func (rr *RecursiveResolver) cacheNS(zone string, servers []string) {
|
||||
rr.nsCacheMu.Lock()
|
||||
rr.nsCache[dns.Fqdn(zone)] = servers
|
||||
rr.nsCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// extractZone gets the zone name from NS authority records.
|
||||
func extractZone(ns []dns.RR) string {
|
||||
for _, rr := range ns {
|
||||
if nsRR, ok := rr.(*dns.NS); ok {
|
||||
return nsRR.Hdr.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ResolveWithFallback tries recursive resolution, falls back to upstream forwarders.
|
||||
// Now with DoT/DoH encryption support for upstream queries.
|
||||
func (rr *RecursiveResolver) ResolveWithFallback(req *dns.Msg, upstream []string) *dns.Msg {
|
||||
// Try full recursive first
|
||||
resp := rr.Resolve(req)
|
||||
if resp != nil && resp.Rcode != dns.RcodeServerFailure {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Fallback to upstream forwarders if configured — use encrypted transport
|
||||
if len(upstream) > 0 {
|
||||
for _, us := range upstream {
|
||||
resp, mode, err := rr.queryUpstreamEncrypted(req, us)
|
||||
if err == nil && resp != nil {
|
||||
if mode != "plain" {
|
||||
log.Printf("[resolver] upstream %s answered via %s", us, mode)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// GetEncryptionStatus returns the current encryption mode info.
|
||||
func (rr *RecursiveResolver) GetEncryptionStatus() map[string]interface{} {
|
||||
status := map[string]interface{}{
|
||||
"dot_enabled": rr.EnableDoT,
|
||||
"doh_enabled": rr.EnableDoH,
|
||||
"dot_servers": knownDoTServers,
|
||||
"doh_servers": knownDoHEndpoints,
|
||||
}
|
||||
if rr.EnableDoH {
|
||||
status["preferred_mode"] = "doh"
|
||||
} else if rr.EnableDoT {
|
||||
status["preferred_mode"] = "dot"
|
||||
} else {
|
||||
status["preferred_mode"] = "plain"
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// FlushNSCache clears all cached NS delegations.
|
||||
func (rr *RecursiveResolver) FlushNSCache() {
|
||||
rr.nsCacheMu.Lock()
|
||||
rr.nsCache = make(map[string][]string)
|
||||
rr.nsCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// GetNSCache returns a copy of the NS delegation cache.
|
||||
func (rr *RecursiveResolver) GetNSCache() map[string][]string {
|
||||
rr.nsCacheMu.RLock()
|
||||
defer rr.nsCacheMu.RUnlock()
|
||||
result := make(map[string][]string, len(rr.nsCache))
|
||||
for k, v := range rr.nsCache {
|
||||
cp := make([]string, len(v))
|
||||
copy(cp, v)
|
||||
result[k] = cp
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// String returns resolver info for debugging.
|
||||
func (rr *RecursiveResolver) String() string {
|
||||
rr.nsCacheMu.RLock()
|
||||
defer rr.nsCacheMu.RUnlock()
|
||||
mode := "plain"
|
||||
if rr.EnableDoH {
|
||||
mode = "DoH"
|
||||
} else if rr.EnableDoT {
|
||||
mode = "DoT"
|
||||
}
|
||||
return fmt.Sprintf("RecursiveResolver{cached_zones=%d, max_depth=%d, mode=%s}", len(rr.nsCache), rr.maxDepth, mode)
|
||||
}
|
||||
525
services/dns-server/server/zones.go
Normal file
525
services/dns-server/server/zones.go
Normal file
@@ -0,0 +1,525 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// RecordType represents supported DNS record types.
|
||||
type RecordType string
|
||||
|
||||
const (
|
||||
TypeA RecordType = "A"
|
||||
TypeAAAA RecordType = "AAAA"
|
||||
TypeCNAME RecordType = "CNAME"
|
||||
TypeMX RecordType = "MX"
|
||||
TypeTXT RecordType = "TXT"
|
||||
TypeNS RecordType = "NS"
|
||||
TypeSRV RecordType = "SRV"
|
||||
TypePTR RecordType = "PTR"
|
||||
TypeSOA RecordType = "SOA"
|
||||
)
|
||||
|
||||
// Record is a single DNS record.
|
||||
type Record struct {
|
||||
ID string `json:"id"`
|
||||
Type RecordType `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
TTL uint32 `json:"ttl"`
|
||||
Priority uint16 `json:"priority,omitempty"` // MX, SRV
|
||||
Weight uint16 `json:"weight,omitempty"` // SRV
|
||||
Port uint16 `json:"port,omitempty"` // SRV
|
||||
}
|
||||
|
||||
// Zone represents a DNS zone with its records.
|
||||
type Zone struct {
|
||||
Domain string `json:"domain"`
|
||||
SOA SOARecord `json:"soa"`
|
||||
Records []Record `json:"records"`
|
||||
DNSSEC bool `json:"dnssec"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// SOARecord holds SOA-specific fields.
|
||||
type SOARecord struct {
|
||||
PrimaryNS string `json:"primary_ns"`
|
||||
AdminEmail string `json:"admin_email"`
|
||||
Serial uint32 `json:"serial"`
|
||||
Refresh uint32 `json:"refresh"`
|
||||
Retry uint32 `json:"retry"`
|
||||
Expire uint32 `json:"expire"`
|
||||
MinTTL uint32 `json:"min_ttl"`
|
||||
}
|
||||
|
||||
// ZoneStore manages zones on disk and in memory.
|
||||
type ZoneStore struct {
|
||||
mu sync.RWMutex
|
||||
zones map[string]*Zone
|
||||
zonesDir string
|
||||
}
|
||||
|
||||
// NewZoneStore creates a store backed by a directory.
|
||||
func NewZoneStore(dir string) *ZoneStore {
|
||||
os.MkdirAll(dir, 0755)
|
||||
return &ZoneStore{
|
||||
zones: make(map[string]*Zone),
|
||||
zonesDir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAll reads all zone files from disk.
|
||||
func (s *ZoneStore) LoadAll() error {
|
||||
entries, err := os.ReadDir(s.zonesDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) != ".json" {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(s.zonesDir, e.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var z Zone
|
||||
if err := json.Unmarshal(data, &z); err != nil {
|
||||
continue
|
||||
}
|
||||
s.zones[dns.Fqdn(z.Domain)] = &z
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save writes a zone to disk.
|
||||
func (s *ZoneStore) Save(z *Zone) error {
|
||||
z.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
data, err := json.MarshalIndent(z, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fname := filepath.Join(s.zonesDir, z.Domain+".json")
|
||||
return os.WriteFile(fname, data, 0644)
|
||||
}
|
||||
|
||||
// Get returns a zone by domain.
|
||||
func (s *ZoneStore) Get(domain string) *Zone {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.zones[dns.Fqdn(domain)]
|
||||
}
|
||||
|
||||
// List returns all zones.
|
||||
func (s *ZoneStore) List() []*Zone {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*Zone, 0, len(s.zones))
|
||||
for _, z := range s.zones {
|
||||
result = append(result, z)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Create adds a new zone.
|
||||
func (s *ZoneStore) Create(domain string) (*Zone, error) {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.zones[fqdn]; exists {
|
||||
return nil, fmt.Errorf("zone %s already exists", domain)
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
z := &Zone{
|
||||
Domain: domain,
|
||||
SOA: SOARecord{
|
||||
PrimaryNS: "ns1." + domain,
|
||||
AdminEmail: "admin." + domain,
|
||||
Serial: uint32(time.Now().Unix()),
|
||||
Refresh: 3600,
|
||||
Retry: 600,
|
||||
Expire: 86400,
|
||||
MinTTL: 300,
|
||||
},
|
||||
Records: []Record{
|
||||
{ID: "ns1", Type: TypeNS, Name: domain + ".", Value: "ns1." + domain + ".", TTL: 3600},
|
||||
},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
s.zones[fqdn] = z
|
||||
return z, s.Save(z)
|
||||
}
|
||||
|
||||
// Delete removes a zone.
|
||||
func (s *ZoneStore) Delete(domain string) error {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.zones[fqdn]; !exists {
|
||||
return fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
delete(s.zones, fqdn)
|
||||
fname := filepath.Join(s.zonesDir, domain+".json")
|
||||
os.Remove(fname)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRecord adds a record to a zone.
|
||||
func (s *ZoneStore) AddRecord(domain string, rec Record) error {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
z, ok := s.zones[fqdn]
|
||||
if !ok {
|
||||
return fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
|
||||
if rec.ID == "" {
|
||||
rec.ID = fmt.Sprintf("r%d", time.Now().UnixNano())
|
||||
}
|
||||
if rec.TTL == 0 {
|
||||
rec.TTL = 300
|
||||
}
|
||||
|
||||
z.Records = append(z.Records, rec)
|
||||
z.SOA.Serial++
|
||||
return s.Save(z)
|
||||
}
|
||||
|
||||
// DeleteRecord removes a record by ID.
|
||||
func (s *ZoneStore) DeleteRecord(domain, recordID string) error {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
z, ok := s.zones[fqdn]
|
||||
if !ok {
|
||||
return fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
|
||||
for i, r := range z.Records {
|
||||
if r.ID == recordID {
|
||||
z.Records = append(z.Records[:i], z.Records[i+1:]...)
|
||||
z.SOA.Serial++
|
||||
return s.Save(z)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("record %s not found", recordID)
|
||||
}
|
||||
|
||||
// UpdateRecord updates a record by ID.
|
||||
func (s *ZoneStore) UpdateRecord(domain, recordID string, rec Record) error {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
z, ok := s.zones[fqdn]
|
||||
if !ok {
|
||||
return fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
|
||||
for i, r := range z.Records {
|
||||
if r.ID == recordID {
|
||||
rec.ID = recordID
|
||||
z.Records[i] = rec
|
||||
z.SOA.Serial++
|
||||
return s.Save(z)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("record %s not found", recordID)
|
||||
}
|
||||
|
||||
// Lookup finds records matching a query name and type within all zones.
|
||||
func (s *ZoneStore) Lookup(name string, qtype uint16) []dns.RR {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
fqdn := dns.Fqdn(name)
|
||||
var results []dns.RR
|
||||
|
||||
// Find the zone for this name
|
||||
for zoneDomain, z := range s.zones {
|
||||
if !dns.IsSubDomain(zoneDomain, fqdn) {
|
||||
continue
|
||||
}
|
||||
// Check records
|
||||
for _, rec := range z.Records {
|
||||
recFQDN := dns.Fqdn(rec.Name)
|
||||
if recFQDN != fqdn {
|
||||
continue
|
||||
}
|
||||
if rr := recordToRR(rec, fqdn); rr != nil {
|
||||
if qtype == dns.TypeANY || rr.Header().Rrtype == qtype {
|
||||
results = append(results, rr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// SOA for zone apex
|
||||
if fqdn == zoneDomain && (qtype == dns.TypeSOA || qtype == dns.TypeANY) {
|
||||
soa := &dns.SOA{
|
||||
Hdr: dns.RR_Header{Name: zoneDomain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: z.SOA.MinTTL},
|
||||
Ns: dns.Fqdn(z.SOA.PrimaryNS),
|
||||
Mbox: dns.Fqdn(z.SOA.AdminEmail),
|
||||
Serial: z.SOA.Serial,
|
||||
Refresh: z.SOA.Refresh,
|
||||
Retry: z.SOA.Retry,
|
||||
Expire: z.SOA.Expire,
|
||||
Minttl: z.SOA.MinTTL,
|
||||
}
|
||||
results = append(results, soa)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func recordToRR(rec Record, fqdn string) dns.RR {
|
||||
hdr := dns.RR_Header{Name: fqdn, Class: dns.ClassINET, Ttl: rec.TTL}
|
||||
|
||||
switch rec.Type {
|
||||
case TypeA:
|
||||
hdr.Rrtype = dns.TypeA
|
||||
rr := &dns.A{Hdr: hdr}
|
||||
rr.A = parseIP(rec.Value)
|
||||
if rr.A == nil {
|
||||
return nil
|
||||
}
|
||||
return rr
|
||||
case TypeAAAA:
|
||||
hdr.Rrtype = dns.TypeAAAA
|
||||
rr := &dns.AAAA{Hdr: hdr}
|
||||
rr.AAAA = parseIP(rec.Value)
|
||||
if rr.AAAA == nil {
|
||||
return nil
|
||||
}
|
||||
return rr
|
||||
case TypeCNAME:
|
||||
hdr.Rrtype = dns.TypeCNAME
|
||||
return &dns.CNAME{Hdr: hdr, Target: dns.Fqdn(rec.Value)}
|
||||
case TypeMX:
|
||||
hdr.Rrtype = dns.TypeMX
|
||||
return &dns.MX{Hdr: hdr, Preference: rec.Priority, Mx: dns.Fqdn(rec.Value)}
|
||||
case TypeTXT:
|
||||
hdr.Rrtype = dns.TypeTXT
|
||||
return &dns.TXT{Hdr: hdr, Txt: []string{rec.Value}}
|
||||
case TypeNS:
|
||||
hdr.Rrtype = dns.TypeNS
|
||||
return &dns.NS{Hdr: hdr, Ns: dns.Fqdn(rec.Value)}
|
||||
case TypeSRV:
|
||||
hdr.Rrtype = dns.TypeSRV
|
||||
return &dns.SRV{Hdr: hdr, Priority: rec.Priority, Weight: rec.Weight, Port: rec.Port, Target: dns.Fqdn(rec.Value)}
|
||||
case TypePTR:
|
||||
hdr.Rrtype = dns.TypePTR
|
||||
return &dns.PTR{Hdr: hdr, Ptr: dns.Fqdn(rec.Value)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIP(s string) net.IP {
|
||||
return net.ParseIP(s)
|
||||
}
|
||||
|
||||
// ExportZoneFile exports a zone in BIND zone file format.
|
||||
func (s *ZoneStore) ExportZoneFile(domain string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
z, ok := s.zones[dns.Fqdn(domain)]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("; Zone file for %s\n", z.Domain))
|
||||
b.WriteString(fmt.Sprintf("; Exported at %s\n", time.Now().UTC().Format(time.RFC3339)))
|
||||
b.WriteString(fmt.Sprintf("$ORIGIN %s.\n", z.Domain))
|
||||
b.WriteString(fmt.Sprintf("$TTL %d\n\n", z.SOA.MinTTL))
|
||||
|
||||
// SOA
|
||||
b.WriteString(fmt.Sprintf("@ IN SOA %s. %s. (\n", z.SOA.PrimaryNS, z.SOA.AdminEmail))
|
||||
b.WriteString(fmt.Sprintf(" %d ; serial\n", z.SOA.Serial))
|
||||
b.WriteString(fmt.Sprintf(" %d ; refresh\n", z.SOA.Refresh))
|
||||
b.WriteString(fmt.Sprintf(" %d ; retry\n", z.SOA.Retry))
|
||||
b.WriteString(fmt.Sprintf(" %d ; expire\n", z.SOA.Expire))
|
||||
b.WriteString(fmt.Sprintf(" %d ; minimum TTL\n)\n\n", z.SOA.MinTTL))
|
||||
|
||||
// Records grouped by type
|
||||
for _, rec := range z.Records {
|
||||
name := rec.Name
|
||||
// Make relative to origin
|
||||
suffix := "." + z.Domain + "."
|
||||
if strings.HasSuffix(name, suffix) {
|
||||
name = strings.TrimSuffix(name, suffix)
|
||||
} else if name == z.Domain+"." {
|
||||
name = "@"
|
||||
}
|
||||
|
||||
switch rec.Type {
|
||||
case TypeMX:
|
||||
b.WriteString(fmt.Sprintf("%-24s %d IN MX %d %s\n", name, rec.TTL, rec.Priority, rec.Value))
|
||||
case TypeSRV:
|
||||
b.WriteString(fmt.Sprintf("%-24s %d IN SRV %d %d %d %s\n", name, rec.TTL, rec.Priority, rec.Weight, rec.Port, rec.Value))
|
||||
default:
|
||||
b.WriteString(fmt.Sprintf("%-24s %d IN %-6s %s\n", name, rec.TTL, rec.Type, rec.Value))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// ImportZoneFile parses a BIND-style zone file and adds records.
|
||||
// Returns number of records added.
|
||||
func (s *ZoneStore) ImportZoneFile(domain, content string) (int, error) {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
z, ok := s.zones[fqdn]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("zone %s not found — create it first", domain)
|
||||
}
|
||||
|
||||
added := 0
|
||||
zp := dns.NewZoneParser(strings.NewReader(content), dns.Fqdn(domain), "")
|
||||
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
|
||||
hdr := rr.Header()
|
||||
rec := Record{
|
||||
ID: fmt.Sprintf("imp%d", time.Now().UnixNano()+int64(added)),
|
||||
Name: hdr.Name,
|
||||
TTL: hdr.Ttl,
|
||||
}
|
||||
|
||||
switch v := rr.(type) {
|
||||
case *dns.A:
|
||||
rec.Type = TypeA
|
||||
rec.Value = v.A.String()
|
||||
case *dns.AAAA:
|
||||
rec.Type = TypeAAAA
|
||||
rec.Value = v.AAAA.String()
|
||||
case *dns.CNAME:
|
||||
rec.Type = TypeCNAME
|
||||
rec.Value = v.Target
|
||||
case *dns.MX:
|
||||
rec.Type = TypeMX
|
||||
rec.Value = v.Mx
|
||||
rec.Priority = v.Preference
|
||||
case *dns.TXT:
|
||||
rec.Type = TypeTXT
|
||||
rec.Value = strings.Join(v.Txt, " ")
|
||||
case *dns.NS:
|
||||
rec.Type = TypeNS
|
||||
rec.Value = v.Ns
|
||||
case *dns.SRV:
|
||||
rec.Type = TypeSRV
|
||||
rec.Value = v.Target
|
||||
rec.Priority = v.Priority
|
||||
rec.Weight = v.Weight
|
||||
rec.Port = v.Port
|
||||
case *dns.PTR:
|
||||
rec.Type = TypePTR
|
||||
rec.Value = v.Ptr
|
||||
default:
|
||||
continue // Skip unsupported types
|
||||
}
|
||||
|
||||
z.Records = append(z.Records, rec)
|
||||
added++
|
||||
}
|
||||
|
||||
if added > 0 {
|
||||
z.SOA.Serial++
|
||||
s.Save(z)
|
||||
}
|
||||
return added, nil
|
||||
}
|
||||
|
||||
// CloneZone duplicates a zone under a new domain.
|
||||
func (s *ZoneStore) CloneZone(srcDomain, dstDomain string) (*Zone, error) {
|
||||
srcFQDN := dns.Fqdn(srcDomain)
|
||||
dstFQDN := dns.Fqdn(dstDomain)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
src, ok := s.zones[srcFQDN]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source zone %s not found", srcDomain)
|
||||
}
|
||||
if _, exists := s.zones[dstFQDN]; exists {
|
||||
return nil, fmt.Errorf("destination zone %s already exists", dstDomain)
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
z := &Zone{
|
||||
Domain: dstDomain,
|
||||
SOA: SOARecord{
|
||||
PrimaryNS: strings.Replace(src.SOA.PrimaryNS, srcDomain, dstDomain, -1),
|
||||
AdminEmail: strings.Replace(src.SOA.AdminEmail, srcDomain, dstDomain, -1),
|
||||
Serial: uint32(time.Now().Unix()),
|
||||
Refresh: src.SOA.Refresh,
|
||||
Retry: src.SOA.Retry,
|
||||
Expire: src.SOA.Expire,
|
||||
MinTTL: src.SOA.MinTTL,
|
||||
},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
// Clone records, replacing domain references
|
||||
for _, rec := range src.Records {
|
||||
newRec := rec
|
||||
newRec.ID = fmt.Sprintf("c%d", time.Now().UnixNano())
|
||||
newRec.Name = strings.Replace(rec.Name, srcDomain, dstDomain, -1)
|
||||
newRec.Value = strings.Replace(rec.Value, srcDomain, dstDomain, -1)
|
||||
z.Records = append(z.Records, newRec)
|
||||
time.Sleep(time.Nanosecond) // Ensure unique IDs
|
||||
}
|
||||
|
||||
s.zones[dstFQDN] = z
|
||||
return z, s.Save(z)
|
||||
}
|
||||
|
||||
// BulkAddRecords adds multiple records at once.
|
||||
func (s *ZoneStore) BulkAddRecords(domain string, records []Record) (int, error) {
|
||||
fqdn := dns.Fqdn(domain)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
z, ok := s.zones[fqdn]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("zone %s not found", domain)
|
||||
}
|
||||
|
||||
added := 0
|
||||
for _, rec := range records {
|
||||
if rec.ID == "" {
|
||||
rec.ID = fmt.Sprintf("b%d", time.Now().UnixNano()+int64(added))
|
||||
}
|
||||
if rec.TTL == 0 {
|
||||
rec.TTL = 300
|
||||
}
|
||||
z.Records = append(z.Records, rec)
|
||||
added++
|
||||
}
|
||||
|
||||
if added > 0 {
|
||||
z.SOA.Serial++
|
||||
s.Save(z)
|
||||
}
|
||||
return added, nil
|
||||
}
|
||||
Reference in New Issue
Block a user