2026-03-12 20:51:38 -07:00

136 lines
3.3 KiB
Go

package server
import (
"context"
"net/http"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
type contextKey string
const claimsKey contextKey = "claims"
// authRequired validates JWT from cookie or Authorization header.
func (s *Server) authRequired(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenStr := ""
// Try cookie first
if cookie, err := r.Cookie("setec_token"); err == nil {
tokenStr = cookie.Value
}
// Fall back to Authorization header
if tokenStr == "" {
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
tokenStr = strings.TrimPrefix(auth, "Bearer ")
}
}
if tokenStr == "" {
// If HTML request, redirect to login
if acceptsHTML(r) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (interface{}, error) {
return s.JWTKey, nil
})
if err != nil || !token.Valid {
if acceptsHTML(r) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
http.Error(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), claimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// adminRequired checks that the authenticated user has admin role.
func (s *Server) adminRequired(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := getClaimsFromContext(r.Context())
if claims == nil || claims.Role != "admin" {
http.Error(w, "Admin access required", http.StatusForbidden)
return
}
next.ServeHTTP(w, r.WithContext(r.Context()))
})
}
func getClaimsFromContext(ctx context.Context) *Claims {
claims, _ := ctx.Value(claimsKey).(*Claims)
return claims
}
func acceptsHTML(r *http.Request) bool {
return strings.Contains(r.Header.Get("Accept"), "text/html")
}
// ── Rate Limiter ────────────────────────────────────────────────────
type rateLimiter struct {
mu sync.Mutex
attempts map[string][]time.Time
limit int
window time.Duration
}
func newRateLimiter(limit int, window time.Duration) *rateLimiter {
return &rateLimiter{
attempts: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
func (rl *rateLimiter) Allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Remove expired entries
var valid []time.Time
for _, t := range rl.attempts[key] {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) >= rl.limit {
rl.attempts[key] = valid
return false
}
rl.attempts[key] = append(valid, now)
return true
}
func (s *Server) loginRateLimit(next http.Handler) http.Handler {
limiter := newRateLimiter(5, time.Minute)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := r.RemoteAddr
if !limiter.Allow(ip) {
http.Error(w, "Too many login attempts. Try again in a minute.", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}