vibetunnel/linux/pkg/server/services/buffer_aggregator.go
2025-06-21 02:49:38 +02:00

406 lines
No EOL
11 KiB
Go

package services
import (
"encoding/json"
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// BufferAggregatorConfig holds configuration for BufferAggregator
type BufferAggregatorConfig struct {
TerminalManager *TerminalManager
RemoteRegistry *RemoteRegistry
IsHQMode bool
}
// BufferAggregator manages WebSocket connections and buffer distribution
type BufferAggregator struct {
config *BufferAggregatorConfig
clientSubscriptions map[*websocket.Conn]map[string]func() // conn -> sessionID -> unsubscribe func
remoteConnections map[string]*RemoteWebSocketConnection
mu sync.RWMutex
upgrader websocket.Upgrader
}
// RemoteWebSocketConnection represents a connection to a remote server
type RemoteWebSocketConnection struct {
WS *websocket.Conn
RemoteID string
RemoteName string
Subscriptions map[string]bool
}
// NewBufferAggregator creates a new buffer aggregator service
func NewBufferAggregator(config *BufferAggregatorConfig) *BufferAggregator {
return &BufferAggregator{
config: config,
clientSubscriptions: make(map[*websocket.Conn]map[string]func()),
remoteConnections: make(map[string]*RemoteWebSocketConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
},
}
}
// HandleClientConnection handles a new WebSocket client connection
func (ba *BufferAggregator) HandleClientConnection(w http.ResponseWriter, r *http.Request) {
conn, err := ba.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("[BufferAggregator] Failed to upgrade connection: %v", err)
return
}
log.Printf("[BufferAggregator] New client connected")
// Initialize subscription map for this client
ba.mu.Lock()
ba.clientSubscriptions[conn] = make(map[string]func())
ba.mu.Unlock()
// Send welcome message
conn.WriteJSON(map[string]interface{}{
"type": "connected",
"version": "1.0",
})
// Handle messages from client
go ba.handleClientMessages(conn)
}
// handleClientMessages handles incoming messages from a client
func (ba *BufferAggregator) handleClientMessages(conn *websocket.Conn) {
defer func() {
ba.handleClientDisconnect(conn)
conn.Close()
}()
for {
var msg map[string]interface{}
if err := conn.ReadJSON(&msg); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("[BufferAggregator] WebSocket error: %v", err)
}
break
}
msgType, _ := msg["type"].(string)
sessionID, _ := msg["sessionId"].(string)
switch msgType {
case "subscribe":
if sessionID != "" {
ba.handleSubscribe(conn, sessionID)
}
case "unsubscribe":
if sessionID != "" {
ba.handleUnsubscribe(conn, sessionID)
}
case "ping":
conn.WriteJSON(map[string]interface{}{
"type": "pong",
"timestamp": time.Now().UnixMilli(),
})
}
}
}
// handleSubscribe handles subscription requests
func (ba *BufferAggregator) handleSubscribe(conn *websocket.Conn, sessionID string) {
ba.mu.Lock()
subscriptions := ba.clientSubscriptions[conn]
ba.mu.Unlock()
if subscriptions == nil {
return
}
// Unsubscribe from existing subscription if any
if unsubscribe, exists := subscriptions[sessionID]; exists && unsubscribe != nil {
unsubscribe()
delete(subscriptions, sessionID)
}
// Check if this is a remote session
var isRemoteSession *RemoteServer
if ba.config.IsHQMode && ba.config.RemoteRegistry != nil {
isRemoteSession = ba.config.RemoteRegistry.GetRemoteBySessionID(sessionID)
}
if isRemoteSession != nil {
// Subscribe to remote session
ba.subscribeToRemoteSession(conn, sessionID, isRemoteSession.ID)
} else {
// Subscribe to local session
ba.subscribeToLocalSession(conn, sessionID)
}
conn.WriteJSON(map[string]interface{}{
"type": "subscribed",
"sessionId": sessionID,
})
log.Printf("[BufferAggregator] Client subscribed to session %s", sessionID)
}
// subscribeToLocalSession subscribes a client to a local session
func (ba *BufferAggregator) subscribeToLocalSession(conn *websocket.Conn, sessionID string) {
// Subscribe to buffer changes
unsubscribe := ba.config.TerminalManager.SubscribeToBufferChanges(sessionID, func(data []byte) {
// Send buffer update to client
ba.sendBufferToClient(conn, sessionID, data)
})
ba.mu.Lock()
if subscriptions, ok := ba.clientSubscriptions[conn]; ok {
subscriptions[sessionID] = unsubscribe
}
ba.mu.Unlock()
// Send initial buffer
if buffer, err := ba.config.TerminalManager.GetBufferSnapshot(sessionID); err == nil {
ba.sendBufferToClient(conn, sessionID, buffer)
}
}
// subscribeToRemoteSession subscribes a client to a remote session
func (ba *BufferAggregator) subscribeToRemoteSession(conn *websocket.Conn, sessionID, remoteID string) {
// Ensure we have a connection to this remote
remoteConn := ba.ensureRemoteConnection(remoteID)
if remoteConn == nil {
conn.WriteJSON(map[string]interface{}{
"type": "error",
"message": "Failed to connect to remote server",
})
return
}
// Subscribe to the session on the remote
remoteConn.Subscriptions[sessionID] = true
remoteConn.WS.WriteJSON(map[string]interface{}{
"type": "subscribe",
"sessionId": sessionID,
})
// Store an unsubscribe function
ba.mu.Lock()
if subscriptions, ok := ba.clientSubscriptions[conn]; ok {
subscriptions[sessionID] = func() {
// Will be handled in unsubscribe
}
}
ba.mu.Unlock()
}
// ensureRemoteConnection ensures we have a WebSocket connection to a remote server
func (ba *BufferAggregator) ensureRemoteConnection(remoteID string) *RemoteWebSocketConnection {
ba.mu.RLock()
remoteConn := ba.remoteConnections[remoteID]
ba.mu.RUnlock()
if remoteConn != nil && remoteConn.WS != nil {
return remoteConn
}
// Need to connect
remote := ba.config.RemoteRegistry.GetRemote(remoteID)
if remote == nil {
return nil
}
// Create WebSocket URL from HTTP URL
wsURL := remote.URL
if len(wsURL) > 4 && wsURL[:4] == "http" {
wsURL = "ws" + wsURL[4:]
}
// Connect with Bearer auth
header := http.Header{}
header.Set("Authorization", "Bearer "+remote.Token)
dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
ws, _, err := dialer.Dial(wsURL, header)
if err != nil {
log.Printf("[BufferAggregator] Failed to connect to remote %s: %v", remote.Name, err)
return nil
}
remoteConn = &RemoteWebSocketConnection{
WS: ws,
RemoteID: remote.ID,
RemoteName: remote.Name,
Subscriptions: make(map[string]bool),
}
ba.mu.Lock()
ba.remoteConnections[remoteID] = remoteConn
ba.mu.Unlock()
// Handle messages from remote
go ba.handleRemoteMessages(remoteConn)
log.Printf("[BufferAggregator] Connected to remote %s", remote.Name)
return remoteConn
}
// handleRemoteMessages handles messages from a remote server
func (ba *BufferAggregator) handleRemoteMessages(remoteConn *RemoteWebSocketConnection) {
defer func() {
ba.mu.Lock()
delete(ba.remoteConnections, remoteConn.RemoteID)
ba.mu.Unlock()
remoteConn.WS.Close()
}()
for {
messageType, data, err := remoteConn.WS.ReadMessage()
if err != nil {
log.Printf("[BufferAggregator] Remote %s disconnected: %v", remoteConn.RemoteName, err)
break
}
if messageType == websocket.BinaryMessage && len(data) > 0 && data[0] == 0xbf {
// Binary buffer update - forward to subscribed clients
ba.forwardBufferToClients(data)
} else if messageType == websocket.TextMessage {
// JSON message
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err == nil {
log.Printf("[BufferAggregator] Remote %s message: %v", remoteConn.RemoteName, msg["type"])
}
}
}
}
// sendBufferToClient sends a buffer update to a specific client
func (ba *BufferAggregator) sendBufferToClient(conn *websocket.Conn, sessionID string, buffer []byte) {
// Create binary message with session ID
sessionIDBytes := []byte(sessionID)
totalLen := 1 + 4 + len(sessionIDBytes) + len(buffer)
fullBuffer := make([]byte, totalLen)
offset := 0
fullBuffer[offset] = 0xbf // Magic byte
offset++
// Session ID length (little-endian)
fullBuffer[offset] = byte(len(sessionIDBytes))
fullBuffer[offset+1] = byte(len(sessionIDBytes) >> 8)
fullBuffer[offset+2] = byte(len(sessionIDBytes) >> 16)
fullBuffer[offset+3] = byte(len(sessionIDBytes) >> 24)
offset += 4
// Session ID
copy(fullBuffer[offset:], sessionIDBytes)
offset += len(sessionIDBytes)
// Buffer data
copy(fullBuffer[offset:], buffer)
conn.WriteMessage(websocket.BinaryMessage, fullBuffer)
}
// forwardBufferToClients forwards a buffer update from a remote to subscribed clients
func (ba *BufferAggregator) forwardBufferToClients(data []byte) {
// Extract session ID from buffer
if len(data) < 5 {
return
}
sessionIDLen := int(data[1]) | int(data[2])<<8 | int(data[3])<<16 | int(data[4])<<24
if len(data) < 5+sessionIDLen {
return
}
sessionID := string(data[5 : 5+sessionIDLen])
// Forward to all clients subscribed to this session
ba.mu.RLock()
defer ba.mu.RUnlock()
for conn, subscriptions := range ba.clientSubscriptions {
if _, subscribed := subscriptions[sessionID]; subscribed {
conn.WriteMessage(websocket.BinaryMessage, data)
}
}
}
// handleUnsubscribe handles unsubscribe requests
func (ba *BufferAggregator) handleUnsubscribe(conn *websocket.Conn, sessionID string) {
ba.mu.Lock()
subscriptions := ba.clientSubscriptions[conn]
ba.mu.Unlock()
if subscriptions == nil {
return
}
if unsubscribe, exists := subscriptions[sessionID]; exists && unsubscribe != nil {
unsubscribe()
delete(subscriptions, sessionID)
}
// Also unsubscribe from remote if applicable
if ba.config.IsHQMode && ba.config.RemoteRegistry != nil {
remote := ba.config.RemoteRegistry.GetRemoteBySessionID(sessionID)
if remote != nil {
ba.mu.RLock()
remoteConn := ba.remoteConnections[remote.ID]
ba.mu.RUnlock()
if remoteConn != nil {
delete(remoteConn.Subscriptions, sessionID)
remoteConn.WS.WriteJSON(map[string]interface{}{
"type": "unsubscribe",
"sessionId": sessionID,
})
}
}
}
log.Printf("[BufferAggregator] Client unsubscribed from session %s", sessionID)
}
// handleClientDisconnect handles client disconnection
func (ba *BufferAggregator) handleClientDisconnect(conn *websocket.Conn) {
ba.mu.Lock()
subscriptions := ba.clientSubscriptions[conn]
delete(ba.clientSubscriptions, conn)
ba.mu.Unlock()
// Unsubscribe from all sessions
for _, unsubscribe := range subscriptions {
if unsubscribe != nil {
unsubscribe()
}
}
log.Printf("[BufferAggregator] Client disconnected")
}
// Stop gracefully stops the buffer aggregator
func (ba *BufferAggregator) Stop() {
// Close all client connections
ba.mu.Lock()
for conn := range ba.clientSubscriptions {
conn.Close()
}
ba.clientSubscriptions = make(map[*websocket.Conn]map[string]func())
// Close all remote connections
for _, remoteConn := range ba.remoteConnections {
remoteConn.WS.Close()
}
ba.remoteConnections = make(map[string]*RemoteWebSocketConnection)
ba.mu.Unlock()
}