657 lines
18 KiB
Go
Raw Normal View History

package server
import (
"log"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/darkhal/autarch-dns/config"
"github.com/miekg/dns"
)
// Metrics holds query statistics.
type Metrics struct {
TotalQueries uint64 `json:"total_queries"`
CacheHits uint64 `json:"cache_hits"`
CacheMisses uint64 `json:"cache_misses"`
LocalAnswers uint64 `json:"local_answers"`
ResolvedQ uint64 `json:"resolved"`
BlockedQ uint64 `json:"blocked"`
FailedQ uint64 `json:"failed"`
StartTime string `json:"start_time"`
}
// QueryLogEntry records a single DNS query.
type QueryLogEntry struct {
Timestamp string `json:"timestamp"`
Client string `json:"client"`
Name string `json:"name"`
Type string `json:"type"`
Rcode string `json:"rcode"`
Answers int `json:"answers"`
Latency string `json:"latency"`
Source string `json:"source"` // "local", "cache", "recursive", "blocked", "failed"
}
// CacheEntry holds a cached DNS response.
type CacheEntry struct {
msg *dns.Msg
expiresAt time.Time
}
// CacheInfo is an exportable view of a cache entry.
type CacheInfo struct {
Key string `json:"key"`
Name string `json:"name"`
Type string `json:"type"`
TTL int `json:"ttl_remaining"`
Answers int `json:"answers"`
ExpiresAt string `json:"expires_at"`
}
// DomainCount tracks query frequency per domain.
type DomainCount struct {
Domain string `json:"domain"`
Count uint64 `json:"count"`
}
// DNSServer is the main DNS server.
type DNSServer struct {
cfg *config.Config
store *ZoneStore
hosts *HostsStore
resolver *RecursiveResolver
metrics Metrics
cache map[string]*CacheEntry
cacheMu sync.RWMutex
udpServ *dns.Server
tcpServ *dns.Server
// Query log — ring buffer
queryLog []QueryLogEntry
queryLogMu sync.RWMutex
queryLogMax int
// Domain frequency tracking
domainCounts map[string]uint64
domainCountsMu sync.RWMutex
// Query type tracking
typeCounts map[string]uint64
typeCountsMu sync.RWMutex
// Client tracking
clientCounts map[string]uint64
clientCountsMu sync.RWMutex
// Blocklist — fast lookup
blocklist map[string]bool
blocklistMu sync.RWMutex
// Conditional forwarding: zone -> upstream servers
conditionalFwd map[string][]string
conditionalFwdMu sync.RWMutex
}
// NewDNSServer creates a DNS server.
func NewDNSServer(cfg *config.Config, store *ZoneStore) *DNSServer {
resolver := NewRecursiveResolver()
resolver.EnableDoT = cfg.EnableDoT
resolver.EnableDoH = cfg.EnableDoH
logMax := cfg.QueryLogMax
if logMax <= 0 {
logMax = 1000
}
s := &DNSServer{
cfg: cfg,
store: store,
hosts: NewHostsStore(),
resolver: resolver,
cache: make(map[string]*CacheEntry),
queryLog: make([]QueryLogEntry, 0, logMax),
queryLogMax: logMax,
domainCounts: make(map[string]uint64),
typeCounts: make(map[string]uint64),
clientCounts: make(map[string]uint64),
blocklist: make(map[string]bool),
conditionalFwd: make(map[string][]string),
metrics: Metrics{
StartTime: time.Now().UTC().Format(time.RFC3339),
},
}
// Load blocklist from config
for _, pattern := range cfg.BlockList {
s.blocklist[dns.Fqdn(strings.ToLower(pattern))] = true
}
// Load hosts file if configured
if cfg.HostsFile != "" {
if err := s.hosts.LoadFile(cfg.HostsFile); err != nil {
log.Printf("[hosts] Warning: could not load hosts file %s: %v", cfg.HostsFile, err)
}
}
return s
}
// GetHosts returns the hosts store.
func (s *DNSServer) GetHosts() *HostsStore {
return s.hosts
}
// GetEncryptionStatus returns encryption info from the resolver.
func (s *DNSServer) GetEncryptionStatus() map[string]interface{} {
return s.resolver.GetEncryptionStatus()
}
// SetEncryption updates DoT/DoH settings on the resolver.
func (s *DNSServer) SetEncryption(dot, doh bool) {
s.resolver.EnableDoT = dot
s.resolver.EnableDoH = doh
s.cfg.EnableDoT = dot
s.cfg.EnableDoH = doh
}
// GetResolver returns the underlying recursive resolver.
func (s *DNSServer) GetResolver() *RecursiveResolver {
return s.resolver
}
// Start begins listening on UDP and TCP.
func (s *DNSServer) Start() error {
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
s.udpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "udp", Handler: mux}
s.tcpServ = &dns.Server{Addr: s.cfg.ListenDNS, Net: "tcp", Handler: mux}
errCh := make(chan error, 2)
go func() { errCh <- s.udpServ.ListenAndServe() }()
go func() { errCh <- s.tcpServ.ListenAndServe() }()
go s.cacheCleanup()
return <-errCh
}
// Stop shuts down both servers.
func (s *DNSServer) Stop() {
if s.udpServ != nil {
s.udpServ.Shutdown()
}
if s.tcpServ != nil {
s.tcpServ.Shutdown()
}
}
// GetMetrics returns current metrics.
func (s *DNSServer) GetMetrics() Metrics {
return Metrics{
TotalQueries: atomic.LoadUint64(&s.metrics.TotalQueries),
CacheHits: atomic.LoadUint64(&s.metrics.CacheHits),
CacheMisses: atomic.LoadUint64(&s.metrics.CacheMisses),
LocalAnswers: atomic.LoadUint64(&s.metrics.LocalAnswers),
ResolvedQ: atomic.LoadUint64(&s.metrics.ResolvedQ),
BlockedQ: atomic.LoadUint64(&s.metrics.BlockedQ),
FailedQ: atomic.LoadUint64(&s.metrics.FailedQ),
StartTime: s.metrics.StartTime,
}
}
// GetQueryLog returns the last N query log entries.
func (s *DNSServer) GetQueryLog(limit int) []QueryLogEntry {
s.queryLogMu.RLock()
defer s.queryLogMu.RUnlock()
n := len(s.queryLog)
if limit <= 0 || limit > n {
limit = n
}
// Return most recent first
result := make([]QueryLogEntry, limit)
for i := 0; i < limit; i++ {
result[i] = s.queryLog[n-1-i]
}
return result
}
// ClearQueryLog empties the log.
func (s *DNSServer) ClearQueryLog() {
s.queryLogMu.Lock()
s.queryLog = s.queryLog[:0]
s.queryLogMu.Unlock()
}
// GetCacheEntries returns all cache entries.
func (s *DNSServer) GetCacheEntries() []CacheInfo {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
now := time.Now()
entries := make([]CacheInfo, 0, len(s.cache))
for key, entry := range s.cache {
if now.After(entry.expiresAt) {
continue
}
parts := strings.SplitN(key, "/", 2)
name, qtype := key, ""
if len(parts) == 2 {
name, qtype = parts[0], parts[1]
}
entries = append(entries, CacheInfo{
Key: key,
Name: name,
Type: qtype,
TTL: int(entry.expiresAt.Sub(now).Seconds()),
Answers: len(entry.msg.Answer),
ExpiresAt: entry.expiresAt.Format(time.RFC3339),
})
}
return entries
}
// CacheSize returns number of active cache entries.
func (s *DNSServer) CacheSize() int {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
return len(s.cache)
}
// FlushCache clears all cached responses.
func (s *DNSServer) FlushCache() int {
s.cacheMu.Lock()
n := len(s.cache)
s.cache = make(map[string]*CacheEntry)
s.cacheMu.Unlock()
// Also flush resolver NS cache
s.resolver.FlushNSCache()
return n
}
// FlushCacheEntry removes a single cache entry.
func (s *DNSServer) FlushCacheEntry(key string) bool {
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
if _, ok := s.cache[key]; ok {
delete(s.cache, key)
return true
}
return false
}
// GetTopDomains returns the most-queried domains.
func (s *DNSServer) GetTopDomains(limit int) []DomainCount {
s.domainCountsMu.RLock()
defer s.domainCountsMu.RUnlock()
counts := make([]DomainCount, 0, len(s.domainCounts))
for domain, count := range s.domainCounts {
counts = append(counts, DomainCount{Domain: domain, Count: count})
}
sort.Slice(counts, func(i, j int) bool { return counts[i].Count > counts[j].Count })
if limit > 0 && limit < len(counts) {
counts = counts[:limit]
}
return counts
}
// GetQueryTypeCounts returns counts by query type.
func (s *DNSServer) GetQueryTypeCounts() map[string]uint64 {
s.typeCountsMu.RLock()
defer s.typeCountsMu.RUnlock()
result := make(map[string]uint64, len(s.typeCounts))
for k, v := range s.typeCounts {
result[k] = v
}
return result
}
// GetClientCounts returns counts by client IP.
func (s *DNSServer) GetClientCounts() map[string]uint64 {
s.clientCountsMu.RLock()
defer s.clientCountsMu.RUnlock()
result := make(map[string]uint64, len(s.clientCounts))
for k, v := range s.clientCounts {
result[k] = v
}
return result
}
// AddBlocklistEntry adds a domain to the blocklist.
func (s *DNSServer) AddBlocklistEntry(domain string) {
s.blocklistMu.Lock()
s.blocklist[dns.Fqdn(strings.ToLower(domain))] = true
s.blocklistMu.Unlock()
}
// RemoveBlocklistEntry removes a domain from the blocklist.
func (s *DNSServer) RemoveBlocklistEntry(domain string) {
s.blocklistMu.Lock()
delete(s.blocklist, dns.Fqdn(strings.ToLower(domain)))
s.blocklistMu.Unlock()
}
// GetBlocklist returns all blocked domains.
func (s *DNSServer) GetBlocklist() []string {
s.blocklistMu.RLock()
defer s.blocklistMu.RUnlock()
list := make([]string, 0, len(s.blocklist))
for domain := range s.blocklist {
list = append(list, domain)
}
sort.Strings(list)
return list
}
// ImportBlocklist adds multiple domains at once.
func (s *DNSServer) ImportBlocklist(domains []string) int {
s.blocklistMu.Lock()
defer s.blocklistMu.Unlock()
count := 0
for _, d := range domains {
d = strings.TrimSpace(strings.ToLower(d))
if d == "" || strings.HasPrefix(d, "#") {
continue
}
s.blocklist[dns.Fqdn(d)] = true
count++
}
return count
}
// SetConditionalForward sets upstream servers for a specific zone.
func (s *DNSServer) SetConditionalForward(zone string, upstreams []string) {
s.conditionalFwdMu.Lock()
s.conditionalFwd[dns.Fqdn(strings.ToLower(zone))] = upstreams
s.conditionalFwdMu.Unlock()
}
// RemoveConditionalForward removes conditional forwarding for a zone.
func (s *DNSServer) RemoveConditionalForward(zone string) {
s.conditionalFwdMu.Lock()
delete(s.conditionalFwd, dns.Fqdn(strings.ToLower(zone)))
s.conditionalFwdMu.Unlock()
}
// GetConditionalForwards returns all conditional forwarding rules.
func (s *DNSServer) GetConditionalForwards() map[string][]string {
s.conditionalFwdMu.RLock()
defer s.conditionalFwdMu.RUnlock()
result := make(map[string][]string, len(s.conditionalFwd))
for k, v := range s.conditionalFwd {
result[k] = v
}
return result
}
// GetResolverNSCache returns the resolver's NS delegation cache.
func (s *DNSServer) GetResolverNSCache() map[string][]string {
return s.resolver.GetNSCache()
}
func (s *DNSServer) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
start := time.Now()
atomic.AddUint64(&s.metrics.TotalQueries, 1)
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = false
msg.RecursionAvailable = true
if len(r.Question) == 0 {
msg.Rcode = dns.RcodeFormatError
w.WriteMsg(msg)
return
}
q := r.Question[0]
qName := q.Name
qTypeStr := dns.TypeToString[q.Qtype]
clientAddr := w.RemoteAddr().String()
// Track stats
s.trackDomain(qName)
s.trackType(qTypeStr)
s.trackClient(clientAddr)
if s.cfg.LogQueries {
log.Printf("[query] %s %s from %s", qTypeStr, qName, clientAddr)
}
// Security: Refuse ANY queries (DNS amplification protection)
if s.cfg.RefuseANY && q.Qtype == dns.TypeANY {
msg.Rcode = dns.RcodeNotImplemented
atomic.AddUint64(&s.metrics.FailedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOTIMPL", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// Security: Block zone transfer requests (AXFR/IXFR)
if q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR {
msg.Rcode = dns.RcodeRefused
atomic.AddUint64(&s.metrics.FailedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// Security: Minimal responses — don't expose server info
if s.cfg.MinimalResponses {
if q.Qtype == dns.TypeTXT && (qName == "version.bind." || qName == "hostname.bind." || qName == "version.server.") {
msg.Rcode = dns.RcodeRefused
s.logQuery(clientAddr, qName, qTypeStr, "REFUSED", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
}
// Blocklist check
if s.isBlocked(qName) {
msg.Rcode = dns.RcodeNameError // NXDOMAIN
atomic.AddUint64(&s.metrics.BlockedQ, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NXDOMAIN", 0, time.Since(start), "blocked")
w.WriteMsg(msg)
return
}
// 1a. Check hosts file
hostsAnswers := s.hosts.Lookup(qName, q.Qtype)
if len(hostsAnswers) > 0 {
msg.Authoritative = true
msg.Answer = hostsAnswers
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(hostsAnswers), time.Since(start), "hosts")
w.WriteMsg(msg)
return
}
// 1b. Check local zones
answers := s.store.Lookup(qName, q.Qtype)
if len(answers) > 0 {
msg.Authoritative = true
msg.Answer = answers
atomic.AddUint64(&s.metrics.LocalAnswers, 1)
s.logQuery(clientAddr, qName, qTypeStr, "NOERROR", len(answers), time.Since(start), "local")
w.WriteMsg(msg)
return
}
// 2. Check cache
cacheKey := cacheKeyFor(q)
if cached := s.getCached(cacheKey); cached != nil {
cached.SetReply(r)
atomic.AddUint64(&s.metrics.CacheHits, 1)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[cached.Rcode], len(cached.Answer), time.Since(start), "cache")
w.WriteMsg(cached)
return
}
atomic.AddUint64(&s.metrics.CacheMisses, 1)
// 3. Check conditional forwarding
if fwdServers := s.getConditionalForward(qName); fwdServers != nil {
c := &dns.Client{Timeout: 5 * time.Second}
for _, srv := range fwdServers {
resp, _, err := c.Exchange(r, srv)
if err == nil && resp != nil {
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
s.putCache(cacheKey, resp)
resp.SetReply(r)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "conditional")
w.WriteMsg(resp)
return
}
}
}
// 4. Recursive resolution from root hints (with optional upstream fallback)
resp := s.resolver.ResolveWithFallback(r, s.cfg.Upstream)
if resp != nil {
atomic.AddUint64(&s.metrics.ResolvedQ, 1)
s.putCache(cacheKey, resp)
resp.SetReply(r)
s.logQuery(clientAddr, qName, qTypeStr, dns.RcodeToString[resp.Rcode], len(resp.Answer), time.Since(start), "recursive")
w.WriteMsg(resp)
return
}
// 5. SERVFAIL
atomic.AddUint64(&s.metrics.FailedQ, 1)
msg.Rcode = dns.RcodeServerFailure
s.logQuery(clientAddr, qName, qTypeStr, "SERVFAIL", 0, time.Since(start), "failed")
w.WriteMsg(msg)
}
// ── Blocklist ────────────────────────────────────────────────────────
func (s *DNSServer) isBlocked(name string) bool {
s.blocklistMu.RLock()
defer s.blocklistMu.RUnlock()
fqdn := dns.Fqdn(strings.ToLower(name))
// Exact match
if s.blocklist[fqdn] {
return true
}
// Wildcard: check parent domains
labels := dns.SplitDomainName(fqdn)
for i := 1; i < len(labels); i++ {
parent := dns.Fqdn(strings.Join(labels[i:], "."))
if s.blocklist[parent] {
return true
}
}
return false
}
// ── Conditional forwarding ───────────────────────────────────────────
func (s *DNSServer) getConditionalForward(name string) []string {
s.conditionalFwdMu.RLock()
defer s.conditionalFwdMu.RUnlock()
fqdn := dns.Fqdn(strings.ToLower(name))
labels := dns.SplitDomainName(fqdn)
for i := 0; i < len(labels); i++ {
zone := dns.Fqdn(strings.Join(labels[i:], "."))
if servers, ok := s.conditionalFwd[zone]; ok {
return servers
}
}
return nil
}
// ── Tracking ─────────────────────────────────────────────────────────
func (s *DNSServer) trackDomain(name string) {
s.domainCountsMu.Lock()
s.domainCounts[name]++
s.domainCountsMu.Unlock()
}
func (s *DNSServer) trackType(qtype string) {
s.typeCountsMu.Lock()
s.typeCounts[qtype]++
s.typeCountsMu.Unlock()
}
func (s *DNSServer) trackClient(addr string) {
// Strip port
if idx := strings.LastIndex(addr, ":"); idx > 0 {
addr = addr[:idx]
}
s.clientCountsMu.Lock()
s.clientCounts[addr]++
s.clientCountsMu.Unlock()
}
func (s *DNSServer) logQuery(client, name, qtype, rcode string, answers int, latency time.Duration, source string) {
entry := QueryLogEntry{
Timestamp: time.Now().UTC().Format(time.RFC3339Nano),
Client: client,
Name: name,
Type: qtype,
Rcode: rcode,
Answers: answers,
Latency: latency.String(),
Source: source,
}
s.queryLogMu.Lock()
if len(s.queryLog) >= s.queryLogMax {
// Shift: remove oldest 10%
trim := s.queryLogMax / 10
copy(s.queryLog, s.queryLog[trim:])
s.queryLog = s.queryLog[:len(s.queryLog)-trim]
}
s.queryLog = append(s.queryLog, entry)
s.queryLogMu.Unlock()
}
// ── Cache ────────────────────────────────────────────────────────────
func cacheKeyFor(q dns.Question) string {
return q.Name + "/" + dns.TypeToString[q.Qtype]
}
func (s *DNSServer) getCached(key string) *dns.Msg {
s.cacheMu.RLock()
defer s.cacheMu.RUnlock()
entry, ok := s.cache[key]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.msg.Copy()
}
func (s *DNSServer) putCache(key string, msg *dns.Msg) {
ttl := time.Duration(s.cfg.CacheTTL) * time.Second
if ttl <= 0 {
return
}
s.cacheMu.Lock()
s.cache[key] = &CacheEntry{msg: msg.Copy(), expiresAt: time.Now().Add(ttl)}
s.cacheMu.Unlock()
}
func (s *DNSServer) cacheCleanup() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.cacheMu.Lock()
now := time.Now()
for k, v := range s.cache {
if now.After(v.expiresAt) {
delete(s.cache, k)
}
}
s.cacheMu.Unlock()
}
}