No One Can Stop Me Now
This commit is contained in:
366
services/setec-manager/internal/float/bridge.go
Normal file
366
services/setec-manager/internal/float/bridge.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package float
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"setec-manager/internal/db"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Bridge manages WebSocket connections for USB passthrough in Float Mode.
|
||||
type Bridge struct {
|
||||
db *db.DB
|
||||
sessions map[string]*bridgeConn
|
||||
mu sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// bridgeConn tracks a single active WebSocket connection and its associated session.
|
||||
type bridgeConn struct {
|
||||
sessionID string
|
||||
conn *websocket.Conn
|
||||
devices []USBDevice
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
const (
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingInterval = 30 * time.Second
|
||||
maxMessageSize = 64 * 1024 // 64 KB max frame payload
|
||||
)
|
||||
|
||||
// NewBridge creates a new Bridge with the given database reference.
|
||||
func NewBridge(database *db.DB) *Bridge {
|
||||
return &Bridge{
|
||||
db: database,
|
||||
sessions: make(map[string]*bridgeConn),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Accept all origins; auth is handled via session token
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket upgrades an HTTP connection to WebSocket and manages the
|
||||
// binary frame protocol for USB passthrough. The session ID must be provided
|
||||
// as a "session" query parameter.
|
||||
func (b *Bridge) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.URL.Query().Get("session")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "missing session parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate session exists and is not expired
|
||||
sess, err := b.db.GetFloatSession(sessionID)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if time.Now().After(sess.ExpiresAt) {
|
||||
http.Error(w, "session expired", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
conn, err := b.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("[float/bridge] upgrade failed for session %s: %v", sessionID, err)
|
||||
return
|
||||
}
|
||||
|
||||
bc := &bridgeConn{
|
||||
sessionID: sessionID,
|
||||
conn: conn,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Register active connection
|
||||
b.mu.Lock()
|
||||
// Close any existing connection for this session
|
||||
if existing, ok := b.sessions[sessionID]; ok {
|
||||
close(existing.done)
|
||||
existing.conn.Close()
|
||||
}
|
||||
b.sessions[sessionID] = bc
|
||||
b.mu.Unlock()
|
||||
|
||||
log.Printf("[float/bridge] session %s connected from %s", sessionID, r.RemoteAddr)
|
||||
|
||||
// Start read/write loops
|
||||
go b.writePump(bc)
|
||||
b.readPump(bc)
|
||||
}
|
||||
|
||||
// readPump reads binary frames from the WebSocket and dispatches them.
|
||||
func (b *Bridge) readPump(bc *bridgeConn) {
|
||||
defer b.cleanup(bc)
|
||||
|
||||
bc.conn.SetReadLimit(maxMessageSize)
|
||||
bc.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
bc.conn.SetPongHandler(func(string) error {
|
||||
bc.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
messageType, data, err := bc.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
log.Printf("[float/bridge] session %s read error: %v", bc.sessionID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
b.sendError(bc, 0x0001, "expected binary message")
|
||||
continue
|
||||
}
|
||||
|
||||
frameType, payload, err := DecodeFrame(data)
|
||||
if err != nil {
|
||||
b.sendError(bc, 0x0002, "malformed frame: "+err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// Update session ping in DB
|
||||
b.db.PingFloatSession(bc.sessionID)
|
||||
|
||||
switch frameType {
|
||||
case FrameEnumerate:
|
||||
b.handleEnumerate(bc)
|
||||
case FrameOpen:
|
||||
b.handleOpen(bc, payload)
|
||||
case FrameClose:
|
||||
b.handleClose(bc, payload)
|
||||
case FrameTransferOut:
|
||||
b.handleTransfer(bc, payload)
|
||||
case FrameInterrupt:
|
||||
b.handleInterrupt(bc, payload)
|
||||
case FramePong:
|
||||
// Client responded to our ping; no action needed
|
||||
default:
|
||||
b.sendError(bc, 0x0003, "unknown frame type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writePump sends periodic pings to keep the connection alive.
|
||||
func (b *Bridge) writePump(bc *bridgeConn) {
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
bc.mu.Lock()
|
||||
bc.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
err := bc.conn.WriteMessage(websocket.BinaryMessage, EncodeFrame(FramePing, nil))
|
||||
bc.mu.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-bc.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleEnumerate responds with the current list of USB devices known to this
|
||||
// session. In a full implementation, this would forward the enumerate request
|
||||
// to the client-side USB agent and await its response. Here we return the
|
||||
// cached device list.
|
||||
func (b *Bridge) handleEnumerate(bc *bridgeConn) {
|
||||
bc.mu.Lock()
|
||||
devices := bc.devices
|
||||
bc.mu.Unlock()
|
||||
|
||||
if devices == nil {
|
||||
devices = []USBDevice{}
|
||||
}
|
||||
|
||||
payload := EncodeDeviceList(devices)
|
||||
b.sendFrame(bc, FrameEnumResult, payload)
|
||||
}
|
||||
|
||||
// handleOpen processes a device open request. The payload contains
|
||||
// [deviceID:2] identifying which device to claim.
|
||||
func (b *Bridge) handleOpen(bc *bridgeConn, payload []byte) {
|
||||
if len(payload) < 2 {
|
||||
b.sendError(bc, 0x0010, "open: payload too short")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := uint16(payload[0])<<8 | uint16(payload[1])
|
||||
|
||||
// Verify the device exists in our known list
|
||||
bc.mu.Lock()
|
||||
found := false
|
||||
for _, dev := range bc.devices {
|
||||
if dev.DeviceID == deviceID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
bc.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
b.sendError(bc, 0x0011, "open: device not found")
|
||||
return
|
||||
}
|
||||
|
||||
// In a real implementation, this would claim the USB device via the host agent.
|
||||
// For now, acknowledge the open request.
|
||||
result := make([]byte, 3)
|
||||
result[0] = payload[0]
|
||||
result[1] = payload[1]
|
||||
result[2] = 0x00 // success
|
||||
b.sendFrame(bc, FrameOpenResult, result)
|
||||
|
||||
log.Printf("[float/bridge] session %s opened device 0x%04X", bc.sessionID, deviceID)
|
||||
}
|
||||
|
||||
// handleClose processes a device close request. Payload: [deviceID:2].
|
||||
func (b *Bridge) handleClose(bc *bridgeConn, payload []byte) {
|
||||
if len(payload) < 2 {
|
||||
b.sendError(bc, 0x0020, "close: payload too short")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := uint16(payload[0])<<8 | uint16(payload[1])
|
||||
|
||||
// Acknowledge close
|
||||
result := make([]byte, 3)
|
||||
result[0] = payload[0]
|
||||
result[1] = payload[1]
|
||||
result[2] = 0x00 // success
|
||||
b.sendFrame(bc, FrameCloseResult, result)
|
||||
|
||||
log.Printf("[float/bridge] session %s closed device 0x%04X", bc.sessionID, deviceID)
|
||||
}
|
||||
|
||||
// handleTransfer forwards a bulk/interrupt OUT transfer to the USB device.
|
||||
func (b *Bridge) handleTransfer(bc *bridgeConn, payload []byte) {
|
||||
deviceID, endpoint, transferData, err := DecodeTransfer(payload)
|
||||
if err != nil {
|
||||
b.sendError(bc, 0x0030, "transfer: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// In a real implementation, the transfer data would be sent to the USB device
|
||||
// via the host agent, and the response would be sent back. Here we acknowledge
|
||||
// receipt of the transfer request.
|
||||
log.Printf("[float/bridge] session %s transfer to device 0x%04X endpoint 0x%02X: %d bytes",
|
||||
bc.sessionID, deviceID, endpoint, len(transferData))
|
||||
|
||||
// Build transfer result: [deviceID:2][endpoint:1][status:1]
|
||||
result := make([]byte, 4)
|
||||
result[0] = byte(deviceID >> 8)
|
||||
result[1] = byte(deviceID)
|
||||
result[2] = endpoint
|
||||
result[3] = 0x00 // success
|
||||
b.sendFrame(bc, FrameTransferResult, result)
|
||||
}
|
||||
|
||||
// handleInterrupt processes an interrupt transfer request.
|
||||
func (b *Bridge) handleInterrupt(bc *bridgeConn, payload []byte) {
|
||||
if len(payload) < 3 {
|
||||
b.sendError(bc, 0x0040, "interrupt: payload too short")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := uint16(payload[0])<<8 | uint16(payload[1])
|
||||
endpoint := payload[2]
|
||||
|
||||
log.Printf("[float/bridge] session %s interrupt on device 0x%04X endpoint 0x%02X",
|
||||
bc.sessionID, deviceID, endpoint)
|
||||
|
||||
// Acknowledge interrupt request
|
||||
result := make([]byte, 4)
|
||||
result[0] = payload[0]
|
||||
result[1] = payload[1]
|
||||
result[2] = endpoint
|
||||
result[3] = 0x00 // success
|
||||
b.sendFrame(bc, FrameInterruptResult, result)
|
||||
}
|
||||
|
||||
// sendFrame writes a binary frame to the WebSocket connection.
|
||||
func (b *Bridge) sendFrame(bc *bridgeConn, frameType byte, payload []byte) {
|
||||
bc.mu.Lock()
|
||||
defer bc.mu.Unlock()
|
||||
|
||||
bc.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := bc.conn.WriteMessage(websocket.BinaryMessage, EncodeFrame(frameType, payload)); err != nil {
|
||||
log.Printf("[float/bridge] session %s write error: %v", bc.sessionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendError writes an error frame to the WebSocket connection.
|
||||
func (b *Bridge) sendError(bc *bridgeConn, code uint16, message string) {
|
||||
b.sendFrame(bc, FrameError, EncodeError(code, message))
|
||||
}
|
||||
|
||||
// cleanup removes a connection from the active sessions and cleans up resources.
|
||||
func (b *Bridge) cleanup(bc *bridgeConn) {
|
||||
b.mu.Lock()
|
||||
if current, ok := b.sessions[bc.sessionID]; ok && current == bc {
|
||||
delete(b.sessions, bc.sessionID)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
close(bc.done)
|
||||
bc.conn.Close()
|
||||
|
||||
log.Printf("[float/bridge] session %s disconnected", bc.sessionID)
|
||||
}
|
||||
|
||||
// ActiveSessions returns the number of currently connected WebSocket sessions.
|
||||
func (b *Bridge) ActiveSessions() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return len(b.sessions)
|
||||
}
|
||||
|
||||
// DisconnectSession forcibly closes the WebSocket connection for a given session.
|
||||
func (b *Bridge) DisconnectSession(sessionID string) {
|
||||
b.mu.Lock()
|
||||
bc, ok := b.sessions[sessionID]
|
||||
if ok {
|
||||
delete(b.sessions, sessionID)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
close(bc.done)
|
||||
bc.conn.WriteControl(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "session terminated"),
|
||||
time.Now().Add(writeWait),
|
||||
)
|
||||
bc.conn.Close()
|
||||
log.Printf("[float/bridge] session %s forcibly disconnected", sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDeviceList sets the known device list for a session (called when the
|
||||
// client-side USB agent reports its attached devices).
|
||||
func (b *Bridge) UpdateDeviceList(sessionID string, devices []USBDevice) {
|
||||
b.mu.RLock()
|
||||
bc, ok := b.sessions[sessionID]
|
||||
b.mu.RUnlock()
|
||||
|
||||
if ok {
|
||||
bc.mu.Lock()
|
||||
bc.devices = devices
|
||||
bc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
225
services/setec-manager/internal/float/protocol.go
Normal file
225
services/setec-manager/internal/float/protocol.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package float
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Frame type constants define the binary protocol for USB passthrough over WebSocket.
|
||||
const (
|
||||
FrameEnumerate byte = 0x01
|
||||
FrameEnumResult byte = 0x02
|
||||
FrameOpen byte = 0x03
|
||||
FrameOpenResult byte = 0x04
|
||||
FrameClose byte = 0x05
|
||||
FrameCloseResult byte = 0x06
|
||||
FrameTransferOut byte = 0x10
|
||||
FrameTransferIn byte = 0x11
|
||||
FrameTransferResult byte = 0x12
|
||||
FrameInterrupt byte = 0x20
|
||||
FrameInterruptResult byte = 0x21
|
||||
FramePing byte = 0xFE
|
||||
FramePong byte = 0xFF
|
||||
FrameError byte = 0xE0
|
||||
)
|
||||
|
||||
// frameHeaderSize is the fixed size of a frame header: 1 byte type + 4 bytes length.
|
||||
const frameHeaderSize = 5
|
||||
|
||||
// USBDevice represents a USB device detected on the client host.
|
||||
type USBDevice struct {
|
||||
VendorID uint16 `json:"vendor_id"`
|
||||
ProductID uint16 `json:"product_id"`
|
||||
DeviceID uint16 `json:"device_id"`
|
||||
Manufacturer string `json:"manufacturer"`
|
||||
Product string `json:"product"`
|
||||
SerialNumber string `json:"serial_number"`
|
||||
Class byte `json:"class"`
|
||||
SubClass byte `json:"sub_class"`
|
||||
}
|
||||
|
||||
// deviceFixedSize is the fixed portion of a serialized USBDevice:
|
||||
// VendorID(2) + ProductID(2) + DeviceID(2) + Class(1) + SubClass(1) + 3 string lengths (2 each) = 14
|
||||
const deviceFixedSize = 14
|
||||
|
||||
// EncodeFrame builds a binary frame: [type:1][length:4 big-endian][payload:N].
|
||||
func EncodeFrame(frameType byte, payload []byte) []byte {
|
||||
frame := make([]byte, frameHeaderSize+len(payload))
|
||||
frame[0] = frameType
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
||||
copy(frame[frameHeaderSize:], payload)
|
||||
return frame
|
||||
}
|
||||
|
||||
// DecodeFrame parses a binary frame into its type and payload.
|
||||
func DecodeFrame(data []byte) (frameType byte, payload []byte, err error) {
|
||||
if len(data) < frameHeaderSize {
|
||||
return 0, nil, fmt.Errorf("frame too short: need at least %d bytes, got %d", frameHeaderSize, len(data))
|
||||
}
|
||||
|
||||
frameType = data[0]
|
||||
length := binary.BigEndian.Uint32(data[1:5])
|
||||
|
||||
if uint32(len(data)-frameHeaderSize) < length {
|
||||
return 0, nil, fmt.Errorf("frame payload truncated: header says %d bytes, have %d", length, len(data)-frameHeaderSize)
|
||||
}
|
||||
|
||||
payload = make([]byte, length)
|
||||
copy(payload, data[frameHeaderSize:frameHeaderSize+int(length)])
|
||||
return frameType, payload, nil
|
||||
}
|
||||
|
||||
// encodeString writes a length-prefixed string (2-byte big-endian length + bytes).
|
||||
func encodeString(buf []byte, offset int, s string) int {
|
||||
b := []byte(s)
|
||||
binary.BigEndian.PutUint16(buf[offset:], uint16(len(b)))
|
||||
offset += 2
|
||||
copy(buf[offset:], b)
|
||||
return offset + len(b)
|
||||
}
|
||||
|
||||
// decodeString reads a length-prefixed string from the buffer.
|
||||
func decodeString(data []byte, offset int) (string, int, error) {
|
||||
if offset+2 > len(data) {
|
||||
return "", 0, fmt.Errorf("string length truncated at offset %d", offset)
|
||||
}
|
||||
slen := int(binary.BigEndian.Uint16(data[offset:]))
|
||||
offset += 2
|
||||
if offset+slen > len(data) {
|
||||
return "", 0, fmt.Errorf("string data truncated at offset %d: need %d bytes", offset, slen)
|
||||
}
|
||||
s := string(data[offset : offset+slen])
|
||||
return s, offset + slen, nil
|
||||
}
|
||||
|
||||
// serializeDevice serializes a single USBDevice into bytes.
|
||||
func serializeDevice(dev USBDevice) []byte {
|
||||
mfr := []byte(dev.Manufacturer)
|
||||
prod := []byte(dev.Product)
|
||||
ser := []byte(dev.SerialNumber)
|
||||
|
||||
size := deviceFixedSize + len(mfr) + len(prod) + len(ser)
|
||||
buf := make([]byte, size)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[0:], dev.VendorID)
|
||||
binary.BigEndian.PutUint16(buf[2:], dev.ProductID)
|
||||
binary.BigEndian.PutUint16(buf[4:], dev.DeviceID)
|
||||
buf[6] = dev.Class
|
||||
buf[7] = dev.SubClass
|
||||
|
||||
off := 8
|
||||
off = encodeString(buf, off, dev.Manufacturer)
|
||||
off = encodeString(buf, off, dev.Product)
|
||||
_ = encodeString(buf, off, dev.SerialNumber)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// EncodeDeviceList serializes a slice of USBDevices for a FrameEnumResult payload.
|
||||
// Format: [count:2 big-endian][device...]
|
||||
func EncodeDeviceList(devices []USBDevice) []byte {
|
||||
// First pass: serialize each device to compute total size
|
||||
serialized := make([][]byte, len(devices))
|
||||
totalSize := 2 // 2 bytes for count
|
||||
for i, dev := range devices {
|
||||
serialized[i] = serializeDevice(dev)
|
||||
totalSize += len(serialized[i])
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
binary.BigEndian.PutUint16(buf[0:], uint16(len(devices)))
|
||||
off := 2
|
||||
for _, s := range serialized {
|
||||
copy(buf[off:], s)
|
||||
off += len(s)
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeDeviceList deserializes a FrameEnumResult payload into a slice of USBDevices.
|
||||
func DecodeDeviceList(data []byte) ([]USBDevice, error) {
|
||||
if len(data) < 2 {
|
||||
return nil, fmt.Errorf("device list too short: need at least 2 bytes")
|
||||
}
|
||||
|
||||
count := int(binary.BigEndian.Uint16(data[0:]))
|
||||
off := 2
|
||||
|
||||
devices := make([]USBDevice, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
if off+8 > len(data) {
|
||||
return nil, fmt.Errorf("device %d: fixed fields truncated at offset %d", i, off)
|
||||
}
|
||||
|
||||
dev := USBDevice{
|
||||
VendorID: binary.BigEndian.Uint16(data[off:]),
|
||||
ProductID: binary.BigEndian.Uint16(data[off+2:]),
|
||||
DeviceID: binary.BigEndian.Uint16(data[off+4:]),
|
||||
Class: data[off+6],
|
||||
SubClass: data[off+7],
|
||||
}
|
||||
off += 8
|
||||
|
||||
var err error
|
||||
dev.Manufacturer, off, err = decodeString(data, off)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device %d manufacturer: %w", i, err)
|
||||
}
|
||||
dev.Product, off, err = decodeString(data, off)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device %d product: %w", i, err)
|
||||
}
|
||||
dev.SerialNumber, off, err = decodeString(data, off)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device %d serial: %w", i, err)
|
||||
}
|
||||
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// EncodeTransfer serializes a USB transfer payload.
|
||||
// Format: [deviceID:2][endpoint:1][data:N]
|
||||
func EncodeTransfer(deviceID uint16, endpoint byte, data []byte) []byte {
|
||||
buf := make([]byte, 3+len(data))
|
||||
binary.BigEndian.PutUint16(buf[0:], deviceID)
|
||||
buf[2] = endpoint
|
||||
copy(buf[3:], data)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeTransfer deserializes a USB transfer payload.
|
||||
func DecodeTransfer(data []byte) (deviceID uint16, endpoint byte, transferData []byte, err error) {
|
||||
if len(data) < 3 {
|
||||
return 0, 0, nil, fmt.Errorf("transfer payload too short: need at least 3 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
deviceID = binary.BigEndian.Uint16(data[0:])
|
||||
endpoint = data[2]
|
||||
transferData = make([]byte, len(data)-3)
|
||||
copy(transferData, data[3:])
|
||||
return deviceID, endpoint, transferData, nil
|
||||
}
|
||||
|
||||
// EncodeError serializes an error response payload.
|
||||
// Format: [code:2 big-endian][message:UTF-8 bytes]
|
||||
func EncodeError(code uint16, message string) []byte {
|
||||
msg := []byte(message)
|
||||
buf := make([]byte, 2+len(msg))
|
||||
binary.BigEndian.PutUint16(buf[0:], code)
|
||||
copy(buf[2:], msg)
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeError deserializes an error response payload.
|
||||
func DecodeError(data []byte) (code uint16, message string) {
|
||||
if len(data) < 2 {
|
||||
return 0, ""
|
||||
}
|
||||
code = binary.BigEndian.Uint16(data[0:])
|
||||
message = string(data[2:])
|
||||
return code, message
|
||||
}
|
||||
248
services/setec-manager/internal/float/session.go
Normal file
248
services/setec-manager/internal/float/session.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package float
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"setec-manager/internal/db"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Session represents an active Float Mode session, combining database state
|
||||
// with the live WebSocket connection reference.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
ClientAgent string `json:"client_agent"`
|
||||
USBBridge bool `json:"usb_bridge"`
|
||||
ConnectedAt time.Time `json:"connected_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastPing *time.Time `json:"last_ping,omitempty"`
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
// SessionManager provides in-memory + database-backed session lifecycle
|
||||
// management for Float Mode connections.
|
||||
type SessionManager struct {
|
||||
sessions map[string]*Session
|
||||
mu sync.RWMutex
|
||||
db *db.DB
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new SessionManager backed by the given database.
|
||||
func NewSessionManager(database *db.DB) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[string]*Session),
|
||||
db: database,
|
||||
}
|
||||
}
|
||||
|
||||
// Create generates a new Float session with a random UUID, storing it in both
|
||||
// the in-memory map and the database.
|
||||
func (sm *SessionManager) Create(userID int64, clientIP, agent string, ttl time.Duration) (string, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(ttl)
|
||||
|
||||
session := &Session{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ClientIP: clientIP,
|
||||
ClientAgent: agent,
|
||||
ConnectedAt: now,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// Persist to database first
|
||||
if err := sm.db.CreateFloatSession(id, userID, clientIP, agent, expiresAt); err != nil {
|
||||
return "", fmt.Errorf("create session: db insert: %w", err)
|
||||
}
|
||||
|
||||
// Store in memory
|
||||
sm.mu.Lock()
|
||||
sm.sessions[id] = session
|
||||
sm.mu.Unlock()
|
||||
|
||||
log.Printf("[float/session] created session %s for user %d from %s (expires %s)",
|
||||
id, userID, clientIP, expiresAt.Format(time.RFC3339))
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Get retrieves a session by ID, checking the in-memory cache first, then
|
||||
// falling back to the database. Returns nil and an error if not found.
|
||||
func (sm *SessionManager) Get(id string) (*Session, error) {
|
||||
// Check memory first
|
||||
sm.mu.RLock()
|
||||
if sess, ok := sm.sessions[id]; ok {
|
||||
sm.mu.RUnlock()
|
||||
// Check if expired
|
||||
if time.Now().After(sess.ExpiresAt) {
|
||||
sm.Delete(id)
|
||||
return nil, fmt.Errorf("session %s has expired", id)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
// Fall back to database
|
||||
dbSess, err := sm.db.GetFloatSession(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(dbSess.ExpiresAt) {
|
||||
sm.db.DeleteFloatSession(id)
|
||||
return nil, fmt.Errorf("session %s has expired", id)
|
||||
}
|
||||
|
||||
// Hydrate into memory
|
||||
session := &Session{
|
||||
ID: dbSess.ID,
|
||||
UserID: dbSess.UserID,
|
||||
ClientIP: dbSess.ClientIP,
|
||||
ClientAgent: dbSess.ClientAgent,
|
||||
USBBridge: dbSess.USBBridge,
|
||||
ConnectedAt: dbSess.ConnectedAt,
|
||||
ExpiresAt: dbSess.ExpiresAt,
|
||||
LastPing: dbSess.LastPing,
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.sessions[id] = session
|
||||
sm.mu.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Delete removes a session from both the in-memory map and the database.
|
||||
func (sm *SessionManager) Delete(id string) error {
|
||||
sm.mu.Lock()
|
||||
sess, ok := sm.sessions[id]
|
||||
if ok {
|
||||
// Close the WebSocket connection if it exists
|
||||
if sess.conn != nil {
|
||||
sess.conn.WriteControl(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "session deleted"),
|
||||
time.Now().Add(5*time.Second),
|
||||
)
|
||||
sess.conn.Close()
|
||||
}
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
if err := sm.db.DeleteFloatSession(id); err != nil {
|
||||
return fmt.Errorf("delete session: db delete: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[float/session] deleted session %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping updates the last-ping timestamp for a session in both memory and DB.
|
||||
func (sm *SessionManager) Ping(id string) error {
|
||||
now := time.Now()
|
||||
|
||||
sm.mu.Lock()
|
||||
if sess, ok := sm.sessions[id]; ok {
|
||||
sess.LastPing = &now
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
if err := sm.db.PingFloatSession(id); err != nil {
|
||||
return fmt.Errorf("ping session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanExpired removes all sessions that have passed their expiry time.
|
||||
// Returns the number of sessions removed.
|
||||
func (sm *SessionManager) CleanExpired() (int, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Clean from memory
|
||||
sm.mu.Lock()
|
||||
var expiredIDs []string
|
||||
for id, sess := range sm.sessions {
|
||||
if now.After(sess.ExpiresAt) {
|
||||
expiredIDs = append(expiredIDs, id)
|
||||
if sess.conn != nil {
|
||||
sess.conn.WriteControl(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "session expired"),
|
||||
now.Add(5*time.Second),
|
||||
)
|
||||
sess.conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, id := range expiredIDs {
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Clean from database
|
||||
count, err := sm.db.CleanExpiredFloatSessions()
|
||||
if err != nil {
|
||||
return len(expiredIDs), fmt.Errorf("clean expired: db: %w", err)
|
||||
}
|
||||
|
||||
total := int(count)
|
||||
if total > 0 {
|
||||
log.Printf("[float/session] cleaned %d expired sessions", total)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of sessions currently in the in-memory map.
|
||||
func (sm *SessionManager) ActiveCount() int {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return len(sm.sessions)
|
||||
}
|
||||
|
||||
// SetConn associates a WebSocket connection with a session.
|
||||
func (sm *SessionManager) SetConn(id string, conn *websocket.Conn) {
|
||||
sm.mu.Lock()
|
||||
if sess, ok := sm.sessions[id]; ok {
|
||||
sess.conn = conn
|
||||
sess.USBBridge = true
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
// List returns all active (non-expired) sessions from the database.
|
||||
func (sm *SessionManager) List() ([]Session, error) {
|
||||
dbSessions, err := sm.db.ListFloatSessions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
|
||||
sessions := make([]Session, 0, len(dbSessions))
|
||||
for _, dbs := range dbSessions {
|
||||
if time.Now().After(dbs.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, Session{
|
||||
ID: dbs.ID,
|
||||
UserID: dbs.UserID,
|
||||
ClientIP: dbs.ClientIP,
|
||||
ClientAgent: dbs.ClientAgent,
|
||||
USBBridge: dbs.USBBridge,
|
||||
ConnectedAt: dbs.ConnectedAt,
|
||||
ExpiresAt: dbs.ExpiresAt,
|
||||
LastPing: dbs.LastPing,
|
||||
})
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
Reference in New Issue
Block a user