2026-03-12 20:51:38 -07:00

249 lines
6.2 KiB
Go

package float
import (
"fmt"
"log"
"sync"
"time"
"setec-manager/internal/db"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
// Session represents an active Float Mode session, combining database state
// with the live WebSocket connection reference.
type Session struct {
ID string `json:"id"`
UserID int64 `json:"user_id"`
ClientIP string `json:"client_ip"`
ClientAgent string `json:"client_agent"`
USBBridge bool `json:"usb_bridge"`
ConnectedAt time.Time `json:"connected_at"`
ExpiresAt time.Time `json:"expires_at"`
LastPing *time.Time `json:"last_ping,omitempty"`
conn *websocket.Conn
}
// SessionManager provides in-memory + database-backed session lifecycle
// management for Float Mode connections.
type SessionManager struct {
sessions map[string]*Session
mu sync.RWMutex
db *db.DB
}
// NewSessionManager creates a new SessionManager backed by the given database.
func NewSessionManager(database *db.DB) *SessionManager {
return &SessionManager{
sessions: make(map[string]*Session),
db: database,
}
}
// Create generates a new Float session with a random UUID, storing it in both
// the in-memory map and the database.
func (sm *SessionManager) Create(userID int64, clientIP, agent string, ttl time.Duration) (string, error) {
id := uuid.New().String()
now := time.Now()
expiresAt := now.Add(ttl)
session := &Session{
ID: id,
UserID: userID,
ClientIP: clientIP,
ClientAgent: agent,
ConnectedAt: now,
ExpiresAt: expiresAt,
}
// Persist to database first
if err := sm.db.CreateFloatSession(id, userID, clientIP, agent, expiresAt); err != nil {
return "", fmt.Errorf("create session: db insert: %w", err)
}
// Store in memory
sm.mu.Lock()
sm.sessions[id] = session
sm.mu.Unlock()
log.Printf("[float/session] created session %s for user %d from %s (expires %s)",
id, userID, clientIP, expiresAt.Format(time.RFC3339))
return id, nil
}
// Get retrieves a session by ID, checking the in-memory cache first, then
// falling back to the database. Returns nil and an error if not found.
func (sm *SessionManager) Get(id string) (*Session, error) {
// Check memory first
sm.mu.RLock()
if sess, ok := sm.sessions[id]; ok {
sm.mu.RUnlock()
// Check if expired
if time.Now().After(sess.ExpiresAt) {
sm.Delete(id)
return nil, fmt.Errorf("session %s has expired", id)
}
return sess, nil
}
sm.mu.RUnlock()
// Fall back to database
dbSess, err := sm.db.GetFloatSession(id)
if err != nil {
return nil, fmt.Errorf("get session: %w", err)
}
// Check if expired
if time.Now().After(dbSess.ExpiresAt) {
sm.db.DeleteFloatSession(id)
return nil, fmt.Errorf("session %s has expired", id)
}
// Hydrate into memory
session := &Session{
ID: dbSess.ID,
UserID: dbSess.UserID,
ClientIP: dbSess.ClientIP,
ClientAgent: dbSess.ClientAgent,
USBBridge: dbSess.USBBridge,
ConnectedAt: dbSess.ConnectedAt,
ExpiresAt: dbSess.ExpiresAt,
LastPing: dbSess.LastPing,
}
sm.mu.Lock()
sm.sessions[id] = session
sm.mu.Unlock()
return session, nil
}
// Delete removes a session from both the in-memory map and the database.
func (sm *SessionManager) Delete(id string) error {
sm.mu.Lock()
sess, ok := sm.sessions[id]
if ok {
// Close the WebSocket connection if it exists
if sess.conn != nil {
sess.conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "session deleted"),
time.Now().Add(5*time.Second),
)
sess.conn.Close()
}
delete(sm.sessions, id)
}
sm.mu.Unlock()
if err := sm.db.DeleteFloatSession(id); err != nil {
return fmt.Errorf("delete session: db delete: %w", err)
}
log.Printf("[float/session] deleted session %s", id)
return nil
}
// Ping updates the last-ping timestamp for a session in both memory and DB.
func (sm *SessionManager) Ping(id string) error {
now := time.Now()
sm.mu.Lock()
if sess, ok := sm.sessions[id]; ok {
sess.LastPing = &now
}
sm.mu.Unlock()
if err := sm.db.PingFloatSession(id); err != nil {
return fmt.Errorf("ping session: %w", err)
}
return nil
}
// CleanExpired removes all sessions that have passed their expiry time.
// Returns the number of sessions removed.
func (sm *SessionManager) CleanExpired() (int, error) {
now := time.Now()
// Clean from memory
sm.mu.Lock()
var expiredIDs []string
for id, sess := range sm.sessions {
if now.After(sess.ExpiresAt) {
expiredIDs = append(expiredIDs, id)
if sess.conn != nil {
sess.conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "session expired"),
now.Add(5*time.Second),
)
sess.conn.Close()
}
}
}
for _, id := range expiredIDs {
delete(sm.sessions, id)
}
sm.mu.Unlock()
// Clean from database
count, err := sm.db.CleanExpiredFloatSessions()
if err != nil {
return len(expiredIDs), fmt.Errorf("clean expired: db: %w", err)
}
total := int(count)
if total > 0 {
log.Printf("[float/session] cleaned %d expired sessions", total)
}
return total, nil
}
// ActiveCount returns the number of sessions currently in the in-memory map.
func (sm *SessionManager) ActiveCount() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.sessions)
}
// SetConn associates a WebSocket connection with a session.
func (sm *SessionManager) SetConn(id string, conn *websocket.Conn) {
sm.mu.Lock()
if sess, ok := sm.sessions[id]; ok {
sess.conn = conn
sess.USBBridge = true
}
sm.mu.Unlock()
}
// List returns all active (non-expired) sessions from the database.
func (sm *SessionManager) List() ([]Session, error) {
dbSessions, err := sm.db.ListFloatSessions()
if err != nil {
return nil, fmt.Errorf("list sessions: %w", err)
}
sessions := make([]Session, 0, len(dbSessions))
for _, dbs := range dbSessions {
if time.Now().After(dbs.ExpiresAt) {
continue
}
sessions = append(sessions, Session{
ID: dbs.ID,
UserID: dbs.UserID,
ClientIP: dbs.ClientIP,
ClientAgent: dbs.ClientAgent,
USBBridge: dbs.USBBridge,
ConnectedAt: dbs.ConnectedAt,
ExpiresAt: dbs.ExpiresAt,
LastPing: dbs.LastPing,
})
}
return sessions, nil
}