200 lines
5.7 KiB
Go
200 lines
5.7 KiB
Go
package server
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"unicode"
|
|
)
|
|
|
|
// ── Security Headers Middleware ──────────────────────────────────────
|
|
|
|
func securityHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
|
w.Header().Set("Content-Security-Policy",
|
|
"default-src 'self'; "+
|
|
"script-src 'self' 'unsafe-inline'; "+
|
|
"style-src 'self' 'unsafe-inline'; "+
|
|
"img-src 'self' data:; "+
|
|
"font-src 'self'; "+
|
|
"connect-src 'self'; "+
|
|
"frame-ancestors 'none'")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ── Request Body Limit ──────────────────────────────────────────────
|
|
|
|
func maxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// ── CSRF Protection ─────────────────────────────────────────────────
|
|
|
|
const csrfTokenLength = 32
|
|
const csrfCookieName = "setec_csrf"
|
|
const csrfHeaderName = "X-CSRF-Token"
|
|
const csrfFormField = "csrf_token"
|
|
|
|
func generateCSRFToken() string {
|
|
b := make([]byte, csrfTokenLength)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func csrfProtection(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Safe methods don't need CSRF validation
|
|
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
|
// Ensure a CSRF cookie exists for forms to use
|
|
if _, err := r.Cookie(csrfCookieName); err != nil {
|
|
token := generateCSRFToken()
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: csrfCookieName,
|
|
Value: token,
|
|
Path: "/",
|
|
HttpOnly: false, // JS needs to read this
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
MaxAge: 86400,
|
|
})
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// For mutating requests, validate CSRF token
|
|
cookie, err := r.Cookie(csrfCookieName)
|
|
if err != nil {
|
|
http.Error(w, "CSRF token missing", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Check header first, then form field
|
|
token := r.Header.Get(csrfHeaderName)
|
|
if token == "" {
|
|
token = r.FormValue(csrfFormField)
|
|
}
|
|
|
|
// API requests with JSON Content-Type + Bearer auth skip CSRF
|
|
// (they're not vulnerable to CSRF since browsers don't send custom headers)
|
|
contentType := r.Header.Get("Content-Type")
|
|
authHeader := r.Header.Get("Authorization")
|
|
if strings.Contains(contentType, "application/json") && strings.HasPrefix(authHeader, "Bearer ") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if token != cookie.Value {
|
|
http.Error(w, "CSRF token invalid", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ── Password Policy ─────────────────────────────────────────────────
|
|
|
|
type passwordPolicy struct {
|
|
MinLength int
|
|
RequireUpper bool
|
|
RequireLower bool
|
|
RequireDigit bool
|
|
}
|
|
|
|
var defaultPasswordPolicy = passwordPolicy{
|
|
MinLength: 8,
|
|
RequireUpper: true,
|
|
RequireLower: true,
|
|
RequireDigit: true,
|
|
}
|
|
|
|
func validatePassword(password string) error {
|
|
p := defaultPasswordPolicy
|
|
|
|
if len(password) < p.MinLength {
|
|
return fmt.Errorf("password must be at least %d characters", p.MinLength)
|
|
}
|
|
|
|
hasUpper, hasLower, hasDigit := false, false, false
|
|
for _, c := range password {
|
|
if unicode.IsUpper(c) {
|
|
hasUpper = true
|
|
}
|
|
if unicode.IsLower(c) {
|
|
hasLower = true
|
|
}
|
|
if unicode.IsDigit(c) {
|
|
hasDigit = true
|
|
}
|
|
}
|
|
|
|
if p.RequireUpper && !hasUpper {
|
|
return fmt.Errorf("password must contain at least one uppercase letter")
|
|
}
|
|
if p.RequireLower && !hasLower {
|
|
return fmt.Errorf("password must contain at least one lowercase letter")
|
|
}
|
|
if p.RequireDigit && !hasDigit {
|
|
return fmt.Errorf("password must contain at least one digit")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ── Persistent JWT Key ──────────────────────────────────────────────
|
|
|
|
func LoadOrCreateJWTKey(dataDir string) ([]byte, error) {
|
|
keyPath := filepath.Join(dataDir, ".jwt_key")
|
|
|
|
// Try to load existing key
|
|
data, err := os.ReadFile(keyPath)
|
|
if err == nil && len(data) == 32 {
|
|
return data, nil
|
|
}
|
|
|
|
// Generate new key
|
|
key := make([]byte, 32)
|
|
if _, err := rand.Read(key); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Save with restrictive permissions
|
|
os.MkdirAll(dataDir, 0700)
|
|
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// ── Audit Logger ────────────────────────────────────────────────────
|
|
|
|
func (s *Server) logAudit(r *http.Request, action, detail string) {
|
|
claims := getClaimsFromContext(r.Context())
|
|
username := "anonymous"
|
|
if claims != nil {
|
|
username = claims.Username
|
|
}
|
|
ip := r.RemoteAddr
|
|
|
|
// Insert into audit log table
|
|
s.DB.Conn().Exec(`INSERT INTO audit_log (username, ip, action, detail) VALUES (?, ?, ?, ?)`,
|
|
username, ip, action, detail)
|
|
}
|