Port node implementation details to go

This commit is contained in:
Peter Steinberger 2025-06-20 15:11:26 +02:00
parent 014bbb9e1e
commit f96b63c77d
10 changed files with 1042 additions and 121 deletions

View file

@ -15,6 +15,7 @@ require (
github.com/spf13/cobra v1.9.1
github.com/spf13/pflag v1.0.6
golang.ngrok.com/ngrok v1.13.0
golang.org/x/sys v0.33.0
golang.org/x/term v0.32.0
gopkg.in/yaml.v3 v3.0.1
)
@ -51,7 +52,6 @@ require (
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/tools v0.34.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect

View file

@ -42,25 +42,27 @@ type StreamEvent struct {
}
type StreamWriter struct {
writer io.Writer
header *AsciinemaHeader
startTime time.Time
mutex sync.Mutex
closed bool
buffer []byte
lastWrite time.Time
flushTimer *time.Timer
syncTimer *time.Timer
needsSync bool
writer io.Writer
header *AsciinemaHeader
startTime time.Time
mutex sync.Mutex
closed bool
buffer []byte
escapeParser *EscapeParser
lastWrite time.Time
flushTimer *time.Timer
syncTimer *time.Timer
needsSync bool
}
func NewStreamWriter(writer io.Writer, header *AsciinemaHeader) *StreamWriter {
return &StreamWriter{
writer: writer,
header: header,
startTime: time.Now(),
buffer: make([]byte, 0, 4096),
lastWrite: time.Now(),
writer: writer,
header: header,
startTime: time.Now(),
buffer: make([]byte, 0, 4096),
escapeParser: NewEscapeParser(),
lastWrite: time.Now(),
}
}
@ -106,22 +108,24 @@ func (w *StreamWriter) writeEvent(eventType EventType, data []byte) error {
return fmt.Errorf("stream writer closed")
}
w.buffer = append(w.buffer, data...)
w.lastWrite = time.Now()
completeData, remaining := extractCompleteUTF8(w.buffer)
// Use escape parser to ensure escape sequences are not split
processedData, remaining := w.escapeParser.ProcessData(data)
// Update buffer with any remaining incomplete sequences
w.buffer = remaining
if len(completeData) == 0 {
// If we have incomplete UTF-8 data, set up a timer to flush it after a short delay
if len(w.buffer) > 0 {
if len(processedData) == 0 {
// If we have incomplete data, set up a timer to flush it after a short delay
if len(w.buffer) > 0 || w.escapeParser.BufferSize() > 0 {
w.scheduleFlush()
}
return nil
}
elapsed := time.Since(w.startTime).Seconds()
event := []interface{}{elapsed, string(eventType), string(completeData)}
event := []interface{}{elapsed, string(eventType), string(processedData)}
eventData, err := json.Marshal(event)
if err != nil {
@ -151,13 +155,25 @@ func (w *StreamWriter) scheduleFlush() {
w.mutex.Lock()
defer w.mutex.Unlock()
if w.closed || len(w.buffer) == 0 {
if w.closed {
return
}
// Force flush incomplete UTF-8 data for real-time streaming
// Flush any buffered data from escape parser
flushedData := w.escapeParser.Flush()
if len(flushedData) == 0 && len(w.buffer) == 0 {
return
}
// Combine flushed data with any remaining buffer
dataToWrite := append(flushedData, w.buffer...)
if len(dataToWrite) == 0 {
return
}
// Force flush incomplete data for real-time streaming
elapsed := time.Since(w.startTime).Seconds()
event := []interface{}{elapsed, string(EventOutput), string(w.buffer)}
event := []interface{}{elapsed, string(EventOutput), string(dataToWrite)}
eventData, err := json.Marshal(event)
if err != nil {
@ -218,9 +234,13 @@ func (w *StreamWriter) Close() error {
w.syncTimer.Stop()
}
if len(w.buffer) > 0 {
// Flush any remaining data from escape parser
flushedData := w.escapeParser.Flush()
finalData := append(flushedData, w.buffer...)
if len(finalData) > 0 {
elapsed := time.Since(w.startTime).Seconds()
event := []interface{}{elapsed, string(EventOutput), string(w.buffer)}
event := []interface{}{elapsed, string(EventOutput), string(finalData)}
eventData, _ := json.Marshal(event)
if _, err := fmt.Fprintf(w.writer, "%s\n", eventData); err != nil {
// Write failed during close - log to stderr to avoid deadlock

View file

@ -0,0 +1,244 @@
package protocol
import (
"unicode/utf8"
)
// EscapeParser handles parsing of terminal escape sequences and UTF-8 data
// This ensures escape sequences are not split across chunks
type EscapeParser struct {
buffer []byte
}
// NewEscapeParser creates a new escape sequence parser
func NewEscapeParser() *EscapeParser {
return &EscapeParser{
buffer: make([]byte, 0, 4096),
}
}
// ProcessData processes terminal data ensuring escape sequences and UTF-8 are not split
// Returns processed data and any remaining incomplete sequences
func (p *EscapeParser) ProcessData(data []byte) (processed []byte, remaining []byte) {
// Combine buffered data with new data
combined := append(p.buffer, data...)
p.buffer = p.buffer[:0] // Clear buffer without reallocating
result := make([]byte, 0, len(combined))
pos := 0
for pos < len(combined) {
// Check for escape sequence
if combined[pos] == 0x1b { // ESC character
seqEnd := p.findEscapeSequenceEnd(combined[pos:])
if seqEnd == -1 {
// Incomplete escape sequence, save for next time
p.buffer = append(p.buffer, combined[pos:]...)
break
}
// Include complete escape sequence
result = append(result, combined[pos:pos+seqEnd]...)
pos += seqEnd
continue
}
// Process UTF-8 character
r, size := utf8.DecodeRune(combined[pos:])
if r == utf8.RuneError {
if size == 0 {
// No more data
break
}
if size == 1 && pos+4 > len(combined) {
// Might be incomplete UTF-8 at end of buffer
if p.mightBeIncompleteUTF8(combined[pos:]) {
p.buffer = append(p.buffer, combined[pos:]...)
break
}
}
// Invalid UTF-8, skip byte
result = append(result, combined[pos])
pos++
continue
}
// Valid UTF-8 character
result = append(result, combined[pos:pos+size]...)
pos += size
}
return result, p.buffer
}
// findEscapeSequenceEnd finds the end of an ANSI escape sequence
// Returns -1 if sequence is incomplete
func (p *EscapeParser) findEscapeSequenceEnd(data []byte) int {
if len(data) == 0 || data[0] != 0x1b {
return -1
}
if len(data) < 2 {
return -1 // Need more data
}
switch data[1] {
case '[': // CSI sequence: ESC [ ... final_char
pos := 2
for pos < len(data) {
b := data[pos]
if b >= 0x20 && b <= 0x3f {
// Parameter and intermediate characters
pos++
} else if b >= 0x40 && b <= 0x7e {
// Final character found
return pos + 1
} else {
// Invalid sequence
return pos
}
}
return -1 // Incomplete
case ']': // OSC sequence: ESC ] ... (ST or BEL)
pos := 2
for pos < len(data) {
if data[pos] == 0x07 { // BEL terminator
return pos + 1
}
if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' {
// ESC \ (ST) terminator
return pos + 2
}
pos++
}
return -1 // Incomplete
case '(', ')', '*', '+': // Charset selection
if len(data) < 3 {
return -1
}
return 3
case 'P', 'X', '^', '_': // DCS, SOS, PM, APC sequences
// These need special termination sequences
pos := 2
for pos < len(data) {
if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' {
// ESC \ (ST) terminator
return pos + 2
}
pos++
}
return -1 // Incomplete
default:
// Simple two-character sequences
return 2
}
}
// mightBeIncompleteUTF8 checks if data might be an incomplete UTF-8 sequence
func (p *EscapeParser) mightBeIncompleteUTF8(data []byte) bool {
if len(data) == 0 {
return false
}
b := data[0]
// Single byte (ASCII)
if b < 0x80 {
return false
}
// Multi-byte sequence starters
if b >= 0xc0 {
if b < 0xe0 {
// 2-byte sequence
return len(data) < 2
}
if b < 0xf0 {
// 3-byte sequence
return len(data) < 3
}
if b < 0xf8 {
// 4-byte sequence
return len(data) < 4
}
}
return false
}
// Flush returns any buffered data (for use when closing)
func (p *EscapeParser) Flush() []byte {
if len(p.buffer) == 0 {
return nil
}
// Return buffered data as-is when flushing
result := make([]byte, len(p.buffer))
copy(result, p.buffer)
p.buffer = p.buffer[:0]
return result
}
// Reset clears the parser state
func (p *EscapeParser) Reset() {
p.buffer = p.buffer[:0]
}
// BufferSize returns the current buffer size
func (p *EscapeParser) BufferSize() int {
return len(p.buffer)
}
// SplitEscapeSequences splits data at escape sequence boundaries
// This is useful for processing data in chunks without splitting sequences
func SplitEscapeSequences(data []byte) [][]byte {
if len(data) == 0 {
return nil
}
var chunks [][]byte
parser := NewEscapeParser()
processed, remaining := parser.ProcessData(data)
if len(processed) > 0 {
chunks = append(chunks, processed)
}
if len(remaining) > 0 {
chunks = append(chunks, remaining)
}
return chunks
}
// IsCompleteEscapeSequence checks if data contains a complete escape sequence
func IsCompleteEscapeSequence(data []byte) bool {
if len(data) == 0 || data[0] != 0x1b {
return false
}
parser := NewEscapeParser()
end := parser.findEscapeSequenceEnd(data)
return end > 0 && end == len(data)
}
// StripEscapeSequences removes all ANSI escape sequences from data
func StripEscapeSequences(data []byte) []byte {
result := make([]byte, 0, len(data))
pos := 0
parser := NewEscapeParser()
for pos < len(data) {
if data[pos] == 0x1b {
seqEnd := parser.findEscapeSequenceEnd(data[pos:])
if seqEnd > 0 {
pos += seqEnd
continue
}
}
result = append(result, data[pos])
pos++
}
return result
}

165
linux/pkg/session/errors.go Normal file
View file

@ -0,0 +1,165 @@
package session
import (
"fmt"
)
// ErrorCode represents standardized error codes matching Node.js implementation
type ErrorCode string
const (
// Session-related errors
ErrSessionNotFound ErrorCode = "SESSION_NOT_FOUND"
ErrSessionAlreadyExists ErrorCode = "SESSION_ALREADY_EXISTS"
ErrSessionStartFailed ErrorCode = "SESSION_START_FAILED"
ErrSessionNotRunning ErrorCode = "SESSION_NOT_RUNNING"
// Process-related errors
ErrProcessNotFound ErrorCode = "PROCESS_NOT_FOUND"
ErrProcessSignalFailed ErrorCode = "PROCESS_SIGNAL_FAILED"
ErrProcessTerminateFailed ErrorCode = "PROCESS_TERMINATE_FAILED"
// I/O related errors
ErrStdinNotFound ErrorCode = "STDIN_NOT_FOUND"
ErrStdinWriteFailed ErrorCode = "STDIN_WRITE_FAILED"
ErrStreamReadFailed ErrorCode = "STREAM_READ_FAILED"
ErrStreamWriteFailed ErrorCode = "STREAM_WRITE_FAILED"
// PTY-related errors
ErrPTYCreationFailed ErrorCode = "PTY_CREATION_FAILED"
ErrPTYConfigFailed ErrorCode = "PTY_CONFIG_FAILED"
ErrPTYResizeFailed ErrorCode = "PTY_RESIZE_FAILED"
// Control-related errors
ErrControlPathNotFound ErrorCode = "CONTROL_PATH_NOT_FOUND"
ErrControlFileCorrupted ErrorCode = "CONTROL_FILE_CORRUPTED"
// Input-related errors
ErrUnknownKey ErrorCode = "UNKNOWN_KEY"
ErrInvalidInput ErrorCode = "INVALID_INPUT"
// General errors
ErrInvalidArgument ErrorCode = "INVALID_ARGUMENT"
ErrPermissionDenied ErrorCode = "PERMISSION_DENIED"
ErrTimeout ErrorCode = "TIMEOUT"
ErrInternal ErrorCode = "INTERNAL_ERROR"
)
// SessionError represents an error with context, matching Node.js PtyError
type SessionError struct {
Message string
Code ErrorCode
SessionID string
Cause error
}
// Error implements the error interface
func (e *SessionError) Error() string {
if e.SessionID != "" {
return fmt.Sprintf("%s (session: %s, code: %s)", e.Message, e.SessionID[:8], e.Code)
}
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
}
// Unwrap returns the underlying cause
func (e *SessionError) Unwrap() error {
return e.Cause
}
// NewSessionError creates a new SessionError
func NewSessionError(message string, code ErrorCode, sessionID string) *SessionError {
return &SessionError{
Message: message,
Code: code,
SessionID: sessionID,
}
}
// NewSessionErrorWithCause creates a new SessionError with an underlying cause
func NewSessionErrorWithCause(message string, code ErrorCode, sessionID string, cause error) *SessionError {
return &SessionError{
Message: message,
Code: code,
SessionID: sessionID,
Cause: cause,
}
}
// WrapError wraps an existing error with session context
func WrapError(err error, code ErrorCode, sessionID string) *SessionError {
if err == nil {
return nil
}
// If it's already a SessionError, preserve the original but add context
if se, ok := err.(*SessionError); ok {
return &SessionError{
Message: se.Message,
Code: code,
SessionID: sessionID,
Cause: se,
}
}
return &SessionError{
Message: err.Error(),
Code: code,
SessionID: sessionID,
Cause: err,
}
}
// IsSessionError checks if an error is a SessionError with a specific code
func IsSessionError(err error, code ErrorCode) bool {
se, ok := err.(*SessionError)
return ok && se.Code == code
}
// GetSessionID extracts the session ID from an error if it's a SessionError
func GetSessionID(err error) string {
if se, ok := err.(*SessionError); ok {
return se.SessionID
}
return ""
}
// Common error constructors for convenience
// ErrSessionNotFoundError creates a session not found error
func ErrSessionNotFoundError(sessionID string) *SessionError {
return NewSessionError(
fmt.Sprintf("Session %s not found", sessionID[:8]),
ErrSessionNotFound,
sessionID,
)
}
// ErrProcessSignalError creates a process signal error
func ErrProcessSignalError(sessionID string, signal string, cause error) *SessionError {
return NewSessionErrorWithCause(
fmt.Sprintf("Failed to send signal %s to session", signal),
ErrProcessSignalFailed,
sessionID,
cause,
)
}
// ErrPTYCreationError creates a PTY creation error
func ErrPTYCreationError(sessionID string, cause error) *SessionError {
return NewSessionErrorWithCause(
"Failed to create PTY",
ErrPTYCreationFailed,
sessionID,
cause,
)
}
// ErrStdinWriteError creates a stdin write error
func ErrStdinWriteError(sessionID string, cause error) *SessionError {
return NewSessionErrorWithCause(
"Failed to write to stdin",
ErrStdinWriteFailed,
sessionID,
cause,
)
}

View file

@ -0,0 +1,157 @@
package session
import (
"log"
"os"
"time"
)
// ProcessTerminator provides graceful process termination with timeout
// Matches the Node.js implementation behavior
type ProcessTerminator struct {
session *Session
gracefulTimeout time.Duration
checkInterval time.Duration
}
// NewProcessTerminator creates a new process terminator
func NewProcessTerminator(session *Session) *ProcessTerminator {
return &ProcessTerminator{
session: session,
gracefulTimeout: 3 * time.Second, // Match Node.js 3 second timeout
checkInterval: 500 * time.Millisecond, // Match Node.js 500ms check interval
}
}
// TerminateGracefully attempts graceful termination with escalation to SIGKILL
// This matches the Node.js implementation behavior:
// 1. Send SIGTERM
// 2. Wait up to 3 seconds for graceful termination
// 3. Send SIGKILL if process is still alive
func (pt *ProcessTerminator) TerminateGracefully() error {
sessionID := pt.session.ID[:8]
pid := pt.session.info.Pid
// Check if already exited
if pt.session.info.Status == string(StatusExited) {
debugLog("[DEBUG] ProcessTerminator: Session %s already exited", sessionID)
pt.session.cleanup()
return nil
}
if pid == 0 {
return NewSessionError("no process to terminate", ErrProcessNotFound, pt.session.ID)
}
log.Printf("[INFO] Terminating session %s (PID: %d) with SIGTERM...", sessionID, pid)
// Send SIGTERM first
if err := pt.session.Signal("SIGTERM"); err != nil {
// If process doesn't exist, that's fine
if !pt.session.IsAlive() {
log.Printf("[INFO] Session %s already terminated", sessionID)
pt.session.cleanup()
return nil
}
// If it's already a SessionError, return as-is
if se, ok := err.(*SessionError); ok {
return se
}
return NewSessionErrorWithCause("failed to send SIGTERM", ErrProcessTerminateFailed, pt.session.ID, err)
}
// Wait for graceful termination
startTime := time.Now()
checkCount := 0
maxChecks := int(pt.gracefulTimeout / pt.checkInterval)
for checkCount < maxChecks {
// Wait for check interval
time.Sleep(pt.checkInterval)
checkCount++
// Check if process is still alive
if !pt.session.IsAlive() {
elapsed := time.Since(startTime)
log.Printf("[INFO] Session %s terminated gracefully after %dms", sessionID, elapsed.Milliseconds())
pt.session.cleanup()
return nil
}
// Log progress
elapsed := time.Since(startTime)
log.Printf("[INFO] Session %s still alive after %dms...", sessionID, elapsed.Milliseconds())
}
// Process didn't terminate gracefully, force kill
log.Printf("[INFO] Session %s didn't terminate gracefully, sending SIGKILL...", sessionID)
if err := pt.session.Signal("SIGKILL"); err != nil {
// If process doesn't exist anymore, that's fine
if !pt.session.IsAlive() {
log.Printf("[INFO] Session %s terminated before SIGKILL", sessionID)
pt.session.cleanup()
return nil
}
// If it's already a SessionError, return as-is
if se, ok := err.(*SessionError); ok {
return se
}
return NewSessionErrorWithCause("failed to send SIGKILL", ErrProcessTerminateFailed, pt.session.ID, err)
}
// Wait a bit for SIGKILL to take effect
time.Sleep(100 * time.Millisecond)
if pt.session.IsAlive() {
log.Printf("[WARN] Session %s may still be alive after SIGKILL", sessionID)
} else {
log.Printf("[INFO] Session %s forcefully terminated with SIGKILL", sessionID)
}
pt.session.cleanup()
return nil
}
// waitForProcessExit waits for a process to exit with timeout
// Returns true if process exited within timeout, false otherwise
func waitForProcessExit(pid int, timeout time.Duration) bool {
startTime := time.Now()
checkInterval := 100 * time.Millisecond
for time.Since(startTime) < timeout {
// Try to find the process
proc, err := os.FindProcess(pid)
if err != nil {
// Process doesn't exist
return true
}
// Check if process is alive using signal 0
if err := proc.Signal(os.Signal(nil)); err != nil {
// Process doesn't exist or we don't have permission
return true
}
time.Sleep(checkInterval)
}
return false
}
// isProcessRunning checks if a process is running by PID
// Uses platform-appropriate methods
func isProcessRunning(pid int) bool {
if pid <= 0 {
return false
}
proc, err := os.FindProcess(pid)
if err != nil {
return false
}
// On Unix, signal 0 checks if process exists
err = proc.Signal(os.Signal(nil))
return err == nil
}

View file

@ -22,13 +22,14 @@ import (
const useSelectPolling = true
type PTY struct {
session *Session
cmd *exec.Cmd
pty *os.File
oldState *term.State
streamWriter *protocol.StreamWriter
stdinPipe *os.File
resizeMutex sync.Mutex
session *Session
cmd *exec.Cmd
pty *os.File
oldState *term.State
streamWriter *protocol.StreamWriter
stdinPipe *os.File
useEventDrivenStdin bool
resizeMutex sync.Mutex
}
func NewPTY(session *Session) (*PTY, error) {
@ -53,7 +54,12 @@ func NewPTY(session *Session) (*PTY, error) {
// Verify the directory exists and is accessible
if _, err := os.Stat(session.info.Cwd); err != nil {
log.Printf("[ERROR] NewPTY: Working directory '%s' not accessible: %v", session.info.Cwd, err)
return nil, fmt.Errorf("working directory '%s' not accessible: %w", session.info.Cwd, err)
return nil, NewSessionErrorWithCause(
fmt.Sprintf("working directory '%s' not accessible", session.info.Cwd),
ErrInvalidArgument,
session.ID,
err,
)
}
cmd.Dir = session.info.Cwd
debugLog("[DEBUG] NewPTY: Set working directory to: %s", session.info.Cwd)
@ -101,7 +107,7 @@ func NewPTY(session *Session) (*PTY, error) {
ptmx, err := pty.Start(cmd)
if err != nil {
log.Printf("[ERROR] NewPTY: Failed to start PTY: %v", err)
return nil, fmt.Errorf("failed to start PTY: %w", err)
return nil, ErrPTYCreationError(session.ID, err)
}
debugLog("[DEBUG] NewPTY: PTY started successfully, PID: %d", cmd.Process.Pid)
@ -110,10 +116,15 @@ func NewPTY(session *Session) (*PTY, error) {
debugLog("[DEBUG] NewPTY: Executing command: %v in directory: %s", cmdline, cmd.Dir)
debugLog("[DEBUG] NewPTY: Environment has %d variables", len(cmd.Env))
if err := pty.Setsize(ptmx, &pty.Winsize{
Rows: uint16(session.info.Height),
Cols: uint16(session.info.Width),
}); err != nil {
// Configure terminal attributes to match node-pty behavior
// This must be done before setting size and after the process starts
if err := configurePTYTerminal(ptmx); err != nil {
log.Printf("[ERROR] NewPTY: Failed to configure PTY terminal: %v", err)
// Don't fail on terminal configuration errors, just log them
}
// Set PTY size using our enhanced function
if err := setPTYSize(ptmx, uint16(session.info.Width), uint16(session.info.Height)); err != nil {
log.Printf("[ERROR] NewPTY: Failed to set PTY size: %v", err)
if err := ptmx.Close(); err != nil {
log.Printf("[ERROR] NewPTY: Failed to close PTY: %v", err)
@ -121,13 +132,15 @@ func NewPTY(session *Session) (*PTY, error) {
if err := cmd.Process.Kill(); err != nil {
log.Printf("[ERROR] NewPTY: Failed to kill process: %v", err)
}
return nil, fmt.Errorf("failed to set PTY size: %w", err)
return nil, NewSessionErrorWithCause(
"failed to set PTY size",
ErrPTYResizeFailed,
session.ID,
err,
)
}
// Configure terminal modes for proper interactive shell behavior
// The creack/pty library handles basic setup, but we ensure the terminal
// is in the correct mode for interactive use (not raw mode)
debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode")
debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode with flow control")
streamOut, err := os.Create(session.StreamOutPath())
if err != nil {
@ -209,19 +222,32 @@ func (p *PTY) Run() error {
debugLog("[DEBUG] PTY.Run: Starting PTY run for session %s, PID %d", p.session.ID[:8], p.cmd.Process.Pid)
stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0)
// Use event-driven stdin handling like Node.js
stdinWatcher, err := NewStdinWatcher(p.session.StdinPath(), p.pty)
if err != nil {
log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err)
return fmt.Errorf("failed to open stdin pipe: %w", err)
}
defer func() {
if err := stdinPipe.Close(); err != nil {
log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err)
// Fall back to polling if watcher fails
log.Printf("[WARN] PTY.Run: Failed to create stdin watcher, falling back to polling: %v", err)
stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0)
if err != nil {
log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err)
return fmt.Errorf("failed to open stdin pipe: %w", err)
}
}()
p.stdinPipe = stdinPipe
defer func() {
if err := stdinPipe.Close(); err != nil {
log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err)
}
}()
p.stdinPipe = stdinPipe
} else {
// Start the watcher
stdinWatcher.Start()
defer stdinWatcher.Stop()
p.useEventDrivenStdin = true
debugLog("[DEBUG] PTY.Run: Using event-driven stdin handling")
}
debugLog("[DEBUG] PTY.Run: Stdin pipe opened successfully")
debugLog("[DEBUG] PTY.Run: Stdin handling initialized")
// Set up SIGWINCH handling for terminal resize
winchCh := make(chan os.Signal, 1)
@ -236,10 +262,7 @@ func (p *PTY) Run() error {
width, height, err := term.GetSize(int(os.Stdin.Fd()))
if err == nil {
debugLog("[DEBUG] PTY.Run: Received SIGWINCH, resizing to %dx%d", width, height)
if err := pty.Setsize(p.pty, &pty.Winsize{
Rows: uint16(height),
Cols: uint16(width),
}); err != nil {
if err := setPTYSize(p.pty, uint16(width), uint16(height)); err != nil {
log.Printf("[ERROR] PTY.Run: Failed to resize PTY: %v", err)
} else {
// Update session info
@ -301,45 +324,48 @@ func (p *PTY) Run() error {
}
}()
go func() {
debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine")
buf := make([]byte, 4096)
for {
n, err := stdinPipe.Read(buf)
if n > 0 {
debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n)
if _, err := p.pty.Write(buf[:n]); err != nil {
log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err)
// Only exit if the PTY is really broken, not on temporary errors
if err != syscall.EPIPE && err != syscall.ECONNRESET {
errCh <- fmt.Errorf("failed to write to PTY: %w", err)
return
// Only start stdin goroutine if not using event-driven mode
if !p.useEventDrivenStdin && p.stdinPipe != nil {
go func() {
debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine")
buf := make([]byte, 4096)
for {
n, err := p.stdinPipe.Read(buf)
if n > 0 {
debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n)
if _, err := p.pty.Write(buf[:n]); err != nil {
log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err)
// Only exit if the PTY is really broken, not on temporary errors
if err != syscall.EPIPE && err != syscall.ECONNRESET {
errCh <- fmt.Errorf("failed to write to PTY: %w", err)
return
}
// For broken pipe, just continue - the PTY might be closing
debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...")
time.Sleep(10 * time.Millisecond)
}
// For broken pipe, just continue - the PTY might be closing
debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...")
time.Sleep(10 * time.Millisecond)
// Continue immediately after successful write
continue
}
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
// No data available, longer pause to prevent excessive CPU usage
time.Sleep(10 * time.Millisecond)
continue
}
if err == io.EOF {
// No writers to the FIFO yet, longer pause before retry
time.Sleep(50 * time.Millisecond)
continue
}
if err != nil {
// Log other errors but don't crash the session - stdin issues shouldn't kill the PTY
log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err)
time.Sleep(10 * time.Millisecond)
continue
}
// Continue immediately after successful write
continue
}
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
// No data available, longer pause to prevent excessive CPU usage
time.Sleep(10 * time.Millisecond)
continue
}
if err == io.EOF {
// No writers to the FIFO yet, longer pause before retry
time.Sleep(50 * time.Millisecond)
continue
}
if err != nil {
// Log other errors but don't crash the session - stdin issues shouldn't kill the PTY
log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err)
time.Sleep(10 * time.Millisecond)
continue
}
}
}()
}()
}
go func() {
debugLog("[DEBUG] PTY.Run: Starting process wait goroutine for PID %d", p.cmd.Process.Pid)

View file

@ -74,7 +74,12 @@ func (p *PTY) pollWithSelect() error {
// Get file descriptors
ptyFd := int(p.pty.Fd())
stdinFd := int(p.stdinPipe.Fd())
var stdinFd int = -1
// Only include stdin in polling if not using event-driven mode
if !p.useEventDrivenStdin && p.stdinPipe != nil {
stdinFd = int(p.stdinPipe.Fd())
}
// Open control FIFO in non-blocking mode
controlPath := filepath.Join(p.session.Path(), "control")
@ -93,7 +98,10 @@ func (p *PTY) pollWithSelect() error {
for {
// Build FD list
fds := []int{ptyFd, stdinFd}
fds := []int{ptyFd}
if stdinFd >= 0 {
fds = append(fds, stdinFd)
}
if controlFd >= 0 {
fds = append(fds, controlFd)
}

View file

@ -353,7 +353,7 @@ func (s *Session) proxyInputToNodeJS(data []byte) error {
func (s *Session) Signal(sig string) error {
if s.info.Pid == 0 {
return fmt.Errorf("no process to signal")
return NewSessionError("no process to signal", ErrProcessNotFound, s.ID)
}
// Check if process is still alive before signaling
@ -370,21 +370,27 @@ func (s *Session) Signal(sig string) error {
proc, err := os.FindProcess(s.info.Pid)
if err != nil {
return err
return ErrProcessSignalError(s.ID, sig, err)
}
switch sig {
case "SIGTERM":
return proc.Signal(os.Interrupt)
if err := proc.Signal(os.Interrupt); err != nil {
return ErrProcessSignalError(s.ID, sig, err)
}
return nil
case "SIGKILL":
err = proc.Kill()
// If kill fails with "process already finished", that's okay
if err != nil && strings.Contains(err.Error(), "process already finished") {
return nil
}
return err
if err != nil {
return ErrProcessSignalError(s.ID, sig, err)
}
return nil
default:
return fmt.Errorf("unsupported signal: %s", sig)
return NewSessionError(fmt.Sprintf("unsupported signal: %s", sig), ErrInvalidArgument, s.ID)
}
}
@ -393,24 +399,29 @@ func (s *Session) Stop() error {
}
func (s *Session) Kill() error {
// First check if the session is already dead
if s.info.Status == string(StatusExited) {
// Already exited, just cleanup and return success
// Use graceful termination like Node.js
terminator := NewProcessTerminator(s)
return terminator.TerminateGracefully()
}
// KillWithSignal kills the session with the specified signal
// If signal is SIGKILL, it sends it immediately without graceful termination
func (s *Session) KillWithSignal(signal string) error {
// If SIGKILL is explicitly requested, send it immediately
if signal == "SIGKILL" || signal == "9" {
err := s.Signal("SIGKILL")
s.cleanup()
return nil
// If the error is because the process doesn't exist, that's fine
if err != nil && (strings.Contains(err.Error(), "no such process") ||
strings.Contains(err.Error(), "process already finished")) {
return nil
}
return err
}
// Try to kill the process
err := s.Signal("SIGKILL")
s.cleanup()
// If the error is because the process doesn't exist, that's fine
if err != nil && (strings.Contains(err.Error(), "no such process") ||
strings.Contains(err.Error(), "process already finished")) {
return nil
}
return err
// For other signals, use graceful termination
return s.Kill()
}
func (s *Session) cleanup() {
@ -427,17 +438,21 @@ func (s *Session) cleanup() {
func (s *Session) Resize(width, height int) error {
if s.pty == nil {
return fmt.Errorf("session not started")
return NewSessionError("session not started", ErrSessionNotRunning, s.ID)
}
// Check if session is still alive
if s.info.Status == string(StatusExited) {
return fmt.Errorf("cannot resize exited session")
return NewSessionError("cannot resize exited session", ErrSessionNotRunning, s.ID)
}
// Validate dimensions
if width <= 0 || height <= 0 {
return fmt.Errorf("invalid dimensions: width=%d, height=%d", width, height)
return NewSessionError(
fmt.Sprintf("invalid dimensions: width=%d, height=%d", width, height),
ErrInvalidArgument,
s.ID,
)
}
// Update session info

View file

@ -0,0 +1,153 @@
package session
import (
"fmt"
"io"
"log"
"os"
"sync"
"syscall"
"github.com/fsnotify/fsnotify"
)
// StdinWatcher provides event-driven stdin handling like Node.js
type StdinWatcher struct {
stdinPath string
ptyFile *os.File
watcher *fsnotify.Watcher
stdinFile *os.File
buffer []byte
mu sync.Mutex
stopChan chan struct{}
stoppedChan chan struct{}
}
// NewStdinWatcher creates a new stdin watcher
func NewStdinWatcher(stdinPath string, ptyFile *os.File) (*StdinWatcher, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("failed to create fsnotify watcher: %w", err)
}
sw := &StdinWatcher{
stdinPath: stdinPath,
ptyFile: ptyFile,
watcher: watcher,
buffer: make([]byte, 4096),
stopChan: make(chan struct{}),
stoppedChan: make(chan struct{}),
}
// Open stdin pipe for reading
stdinFile, err := os.OpenFile(stdinPath, os.O_RDONLY|syscall.O_NONBLOCK, 0)
if err != nil {
watcher.Close()
return nil, fmt.Errorf("failed to open stdin pipe: %w", err)
}
sw.stdinFile = stdinFile
// Add stdin path to watcher
if err := watcher.Add(stdinPath); err != nil {
stdinFile.Close()
watcher.Close()
return nil, fmt.Errorf("failed to watch stdin pipe: %w", err)
}
return sw, nil
}
// Start begins watching for stdin input
func (sw *StdinWatcher) Start() {
go sw.watchLoop()
}
// Stop stops the watcher
func (sw *StdinWatcher) Stop() {
close(sw.stopChan)
<-sw.stoppedChan
sw.cleanup()
}
// watchLoop is the main event loop
func (sw *StdinWatcher) watchLoop() {
defer close(sw.stoppedChan)
for {
select {
case <-sw.stopChan:
debugLog("[DEBUG] StdinWatcher: Stopping watch loop")
return
case event, ok := <-sw.watcher.Events:
if !ok {
debugLog("[DEBUG] StdinWatcher: Watcher events channel closed")
return
}
// Handle write events (new data available)
if event.Op&fsnotify.Write == fsnotify.Write {
sw.handleStdinData()
}
case err, ok := <-sw.watcher.Errors:
if !ok {
debugLog("[DEBUG] StdinWatcher: Watcher errors channel closed")
return
}
log.Printf("[ERROR] StdinWatcher: Watcher error: %v", err)
}
}
}
// handleStdinData reads available data and forwards it to the PTY
func (sw *StdinWatcher) handleStdinData() {
sw.mu.Lock()
defer sw.mu.Unlock()
for {
n, err := sw.stdinFile.Read(sw.buffer)
if n > 0 {
// Forward data to PTY immediately
if _, writeErr := sw.ptyFile.Write(sw.buffer[:n]); writeErr != nil {
log.Printf("[ERROR] StdinWatcher: Failed to write to PTY: %v", writeErr)
return
}
debugLog("[DEBUG] StdinWatcher: Forwarded %d bytes to PTY", n)
}
if err != nil {
if err == io.EOF || isEAGAIN(err) {
// No more data available right now
break
}
log.Printf("[ERROR] StdinWatcher: Failed to read from stdin: %v", err)
return
}
// If we read a full buffer, there might be more data
if n == len(sw.buffer) {
continue
}
break
}
}
// cleanup releases resources
func (sw *StdinWatcher) cleanup() {
if sw.watcher != nil {
sw.watcher.Close()
}
if sw.stdinFile != nil {
sw.stdinFile.Close()
}
}
// isEAGAIN checks if the error is EAGAIN (resource temporarily unavailable)
func isEAGAIN(err error) bool {
if err == nil {
return false
}
// Check for EAGAIN in the error string
return err.Error() == "resource temporarily unavailable"
}

View file

@ -0,0 +1,133 @@
package session
import (
"fmt"
"os"
"syscall"
"golang.org/x/sys/unix"
)
// TerminalMode represents terminal mode settings
type TerminalMode struct {
Raw bool
Echo bool
LineMode bool
FlowControl bool
}
// configurePTYTerminal configures the PTY terminal attributes to match node-pty behavior
// This ensures proper terminal behavior with flow control, signal handling, and line editing
func configurePTYTerminal(ptyFile *os.File) error {
fd := int(ptyFile.Fd())
// Get current terminal attributes
termios, err := unix.IoctlGetTermios(fd, unix.TIOCGETA)
if err != nil {
return fmt.Errorf("failed to get terminal attributes: %w", err)
}
// Configure input flags (similar to node-pty's handleFlowControl)
// IXON: Enable start/stop output control (Ctrl+S/Ctrl+Q)
// IXOFF: Enable start/stop input control
// IXANY: Any character will restart output after stop
// ICRNL: Map CR to NL on input
termios.Iflag |= unix.IXON | unix.IXOFF | unix.IXANY | unix.ICRNL
// Configure output flags
// OPOST: Enable output processing
// ONLCR: Map NL to CR-NL on output
termios.Oflag |= unix.OPOST | unix.ONLCR
// Configure control flags
// CS8: 8-bit characters
// CREAD: Enable receiver
// HUPCL: Hang up on last close
termios.Cflag |= unix.CS8 | unix.CREAD | unix.HUPCL
termios.Cflag &^= unix.PARENB // Disable parity
// Configure local flags
// ISIG: Enable signal generation (SIGINT on Ctrl+C, etc)
// ICANON: Enable canonical mode (line editing)
// ECHO: Enable echo
// ECHOE: Echo erase character as BS-SP-BS
// ECHOK: Echo kill character
// ECHONL: Echo NL even if ECHO is off
// IEXTEN: Enable extended functions
termios.Lflag |= unix.ISIG | unix.ICANON | unix.ECHO | unix.ECHOE | unix.ECHOK | unix.ECHONL | unix.IEXTEN
// Set control characters
termios.Cc[unix.VEOF] = 4 // Ctrl+D
termios.Cc[unix.VEOL] = 0 // Additional end-of-line
termios.Cc[unix.VERASE] = 127 // DEL
termios.Cc[unix.VINTR] = 3 // Ctrl+C
termios.Cc[unix.VKILL] = 21 // Ctrl+U
termios.Cc[unix.VMIN] = 1 // Minimum characters for read
termios.Cc[unix.VQUIT] = 28 // Ctrl+\
termios.Cc[unix.VSTART] = 17 // Ctrl+Q
termios.Cc[unix.VSTOP] = 19 // Ctrl+S
termios.Cc[unix.VSUSP] = 26 // Ctrl+Z
termios.Cc[unix.VTIME] = 0 // Timeout for read
// Apply the terminal attributes
if err := unix.IoctlSetTermios(fd, unix.TIOCSETA, termios); err != nil {
return fmt.Errorf("failed to set terminal attributes: %w", err)
}
debugLog("[DEBUG] PTY terminal configured with proper flow control and signal handling")
return nil
}
// setPTYSize sets the window size of the PTY
func setPTYSize(ptyFile *os.File, cols, rows uint16) error {
fd := int(ptyFile.Fd())
ws := &unix.Winsize{
Row: rows,
Col: cols,
Xpixel: 0,
Ypixel: 0,
}
if err := unix.IoctlSetWinsize(fd, unix.TIOCSWINSZ, ws); err != nil {
return fmt.Errorf("failed to set PTY size: %w", err)
}
return nil
}
// getPTYSize gets the current window size of the PTY
func getPTYSize(ptyFile *os.File) (cols, rows uint16, err error) {
fd := int(ptyFile.Fd())
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, fmt.Errorf("failed to get PTY size: %w", err)
}
return ws.Col, ws.Row, nil
}
// sendSignalToPTY sends a signal to the PTY process group
func sendSignalToPTY(ptyFile *os.File, signal syscall.Signal) error {
fd := int(ptyFile.Fd())
// Get the process group ID of the PTY
pgid, err := unix.IoctlGetInt(fd, unix.TIOCGPGRP)
if err != nil {
return fmt.Errorf("failed to get PTY process group: %w", err)
}
// Send signal to the process group
if err := syscall.Kill(-pgid, signal); err != nil {
return fmt.Errorf("failed to send signal to PTY process group: %w", err)
}
return nil
}
// isTerminal checks if a file descriptor is a terminal
func isTerminal(fd int) bool {
_, err := unix.IoctlGetTermios(fd, unix.TIOCGETA)
return err == nil
}