From f96b63c77dbcbb6ae53549bdd7a691ee3fe7e39f Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 20 Jun 2025 15:11:26 +0200 Subject: [PATCH] Port node implementation details to go --- linux/go.mod | 2 +- linux/pkg/protocol/asciinema.go | 72 +++++--- linux/pkg/protocol/escape_parser.go | 244 ++++++++++++++++++++++++++++ linux/pkg/session/errors.go | 165 +++++++++++++++++++ linux/pkg/session/process.go | 157 ++++++++++++++++++ linux/pkg/session/pty.go | 162 ++++++++++-------- linux/pkg/session/select.go | 12 +- linux/pkg/session/session.go | 63 ++++--- linux/pkg/session/stdin_watcher.go | 153 +++++++++++++++++ linux/pkg/session/terminal.go | 133 +++++++++++++++ 10 files changed, 1042 insertions(+), 121 deletions(-) create mode 100644 linux/pkg/protocol/escape_parser.go create mode 100644 linux/pkg/session/errors.go create mode 100644 linux/pkg/session/process.go create mode 100644 linux/pkg/session/stdin_watcher.go create mode 100644 linux/pkg/session/terminal.go diff --git a/linux/go.mod b/linux/go.mod index 5c2a9dc1..cecb1cac 100644 --- a/linux/go.mod +++ b/linux/go.mod @@ -15,6 +15,7 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 golang.ngrok.com/ngrok v1.13.0 + golang.org/x/sys v0.33.0 golang.org/x/term v0.32.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -51,7 +52,6 @@ require ( golang.org/x/mod v0.25.0 // indirect golang.org/x/net v0.41.0 // indirect golang.org/x/sync v0.15.0 // indirect - golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.26.0 // indirect golang.org/x/tools v0.34.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/linux/pkg/protocol/asciinema.go b/linux/pkg/protocol/asciinema.go index bed8d535..1b762237 100644 --- a/linux/pkg/protocol/asciinema.go +++ b/linux/pkg/protocol/asciinema.go @@ -42,25 +42,27 @@ type StreamEvent struct { } type StreamWriter struct { - writer io.Writer - header *AsciinemaHeader - startTime time.Time - mutex sync.Mutex - closed bool - buffer []byte - lastWrite time.Time - flushTimer *time.Timer - syncTimer *time.Timer - needsSync bool + writer io.Writer + header *AsciinemaHeader + startTime time.Time + mutex sync.Mutex + closed bool + buffer []byte + escapeParser *EscapeParser + lastWrite time.Time + flushTimer *time.Timer + syncTimer *time.Timer + needsSync bool } func NewStreamWriter(writer io.Writer, header *AsciinemaHeader) *StreamWriter { return &StreamWriter{ - writer: writer, - header: header, - startTime: time.Now(), - buffer: make([]byte, 0, 4096), - lastWrite: time.Now(), + writer: writer, + header: header, + startTime: time.Now(), + buffer: make([]byte, 0, 4096), + escapeParser: NewEscapeParser(), + lastWrite: time.Now(), } } @@ -106,22 +108,24 @@ func (w *StreamWriter) writeEvent(eventType EventType, data []byte) error { return fmt.Errorf("stream writer closed") } - w.buffer = append(w.buffer, data...) w.lastWrite = time.Now() - completeData, remaining := extractCompleteUTF8(w.buffer) + // Use escape parser to ensure escape sequences are not split + processedData, remaining := w.escapeParser.ProcessData(data) + + // Update buffer with any remaining incomplete sequences w.buffer = remaining - if len(completeData) == 0 { - // If we have incomplete UTF-8 data, set up a timer to flush it after a short delay - if len(w.buffer) > 0 { + if len(processedData) == 0 { + // If we have incomplete data, set up a timer to flush it after a short delay + if len(w.buffer) > 0 || w.escapeParser.BufferSize() > 0 { w.scheduleFlush() } return nil } elapsed := time.Since(w.startTime).Seconds() - event := []interface{}{elapsed, string(eventType), string(completeData)} + event := []interface{}{elapsed, string(eventType), string(processedData)} eventData, err := json.Marshal(event) if err != nil { @@ -151,13 +155,25 @@ func (w *StreamWriter) scheduleFlush() { w.mutex.Lock() defer w.mutex.Unlock() - if w.closed || len(w.buffer) == 0 { + if w.closed { return } - // Force flush incomplete UTF-8 data for real-time streaming + // Flush any buffered data from escape parser + flushedData := w.escapeParser.Flush() + if len(flushedData) == 0 && len(w.buffer) == 0 { + return + } + + // Combine flushed data with any remaining buffer + dataToWrite := append(flushedData, w.buffer...) + if len(dataToWrite) == 0 { + return + } + + // Force flush incomplete data for real-time streaming elapsed := time.Since(w.startTime).Seconds() - event := []interface{}{elapsed, string(EventOutput), string(w.buffer)} + event := []interface{}{elapsed, string(EventOutput), string(dataToWrite)} eventData, err := json.Marshal(event) if err != nil { @@ -218,9 +234,13 @@ func (w *StreamWriter) Close() error { w.syncTimer.Stop() } - if len(w.buffer) > 0 { + // Flush any remaining data from escape parser + flushedData := w.escapeParser.Flush() + finalData := append(flushedData, w.buffer...) + + if len(finalData) > 0 { elapsed := time.Since(w.startTime).Seconds() - event := []interface{}{elapsed, string(EventOutput), string(w.buffer)} + event := []interface{}{elapsed, string(EventOutput), string(finalData)} eventData, _ := json.Marshal(event) if _, err := fmt.Fprintf(w.writer, "%s\n", eventData); err != nil { // Write failed during close - log to stderr to avoid deadlock diff --git a/linux/pkg/protocol/escape_parser.go b/linux/pkg/protocol/escape_parser.go new file mode 100644 index 00000000..c36b0ac7 --- /dev/null +++ b/linux/pkg/protocol/escape_parser.go @@ -0,0 +1,244 @@ +package protocol + +import ( + "unicode/utf8" +) + +// EscapeParser handles parsing of terminal escape sequences and UTF-8 data +// This ensures escape sequences are not split across chunks +type EscapeParser struct { + buffer []byte +} + +// NewEscapeParser creates a new escape sequence parser +func NewEscapeParser() *EscapeParser { + return &EscapeParser{ + buffer: make([]byte, 0, 4096), + } +} + +// ProcessData processes terminal data ensuring escape sequences and UTF-8 are not split +// Returns processed data and any remaining incomplete sequences +func (p *EscapeParser) ProcessData(data []byte) (processed []byte, remaining []byte) { + // Combine buffered data with new data + combined := append(p.buffer, data...) + p.buffer = p.buffer[:0] // Clear buffer without reallocating + + result := make([]byte, 0, len(combined)) + pos := 0 + + for pos < len(combined) { + // Check for escape sequence + if combined[pos] == 0x1b { // ESC character + seqEnd := p.findEscapeSequenceEnd(combined[pos:]) + if seqEnd == -1 { + // Incomplete escape sequence, save for next time + p.buffer = append(p.buffer, combined[pos:]...) + break + } + // Include complete escape sequence + result = append(result, combined[pos:pos+seqEnd]...) + pos += seqEnd + continue + } + + // Process UTF-8 character + r, size := utf8.DecodeRune(combined[pos:]) + if r == utf8.RuneError { + if size == 0 { + // No more data + break + } + if size == 1 && pos+4 > len(combined) { + // Might be incomplete UTF-8 at end of buffer + if p.mightBeIncompleteUTF8(combined[pos:]) { + p.buffer = append(p.buffer, combined[pos:]...) + break + } + } + // Invalid UTF-8, skip byte + result = append(result, combined[pos]) + pos++ + continue + } + + // Valid UTF-8 character + result = append(result, combined[pos:pos+size]...) + pos += size + } + + return result, p.buffer +} + +// findEscapeSequenceEnd finds the end of an ANSI escape sequence +// Returns -1 if sequence is incomplete +func (p *EscapeParser) findEscapeSequenceEnd(data []byte) int { + if len(data) == 0 || data[0] != 0x1b { + return -1 + } + + if len(data) < 2 { + return -1 // Need more data + } + + switch data[1] { + case '[': // CSI sequence: ESC [ ... final_char + pos := 2 + for pos < len(data) { + b := data[pos] + if b >= 0x20 && b <= 0x3f { + // Parameter and intermediate characters + pos++ + } else if b >= 0x40 && b <= 0x7e { + // Final character found + return pos + 1 + } else { + // Invalid sequence + return pos + } + } + return -1 // Incomplete + + case ']': // OSC sequence: ESC ] ... (ST or BEL) + pos := 2 + for pos < len(data) { + if data[pos] == 0x07 { // BEL terminator + return pos + 1 + } + if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' { + // ESC \ (ST) terminator + return pos + 2 + } + pos++ + } + return -1 // Incomplete + + case '(', ')', '*', '+': // Charset selection + if len(data) < 3 { + return -1 + } + return 3 + + case 'P', 'X', '^', '_': // DCS, SOS, PM, APC sequences + // These need special termination sequences + pos := 2 + for pos < len(data) { + if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' { + // ESC \ (ST) terminator + return pos + 2 + } + pos++ + } + return -1 // Incomplete + + default: + // Simple two-character sequences + return 2 + } +} + +// mightBeIncompleteUTF8 checks if data might be an incomplete UTF-8 sequence +func (p *EscapeParser) mightBeIncompleteUTF8(data []byte) bool { + if len(data) == 0 { + return false + } + + b := data[0] + + // Single byte (ASCII) + if b < 0x80 { + return false + } + + // Multi-byte sequence starters + if b >= 0xc0 { + if b < 0xe0 { + // 2-byte sequence + return len(data) < 2 + } + if b < 0xf0 { + // 3-byte sequence + return len(data) < 3 + } + if b < 0xf8 { + // 4-byte sequence + return len(data) < 4 + } + } + + return false +} + +// Flush returns any buffered data (for use when closing) +func (p *EscapeParser) Flush() []byte { + if len(p.buffer) == 0 { + return nil + } + // Return buffered data as-is when flushing + result := make([]byte, len(p.buffer)) + copy(result, p.buffer) + p.buffer = p.buffer[:0] + return result +} + +// Reset clears the parser state +func (p *EscapeParser) Reset() { + p.buffer = p.buffer[:0] +} + +// BufferSize returns the current buffer size +func (p *EscapeParser) BufferSize() int { + return len(p.buffer) +} + +// SplitEscapeSequences splits data at escape sequence boundaries +// This is useful for processing data in chunks without splitting sequences +func SplitEscapeSequences(data []byte) [][]byte { + if len(data) == 0 { + return nil + } + + var chunks [][]byte + parser := NewEscapeParser() + + processed, remaining := parser.ProcessData(data) + if len(processed) > 0 { + chunks = append(chunks, processed) + } + if len(remaining) > 0 { + chunks = append(chunks, remaining) + } + + return chunks +} + +// IsCompleteEscapeSequence checks if data contains a complete escape sequence +func IsCompleteEscapeSequence(data []byte) bool { + if len(data) == 0 || data[0] != 0x1b { + return false + } + parser := NewEscapeParser() + end := parser.findEscapeSequenceEnd(data) + return end > 0 && end == len(data) +} + +// StripEscapeSequences removes all ANSI escape sequences from data +func StripEscapeSequences(data []byte) []byte { + result := make([]byte, 0, len(data)) + pos := 0 + + parser := NewEscapeParser() + for pos < len(data) { + if data[pos] == 0x1b { + seqEnd := parser.findEscapeSequenceEnd(data[pos:]) + if seqEnd > 0 { + pos += seqEnd + continue + } + } + result = append(result, data[pos]) + pos++ + } + + return result +} \ No newline at end of file diff --git a/linux/pkg/session/errors.go b/linux/pkg/session/errors.go new file mode 100644 index 00000000..85c9cf1a --- /dev/null +++ b/linux/pkg/session/errors.go @@ -0,0 +1,165 @@ +package session + +import ( + "fmt" +) + +// ErrorCode represents standardized error codes matching Node.js implementation +type ErrorCode string + +const ( + // Session-related errors + ErrSessionNotFound ErrorCode = "SESSION_NOT_FOUND" + ErrSessionAlreadyExists ErrorCode = "SESSION_ALREADY_EXISTS" + ErrSessionStartFailed ErrorCode = "SESSION_START_FAILED" + ErrSessionNotRunning ErrorCode = "SESSION_NOT_RUNNING" + + // Process-related errors + ErrProcessNotFound ErrorCode = "PROCESS_NOT_FOUND" + ErrProcessSignalFailed ErrorCode = "PROCESS_SIGNAL_FAILED" + ErrProcessTerminateFailed ErrorCode = "PROCESS_TERMINATE_FAILED" + + // I/O related errors + ErrStdinNotFound ErrorCode = "STDIN_NOT_FOUND" + ErrStdinWriteFailed ErrorCode = "STDIN_WRITE_FAILED" + ErrStreamReadFailed ErrorCode = "STREAM_READ_FAILED" + ErrStreamWriteFailed ErrorCode = "STREAM_WRITE_FAILED" + + // PTY-related errors + ErrPTYCreationFailed ErrorCode = "PTY_CREATION_FAILED" + ErrPTYConfigFailed ErrorCode = "PTY_CONFIG_FAILED" + ErrPTYResizeFailed ErrorCode = "PTY_RESIZE_FAILED" + + // Control-related errors + ErrControlPathNotFound ErrorCode = "CONTROL_PATH_NOT_FOUND" + ErrControlFileCorrupted ErrorCode = "CONTROL_FILE_CORRUPTED" + + // Input-related errors + ErrUnknownKey ErrorCode = "UNKNOWN_KEY" + ErrInvalidInput ErrorCode = "INVALID_INPUT" + + // General errors + ErrInvalidArgument ErrorCode = "INVALID_ARGUMENT" + ErrPermissionDenied ErrorCode = "PERMISSION_DENIED" + ErrTimeout ErrorCode = "TIMEOUT" + ErrInternal ErrorCode = "INTERNAL_ERROR" +) + +// SessionError represents an error with context, matching Node.js PtyError +type SessionError struct { + Message string + Code ErrorCode + SessionID string + Cause error +} + +// Error implements the error interface +func (e *SessionError) Error() string { + if e.SessionID != "" { + return fmt.Sprintf("%s (session: %s, code: %s)", e.Message, e.SessionID[:8], e.Code) + } + return fmt.Sprintf("%s (code: %s)", e.Message, e.Code) +} + +// Unwrap returns the underlying cause +func (e *SessionError) Unwrap() error { + return e.Cause +} + +// NewSessionError creates a new SessionError +func NewSessionError(message string, code ErrorCode, sessionID string) *SessionError { + return &SessionError{ + Message: message, + Code: code, + SessionID: sessionID, + } +} + +// NewSessionErrorWithCause creates a new SessionError with an underlying cause +func NewSessionErrorWithCause(message string, code ErrorCode, sessionID string, cause error) *SessionError { + return &SessionError{ + Message: message, + Code: code, + SessionID: sessionID, + Cause: cause, + } +} + +// WrapError wraps an existing error with session context +func WrapError(err error, code ErrorCode, sessionID string) *SessionError { + if err == nil { + return nil + } + + // If it's already a SessionError, preserve the original but add context + if se, ok := err.(*SessionError); ok { + return &SessionError{ + Message: se.Message, + Code: code, + SessionID: sessionID, + Cause: se, + } + } + + return &SessionError{ + Message: err.Error(), + Code: code, + SessionID: sessionID, + Cause: err, + } +} + +// IsSessionError checks if an error is a SessionError with a specific code +func IsSessionError(err error, code ErrorCode) bool { + se, ok := err.(*SessionError) + return ok && se.Code == code +} + +// GetSessionID extracts the session ID from an error if it's a SessionError +func GetSessionID(err error) string { + if se, ok := err.(*SessionError); ok { + return se.SessionID + } + return "" +} + +// Common error constructors for convenience + +// ErrSessionNotFoundError creates a session not found error +func ErrSessionNotFoundError(sessionID string) *SessionError { + return NewSessionError( + fmt.Sprintf("Session %s not found", sessionID[:8]), + ErrSessionNotFound, + sessionID, + ) +} + +// ErrProcessSignalError creates a process signal error +func ErrProcessSignalError(sessionID string, signal string, cause error) *SessionError { + return NewSessionErrorWithCause( + fmt.Sprintf("Failed to send signal %s to session", signal), + ErrProcessSignalFailed, + sessionID, + cause, + ) +} + +// ErrPTYCreationError creates a PTY creation error +func ErrPTYCreationError(sessionID string, cause error) *SessionError { + return NewSessionErrorWithCause( + "Failed to create PTY", + ErrPTYCreationFailed, + sessionID, + cause, + ) +} + +// ErrStdinWriteError creates a stdin write error +func ErrStdinWriteError(sessionID string, cause error) *SessionError { + return NewSessionErrorWithCause( + "Failed to write to stdin", + ErrStdinWriteFailed, + sessionID, + cause, + ) +} \ No newline at end of file diff --git a/linux/pkg/session/process.go b/linux/pkg/session/process.go new file mode 100644 index 00000000..14cfa94f --- /dev/null +++ b/linux/pkg/session/process.go @@ -0,0 +1,157 @@ +package session + +import ( + "log" + "os" + "time" +) + +// ProcessTerminator provides graceful process termination with timeout +// Matches the Node.js implementation behavior +type ProcessTerminator struct { + session *Session + gracefulTimeout time.Duration + checkInterval time.Duration +} + +// NewProcessTerminator creates a new process terminator +func NewProcessTerminator(session *Session) *ProcessTerminator { + return &ProcessTerminator{ + session: session, + gracefulTimeout: 3 * time.Second, // Match Node.js 3 second timeout + checkInterval: 500 * time.Millisecond, // Match Node.js 500ms check interval + } +} + +// TerminateGracefully attempts graceful termination with escalation to SIGKILL +// This matches the Node.js implementation behavior: +// 1. Send SIGTERM +// 2. Wait up to 3 seconds for graceful termination +// 3. Send SIGKILL if process is still alive +func (pt *ProcessTerminator) TerminateGracefully() error { + sessionID := pt.session.ID[:8] + pid := pt.session.info.Pid + + // Check if already exited + if pt.session.info.Status == string(StatusExited) { + debugLog("[DEBUG] ProcessTerminator: Session %s already exited", sessionID) + pt.session.cleanup() + return nil + } + + if pid == 0 { + return NewSessionError("no process to terminate", ErrProcessNotFound, pt.session.ID) + } + + log.Printf("[INFO] Terminating session %s (PID: %d) with SIGTERM...", sessionID, pid) + + // Send SIGTERM first + if err := pt.session.Signal("SIGTERM"); err != nil { + // If process doesn't exist, that's fine + if !pt.session.IsAlive() { + log.Printf("[INFO] Session %s already terminated", sessionID) + pt.session.cleanup() + return nil + } + // If it's already a SessionError, return as-is + if se, ok := err.(*SessionError); ok { + return se + } + return NewSessionErrorWithCause("failed to send SIGTERM", ErrProcessTerminateFailed, pt.session.ID, err) + } + + // Wait for graceful termination + startTime := time.Now() + checkCount := 0 + maxChecks := int(pt.gracefulTimeout / pt.checkInterval) + + for checkCount < maxChecks { + // Wait for check interval + time.Sleep(pt.checkInterval) + checkCount++ + + // Check if process is still alive + if !pt.session.IsAlive() { + elapsed := time.Since(startTime) + log.Printf("[INFO] Session %s terminated gracefully after %dms", sessionID, elapsed.Milliseconds()) + pt.session.cleanup() + return nil + } + + // Log progress + elapsed := time.Since(startTime) + log.Printf("[INFO] Session %s still alive after %dms...", sessionID, elapsed.Milliseconds()) + } + + // Process didn't terminate gracefully, force kill + log.Printf("[INFO] Session %s didn't terminate gracefully, sending SIGKILL...", sessionID) + + if err := pt.session.Signal("SIGKILL"); err != nil { + // If process doesn't exist anymore, that's fine + if !pt.session.IsAlive() { + log.Printf("[INFO] Session %s terminated before SIGKILL", sessionID) + pt.session.cleanup() + return nil + } + // If it's already a SessionError, return as-is + if se, ok := err.(*SessionError); ok { + return se + } + return NewSessionErrorWithCause("failed to send SIGKILL", ErrProcessTerminateFailed, pt.session.ID, err) + } + + // Wait a bit for SIGKILL to take effect + time.Sleep(100 * time.Millisecond) + + if pt.session.IsAlive() { + log.Printf("[WARN] Session %s may still be alive after SIGKILL", sessionID) + } else { + log.Printf("[INFO] Session %s forcefully terminated with SIGKILL", sessionID) + } + + pt.session.cleanup() + return nil +} + +// waitForProcessExit waits for a process to exit with timeout +// Returns true if process exited within timeout, false otherwise +func waitForProcessExit(pid int, timeout time.Duration) bool { + startTime := time.Now() + checkInterval := 100 * time.Millisecond + + for time.Since(startTime) < timeout { + // Try to find the process + proc, err := os.FindProcess(pid) + if err != nil { + // Process doesn't exist + return true + } + + // Check if process is alive using signal 0 + if err := proc.Signal(os.Signal(nil)); err != nil { + // Process doesn't exist or we don't have permission + return true + } + + time.Sleep(checkInterval) + } + + return false +} + +// isProcessRunning checks if a process is running by PID +// Uses platform-appropriate methods +func isProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + + // On Unix, signal 0 checks if process exists + err = proc.Signal(os.Signal(nil)) + return err == nil +} \ No newline at end of file diff --git a/linux/pkg/session/pty.go b/linux/pkg/session/pty.go index d8372007..6d04f3c9 100644 --- a/linux/pkg/session/pty.go +++ b/linux/pkg/session/pty.go @@ -22,13 +22,14 @@ import ( const useSelectPolling = true type PTY struct { - session *Session - cmd *exec.Cmd - pty *os.File - oldState *term.State - streamWriter *protocol.StreamWriter - stdinPipe *os.File - resizeMutex sync.Mutex + session *Session + cmd *exec.Cmd + pty *os.File + oldState *term.State + streamWriter *protocol.StreamWriter + stdinPipe *os.File + useEventDrivenStdin bool + resizeMutex sync.Mutex } func NewPTY(session *Session) (*PTY, error) { @@ -53,7 +54,12 @@ func NewPTY(session *Session) (*PTY, error) { // Verify the directory exists and is accessible if _, err := os.Stat(session.info.Cwd); err != nil { log.Printf("[ERROR] NewPTY: Working directory '%s' not accessible: %v", session.info.Cwd, err) - return nil, fmt.Errorf("working directory '%s' not accessible: %w", session.info.Cwd, err) + return nil, NewSessionErrorWithCause( + fmt.Sprintf("working directory '%s' not accessible", session.info.Cwd), + ErrInvalidArgument, + session.ID, + err, + ) } cmd.Dir = session.info.Cwd debugLog("[DEBUG] NewPTY: Set working directory to: %s", session.info.Cwd) @@ -101,7 +107,7 @@ func NewPTY(session *Session) (*PTY, error) { ptmx, err := pty.Start(cmd) if err != nil { log.Printf("[ERROR] NewPTY: Failed to start PTY: %v", err) - return nil, fmt.Errorf("failed to start PTY: %w", err) + return nil, ErrPTYCreationError(session.ID, err) } debugLog("[DEBUG] NewPTY: PTY started successfully, PID: %d", cmd.Process.Pid) @@ -110,10 +116,15 @@ func NewPTY(session *Session) (*PTY, error) { debugLog("[DEBUG] NewPTY: Executing command: %v in directory: %s", cmdline, cmd.Dir) debugLog("[DEBUG] NewPTY: Environment has %d variables", len(cmd.Env)) - if err := pty.Setsize(ptmx, &pty.Winsize{ - Rows: uint16(session.info.Height), - Cols: uint16(session.info.Width), - }); err != nil { + // Configure terminal attributes to match node-pty behavior + // This must be done before setting size and after the process starts + if err := configurePTYTerminal(ptmx); err != nil { + log.Printf("[ERROR] NewPTY: Failed to configure PTY terminal: %v", err) + // Don't fail on terminal configuration errors, just log them + } + + // Set PTY size using our enhanced function + if err := setPTYSize(ptmx, uint16(session.info.Width), uint16(session.info.Height)); err != nil { log.Printf("[ERROR] NewPTY: Failed to set PTY size: %v", err) if err := ptmx.Close(); err != nil { log.Printf("[ERROR] NewPTY: Failed to close PTY: %v", err) @@ -121,13 +132,15 @@ func NewPTY(session *Session) (*PTY, error) { if err := cmd.Process.Kill(); err != nil { log.Printf("[ERROR] NewPTY: Failed to kill process: %v", err) } - return nil, fmt.Errorf("failed to set PTY size: %w", err) + return nil, NewSessionErrorWithCause( + "failed to set PTY size", + ErrPTYResizeFailed, + session.ID, + err, + ) } - // Configure terminal modes for proper interactive shell behavior - // The creack/pty library handles basic setup, but we ensure the terminal - // is in the correct mode for interactive use (not raw mode) - debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode") + debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode with flow control") streamOut, err := os.Create(session.StreamOutPath()) if err != nil { @@ -209,19 +222,32 @@ func (p *PTY) Run() error { debugLog("[DEBUG] PTY.Run: Starting PTY run for session %s, PID %d", p.session.ID[:8], p.cmd.Process.Pid) - stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0) + // Use event-driven stdin handling like Node.js + stdinWatcher, err := NewStdinWatcher(p.session.StdinPath(), p.pty) if err != nil { - log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err) - return fmt.Errorf("failed to open stdin pipe: %w", err) - } - defer func() { - if err := stdinPipe.Close(); err != nil { - log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err) + // Fall back to polling if watcher fails + log.Printf("[WARN] PTY.Run: Failed to create stdin watcher, falling back to polling: %v", err) + + stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err) + return fmt.Errorf("failed to open stdin pipe: %w", err) } - }() - p.stdinPipe = stdinPipe + defer func() { + if err := stdinPipe.Close(); err != nil { + log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err) + } + }() + p.stdinPipe = stdinPipe + } else { + // Start the watcher + stdinWatcher.Start() + defer stdinWatcher.Stop() + p.useEventDrivenStdin = true + debugLog("[DEBUG] PTY.Run: Using event-driven stdin handling") + } - debugLog("[DEBUG] PTY.Run: Stdin pipe opened successfully") + debugLog("[DEBUG] PTY.Run: Stdin handling initialized") // Set up SIGWINCH handling for terminal resize winchCh := make(chan os.Signal, 1) @@ -236,10 +262,7 @@ func (p *PTY) Run() error { width, height, err := term.GetSize(int(os.Stdin.Fd())) if err == nil { debugLog("[DEBUG] PTY.Run: Received SIGWINCH, resizing to %dx%d", width, height) - if err := pty.Setsize(p.pty, &pty.Winsize{ - Rows: uint16(height), - Cols: uint16(width), - }); err != nil { + if err := setPTYSize(p.pty, uint16(width), uint16(height)); err != nil { log.Printf("[ERROR] PTY.Run: Failed to resize PTY: %v", err) } else { // Update session info @@ -301,45 +324,48 @@ func (p *PTY) Run() error { } }() - go func() { - debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine") - buf := make([]byte, 4096) - for { - n, err := stdinPipe.Read(buf) - if n > 0 { - debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n) - if _, err := p.pty.Write(buf[:n]); err != nil { - log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err) - // Only exit if the PTY is really broken, not on temporary errors - if err != syscall.EPIPE && err != syscall.ECONNRESET { - errCh <- fmt.Errorf("failed to write to PTY: %w", err) - return + // Only start stdin goroutine if not using event-driven mode + if !p.useEventDrivenStdin && p.stdinPipe != nil { + go func() { + debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine") + buf := make([]byte, 4096) + for { + n, err := p.stdinPipe.Read(buf) + if n > 0 { + debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n) + if _, err := p.pty.Write(buf[:n]); err != nil { + log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err) + // Only exit if the PTY is really broken, not on temporary errors + if err != syscall.EPIPE && err != syscall.ECONNRESET { + errCh <- fmt.Errorf("failed to write to PTY: %w", err) + return + } + // For broken pipe, just continue - the PTY might be closing + debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...") + time.Sleep(10 * time.Millisecond) } - // For broken pipe, just continue - the PTY might be closing - debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...") - time.Sleep(10 * time.Millisecond) + // Continue immediately after successful write + continue + } + if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { + // No data available, longer pause to prevent excessive CPU usage + time.Sleep(10 * time.Millisecond) + continue + } + if err == io.EOF { + // No writers to the FIFO yet, longer pause before retry + time.Sleep(50 * time.Millisecond) + continue + } + if err != nil { + // Log other errors but don't crash the session - stdin issues shouldn't kill the PTY + log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err) + time.Sleep(10 * time.Millisecond) + continue } - // Continue immediately after successful write - continue } - if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { - // No data available, longer pause to prevent excessive CPU usage - time.Sleep(10 * time.Millisecond) - continue - } - if err == io.EOF { - // No writers to the FIFO yet, longer pause before retry - time.Sleep(50 * time.Millisecond) - continue - } - if err != nil { - // Log other errors but don't crash the session - stdin issues shouldn't kill the PTY - log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err) - time.Sleep(10 * time.Millisecond) - continue - } - } - }() + }() + } go func() { debugLog("[DEBUG] PTY.Run: Starting process wait goroutine for PID %d", p.cmd.Process.Pid) diff --git a/linux/pkg/session/select.go b/linux/pkg/session/select.go index 61c4ea6b..3a712fb8 100644 --- a/linux/pkg/session/select.go +++ b/linux/pkg/session/select.go @@ -74,7 +74,12 @@ func (p *PTY) pollWithSelect() error { // Get file descriptors ptyFd := int(p.pty.Fd()) - stdinFd := int(p.stdinPipe.Fd()) + var stdinFd int = -1 + + // Only include stdin in polling if not using event-driven mode + if !p.useEventDrivenStdin && p.stdinPipe != nil { + stdinFd = int(p.stdinPipe.Fd()) + } // Open control FIFO in non-blocking mode controlPath := filepath.Join(p.session.Path(), "control") @@ -93,7 +98,10 @@ func (p *PTY) pollWithSelect() error { for { // Build FD list - fds := []int{ptyFd, stdinFd} + fds := []int{ptyFd} + if stdinFd >= 0 { + fds = append(fds, stdinFd) + } if controlFd >= 0 { fds = append(fds, controlFd) } diff --git a/linux/pkg/session/session.go b/linux/pkg/session/session.go index 1fb26d71..7332820f 100644 --- a/linux/pkg/session/session.go +++ b/linux/pkg/session/session.go @@ -353,7 +353,7 @@ func (s *Session) proxyInputToNodeJS(data []byte) error { func (s *Session) Signal(sig string) error { if s.info.Pid == 0 { - return fmt.Errorf("no process to signal") + return NewSessionError("no process to signal", ErrProcessNotFound, s.ID) } // Check if process is still alive before signaling @@ -370,21 +370,27 @@ func (s *Session) Signal(sig string) error { proc, err := os.FindProcess(s.info.Pid) if err != nil { - return err + return ErrProcessSignalError(s.ID, sig, err) } switch sig { case "SIGTERM": - return proc.Signal(os.Interrupt) + if err := proc.Signal(os.Interrupt); err != nil { + return ErrProcessSignalError(s.ID, sig, err) + } + return nil case "SIGKILL": err = proc.Kill() // If kill fails with "process already finished", that's okay if err != nil && strings.Contains(err.Error(), "process already finished") { return nil } - return err + if err != nil { + return ErrProcessSignalError(s.ID, sig, err) + } + return nil default: - return fmt.Errorf("unsupported signal: %s", sig) + return NewSessionError(fmt.Sprintf("unsupported signal: %s", sig), ErrInvalidArgument, s.ID) } } @@ -393,24 +399,29 @@ func (s *Session) Stop() error { } func (s *Session) Kill() error { - // First check if the session is already dead - if s.info.Status == string(StatusExited) { - // Already exited, just cleanup and return success + // Use graceful termination like Node.js + terminator := NewProcessTerminator(s) + return terminator.TerminateGracefully() +} + +// KillWithSignal kills the session with the specified signal +// If signal is SIGKILL, it sends it immediately without graceful termination +func (s *Session) KillWithSignal(signal string) error { + // If SIGKILL is explicitly requested, send it immediately + if signal == "SIGKILL" || signal == "9" { + err := s.Signal("SIGKILL") s.cleanup() - return nil + + // If the error is because the process doesn't exist, that's fine + if err != nil && (strings.Contains(err.Error(), "no such process") || + strings.Contains(err.Error(), "process already finished")) { + return nil + } + return err } - - // Try to kill the process - err := s.Signal("SIGKILL") - s.cleanup() - - // If the error is because the process doesn't exist, that's fine - if err != nil && (strings.Contains(err.Error(), "no such process") || - strings.Contains(err.Error(), "process already finished")) { - return nil - } - - return err + + // For other signals, use graceful termination + return s.Kill() } func (s *Session) cleanup() { @@ -427,17 +438,21 @@ func (s *Session) cleanup() { func (s *Session) Resize(width, height int) error { if s.pty == nil { - return fmt.Errorf("session not started") + return NewSessionError("session not started", ErrSessionNotRunning, s.ID) } // Check if session is still alive if s.info.Status == string(StatusExited) { - return fmt.Errorf("cannot resize exited session") + return NewSessionError("cannot resize exited session", ErrSessionNotRunning, s.ID) } // Validate dimensions if width <= 0 || height <= 0 { - return fmt.Errorf("invalid dimensions: width=%d, height=%d", width, height) + return NewSessionError( + fmt.Sprintf("invalid dimensions: width=%d, height=%d", width, height), + ErrInvalidArgument, + s.ID, + ) } // Update session info diff --git a/linux/pkg/session/stdin_watcher.go b/linux/pkg/session/stdin_watcher.go new file mode 100644 index 00000000..d9972c67 --- /dev/null +++ b/linux/pkg/session/stdin_watcher.go @@ -0,0 +1,153 @@ +package session + +import ( + "fmt" + "io" + "log" + "os" + "sync" + "syscall" + + "github.com/fsnotify/fsnotify" +) + +// StdinWatcher provides event-driven stdin handling like Node.js +type StdinWatcher struct { + stdinPath string + ptyFile *os.File + watcher *fsnotify.Watcher + stdinFile *os.File + buffer []byte + mu sync.Mutex + stopChan chan struct{} + stoppedChan chan struct{} +} + +// NewStdinWatcher creates a new stdin watcher +func NewStdinWatcher(stdinPath string, ptyFile *os.File) (*StdinWatcher, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create fsnotify watcher: %w", err) + } + + sw := &StdinWatcher{ + stdinPath: stdinPath, + ptyFile: ptyFile, + watcher: watcher, + buffer: make([]byte, 4096), + stopChan: make(chan struct{}), + stoppedChan: make(chan struct{}), + } + + // Open stdin pipe for reading + stdinFile, err := os.OpenFile(stdinPath, os.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + watcher.Close() + return nil, fmt.Errorf("failed to open stdin pipe: %w", err) + } + sw.stdinFile = stdinFile + + // Add stdin path to watcher + if err := watcher.Add(stdinPath); err != nil { + stdinFile.Close() + watcher.Close() + return nil, fmt.Errorf("failed to watch stdin pipe: %w", err) + } + + return sw, nil +} + +// Start begins watching for stdin input +func (sw *StdinWatcher) Start() { + go sw.watchLoop() +} + +// Stop stops the watcher +func (sw *StdinWatcher) Stop() { + close(sw.stopChan) + <-sw.stoppedChan + sw.cleanup() +} + +// watchLoop is the main event loop +func (sw *StdinWatcher) watchLoop() { + defer close(sw.stoppedChan) + + for { + select { + case <-sw.stopChan: + debugLog("[DEBUG] StdinWatcher: Stopping watch loop") + return + + case event, ok := <-sw.watcher.Events: + if !ok { + debugLog("[DEBUG] StdinWatcher: Watcher events channel closed") + return + } + + // Handle write events (new data available) + if event.Op&fsnotify.Write == fsnotify.Write { + sw.handleStdinData() + } + + case err, ok := <-sw.watcher.Errors: + if !ok { + debugLog("[DEBUG] StdinWatcher: Watcher errors channel closed") + return + } + log.Printf("[ERROR] StdinWatcher: Watcher error: %v", err) + } + } +} + +// handleStdinData reads available data and forwards it to the PTY +func (sw *StdinWatcher) handleStdinData() { + sw.mu.Lock() + defer sw.mu.Unlock() + + for { + n, err := sw.stdinFile.Read(sw.buffer) + if n > 0 { + // Forward data to PTY immediately + if _, writeErr := sw.ptyFile.Write(sw.buffer[:n]); writeErr != nil { + log.Printf("[ERROR] StdinWatcher: Failed to write to PTY: %v", writeErr) + return + } + debugLog("[DEBUG] StdinWatcher: Forwarded %d bytes to PTY", n) + } + + if err != nil { + if err == io.EOF || isEAGAIN(err) { + // No more data available right now + break + } + log.Printf("[ERROR] StdinWatcher: Failed to read from stdin: %v", err) + return + } + + // If we read a full buffer, there might be more data + if n == len(sw.buffer) { + continue + } + break + } +} + +// cleanup releases resources +func (sw *StdinWatcher) cleanup() { + if sw.watcher != nil { + sw.watcher.Close() + } + if sw.stdinFile != nil { + sw.stdinFile.Close() + } +} + +// isEAGAIN checks if the error is EAGAIN (resource temporarily unavailable) +func isEAGAIN(err error) bool { + if err == nil { + return false + } + // Check for EAGAIN in the error string + return err.Error() == "resource temporarily unavailable" +} \ No newline at end of file diff --git a/linux/pkg/session/terminal.go b/linux/pkg/session/terminal.go new file mode 100644 index 00000000..acc3348e --- /dev/null +++ b/linux/pkg/session/terminal.go @@ -0,0 +1,133 @@ +package session + +import ( + "fmt" + "os" + "syscall" + + "golang.org/x/sys/unix" +) + +// TerminalMode represents terminal mode settings +type TerminalMode struct { + Raw bool + Echo bool + LineMode bool + FlowControl bool +} + +// configurePTYTerminal configures the PTY terminal attributes to match node-pty behavior +// This ensures proper terminal behavior with flow control, signal handling, and line editing +func configurePTYTerminal(ptyFile *os.File) error { + fd := int(ptyFile.Fd()) + + // Get current terminal attributes + termios, err := unix.IoctlGetTermios(fd, unix.TIOCGETA) + if err != nil { + return fmt.Errorf("failed to get terminal attributes: %w", err) + } + + // Configure input flags (similar to node-pty's handleFlowControl) + // IXON: Enable start/stop output control (Ctrl+S/Ctrl+Q) + // IXOFF: Enable start/stop input control + // IXANY: Any character will restart output after stop + // ICRNL: Map CR to NL on input + termios.Iflag |= unix.IXON | unix.IXOFF | unix.IXANY | unix.ICRNL + + // Configure output flags + // OPOST: Enable output processing + // ONLCR: Map NL to CR-NL on output + termios.Oflag |= unix.OPOST | unix.ONLCR + + // Configure control flags + // CS8: 8-bit characters + // CREAD: Enable receiver + // HUPCL: Hang up on last close + termios.Cflag |= unix.CS8 | unix.CREAD | unix.HUPCL + termios.Cflag &^= unix.PARENB // Disable parity + + // Configure local flags + // ISIG: Enable signal generation (SIGINT on Ctrl+C, etc) + // ICANON: Enable canonical mode (line editing) + // ECHO: Enable echo + // ECHOE: Echo erase character as BS-SP-BS + // ECHOK: Echo kill character + // ECHONL: Echo NL even if ECHO is off + // IEXTEN: Enable extended functions + termios.Lflag |= unix.ISIG | unix.ICANON | unix.ECHO | unix.ECHOE | unix.ECHOK | unix.ECHONL | unix.IEXTEN + + // Set control characters + termios.Cc[unix.VEOF] = 4 // Ctrl+D + termios.Cc[unix.VEOL] = 0 // Additional end-of-line + termios.Cc[unix.VERASE] = 127 // DEL + termios.Cc[unix.VINTR] = 3 // Ctrl+C + termios.Cc[unix.VKILL] = 21 // Ctrl+U + termios.Cc[unix.VMIN] = 1 // Minimum characters for read + termios.Cc[unix.VQUIT] = 28 // Ctrl+\ + termios.Cc[unix.VSTART] = 17 // Ctrl+Q + termios.Cc[unix.VSTOP] = 19 // Ctrl+S + termios.Cc[unix.VSUSP] = 26 // Ctrl+Z + termios.Cc[unix.VTIME] = 0 // Timeout for read + + // Apply the terminal attributes + if err := unix.IoctlSetTermios(fd, unix.TIOCSETA, termios); err != nil { + return fmt.Errorf("failed to set terminal attributes: %w", err) + } + + debugLog("[DEBUG] PTY terminal configured with proper flow control and signal handling") + return nil +} + +// setPTYSize sets the window size of the PTY +func setPTYSize(ptyFile *os.File, cols, rows uint16) error { + fd := int(ptyFile.Fd()) + + ws := &unix.Winsize{ + Row: rows, + Col: cols, + Xpixel: 0, + Ypixel: 0, + } + + if err := unix.IoctlSetWinsize(fd, unix.TIOCSWINSZ, ws); err != nil { + return fmt.Errorf("failed to set PTY size: %w", err) + } + + return nil +} + +// getPTYSize gets the current window size of the PTY +func getPTYSize(ptyFile *os.File) (cols, rows uint16, err error) { + fd := int(ptyFile.Fd()) + + ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) + if err != nil { + return 0, 0, fmt.Errorf("failed to get PTY size: %w", err) + } + + return ws.Col, ws.Row, nil +} + +// sendSignalToPTY sends a signal to the PTY process group +func sendSignalToPTY(ptyFile *os.File, signal syscall.Signal) error { + fd := int(ptyFile.Fd()) + + // Get the process group ID of the PTY + pgid, err := unix.IoctlGetInt(fd, unix.TIOCGPGRP) + if err != nil { + return fmt.Errorf("failed to get PTY process group: %w", err) + } + + // Send signal to the process group + if err := syscall.Kill(-pgid, signal); err != nil { + return fmt.Errorf("failed to send signal to PTY process group: %w", err) + } + + return nil +} + +// isTerminal checks if a file descriptor is a terminal +func isTerminal(fd int) bool { + _, err := unix.IoctlGetTermios(fd, unix.TIOCGETA) + return err == nil +} \ No newline at end of file