mirror of
https://github.com/samsonjs/vibetunnel.git
synced 2026-04-12 12:25:53 +00:00
Port node implementation details to go
This commit is contained in:
parent
014bbb9e1e
commit
f96b63c77d
10 changed files with 1042 additions and 121 deletions
|
|
@ -15,6 +15,7 @@ require (
|
|||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/pflag v1.0.6
|
||||
golang.ngrok.com/ngrok v1.13.0
|
||||
golang.org/x/sys v0.33.0
|
||||
golang.org/x/term v0.32.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
|
@ -51,7 +52,6 @@ require (
|
|||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
|
|
|
|||
|
|
@ -42,25 +42,27 @@ type StreamEvent struct {
|
|||
}
|
||||
|
||||
type StreamWriter struct {
|
||||
writer io.Writer
|
||||
header *AsciinemaHeader
|
||||
startTime time.Time
|
||||
mutex sync.Mutex
|
||||
closed bool
|
||||
buffer []byte
|
||||
lastWrite time.Time
|
||||
flushTimer *time.Timer
|
||||
syncTimer *time.Timer
|
||||
needsSync bool
|
||||
writer io.Writer
|
||||
header *AsciinemaHeader
|
||||
startTime time.Time
|
||||
mutex sync.Mutex
|
||||
closed bool
|
||||
buffer []byte
|
||||
escapeParser *EscapeParser
|
||||
lastWrite time.Time
|
||||
flushTimer *time.Timer
|
||||
syncTimer *time.Timer
|
||||
needsSync bool
|
||||
}
|
||||
|
||||
func NewStreamWriter(writer io.Writer, header *AsciinemaHeader) *StreamWriter {
|
||||
return &StreamWriter{
|
||||
writer: writer,
|
||||
header: header,
|
||||
startTime: time.Now(),
|
||||
buffer: make([]byte, 0, 4096),
|
||||
lastWrite: time.Now(),
|
||||
writer: writer,
|
||||
header: header,
|
||||
startTime: time.Now(),
|
||||
buffer: make([]byte, 0, 4096),
|
||||
escapeParser: NewEscapeParser(),
|
||||
lastWrite: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -106,22 +108,24 @@ func (w *StreamWriter) writeEvent(eventType EventType, data []byte) error {
|
|||
return fmt.Errorf("stream writer closed")
|
||||
}
|
||||
|
||||
w.buffer = append(w.buffer, data...)
|
||||
w.lastWrite = time.Now()
|
||||
|
||||
completeData, remaining := extractCompleteUTF8(w.buffer)
|
||||
// Use escape parser to ensure escape sequences are not split
|
||||
processedData, remaining := w.escapeParser.ProcessData(data)
|
||||
|
||||
// Update buffer with any remaining incomplete sequences
|
||||
w.buffer = remaining
|
||||
|
||||
if len(completeData) == 0 {
|
||||
// If we have incomplete UTF-8 data, set up a timer to flush it after a short delay
|
||||
if len(w.buffer) > 0 {
|
||||
if len(processedData) == 0 {
|
||||
// If we have incomplete data, set up a timer to flush it after a short delay
|
||||
if len(w.buffer) > 0 || w.escapeParser.BufferSize() > 0 {
|
||||
w.scheduleFlush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
elapsed := time.Since(w.startTime).Seconds()
|
||||
event := []interface{}{elapsed, string(eventType), string(completeData)}
|
||||
event := []interface{}{elapsed, string(eventType), string(processedData)}
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
|
@ -151,13 +155,25 @@ func (w *StreamWriter) scheduleFlush() {
|
|||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
if w.closed || len(w.buffer) == 0 {
|
||||
if w.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// Force flush incomplete UTF-8 data for real-time streaming
|
||||
// Flush any buffered data from escape parser
|
||||
flushedData := w.escapeParser.Flush()
|
||||
if len(flushedData) == 0 && len(w.buffer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Combine flushed data with any remaining buffer
|
||||
dataToWrite := append(flushedData, w.buffer...)
|
||||
if len(dataToWrite) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Force flush incomplete data for real-time streaming
|
||||
elapsed := time.Since(w.startTime).Seconds()
|
||||
event := []interface{}{elapsed, string(EventOutput), string(w.buffer)}
|
||||
event := []interface{}{elapsed, string(EventOutput), string(dataToWrite)}
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
|
@ -218,9 +234,13 @@ func (w *StreamWriter) Close() error {
|
|||
w.syncTimer.Stop()
|
||||
}
|
||||
|
||||
if len(w.buffer) > 0 {
|
||||
// Flush any remaining data from escape parser
|
||||
flushedData := w.escapeParser.Flush()
|
||||
finalData := append(flushedData, w.buffer...)
|
||||
|
||||
if len(finalData) > 0 {
|
||||
elapsed := time.Since(w.startTime).Seconds()
|
||||
event := []interface{}{elapsed, string(EventOutput), string(w.buffer)}
|
||||
event := []interface{}{elapsed, string(EventOutput), string(finalData)}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if _, err := fmt.Fprintf(w.writer, "%s\n", eventData); err != nil {
|
||||
// Write failed during close - log to stderr to avoid deadlock
|
||||
|
|
|
|||
244
linux/pkg/protocol/escape_parser.go
Normal file
244
linux/pkg/protocol/escape_parser.go
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// EscapeParser handles parsing of terminal escape sequences and UTF-8 data
|
||||
// This ensures escape sequences are not split across chunks
|
||||
type EscapeParser struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// NewEscapeParser creates a new escape sequence parser
|
||||
func NewEscapeParser() *EscapeParser {
|
||||
return &EscapeParser{
|
||||
buffer: make([]byte, 0, 4096),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessData processes terminal data ensuring escape sequences and UTF-8 are not split
|
||||
// Returns processed data and any remaining incomplete sequences
|
||||
func (p *EscapeParser) ProcessData(data []byte) (processed []byte, remaining []byte) {
|
||||
// Combine buffered data with new data
|
||||
combined := append(p.buffer, data...)
|
||||
p.buffer = p.buffer[:0] // Clear buffer without reallocating
|
||||
|
||||
result := make([]byte, 0, len(combined))
|
||||
pos := 0
|
||||
|
||||
for pos < len(combined) {
|
||||
// Check for escape sequence
|
||||
if combined[pos] == 0x1b { // ESC character
|
||||
seqEnd := p.findEscapeSequenceEnd(combined[pos:])
|
||||
if seqEnd == -1 {
|
||||
// Incomplete escape sequence, save for next time
|
||||
p.buffer = append(p.buffer, combined[pos:]...)
|
||||
break
|
||||
}
|
||||
// Include complete escape sequence
|
||||
result = append(result, combined[pos:pos+seqEnd]...)
|
||||
pos += seqEnd
|
||||
continue
|
||||
}
|
||||
|
||||
// Process UTF-8 character
|
||||
r, size := utf8.DecodeRune(combined[pos:])
|
||||
if r == utf8.RuneError {
|
||||
if size == 0 {
|
||||
// No more data
|
||||
break
|
||||
}
|
||||
if size == 1 && pos+4 > len(combined) {
|
||||
// Might be incomplete UTF-8 at end of buffer
|
||||
if p.mightBeIncompleteUTF8(combined[pos:]) {
|
||||
p.buffer = append(p.buffer, combined[pos:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Invalid UTF-8, skip byte
|
||||
result = append(result, combined[pos])
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid UTF-8 character
|
||||
result = append(result, combined[pos:pos+size]...)
|
||||
pos += size
|
||||
}
|
||||
|
||||
return result, p.buffer
|
||||
}
|
||||
|
||||
// findEscapeSequenceEnd finds the end of an ANSI escape sequence
|
||||
// Returns -1 if sequence is incomplete
|
||||
func (p *EscapeParser) findEscapeSequenceEnd(data []byte) int {
|
||||
if len(data) == 0 || data[0] != 0x1b {
|
||||
return -1
|
||||
}
|
||||
|
||||
if len(data) < 2 {
|
||||
return -1 // Need more data
|
||||
}
|
||||
|
||||
switch data[1] {
|
||||
case '[': // CSI sequence: ESC [ ... final_char
|
||||
pos := 2
|
||||
for pos < len(data) {
|
||||
b := data[pos]
|
||||
if b >= 0x20 && b <= 0x3f {
|
||||
// Parameter and intermediate characters
|
||||
pos++
|
||||
} else if b >= 0x40 && b <= 0x7e {
|
||||
// Final character found
|
||||
return pos + 1
|
||||
} else {
|
||||
// Invalid sequence
|
||||
return pos
|
||||
}
|
||||
}
|
||||
return -1 // Incomplete
|
||||
|
||||
case ']': // OSC sequence: ESC ] ... (ST or BEL)
|
||||
pos := 2
|
||||
for pos < len(data) {
|
||||
if data[pos] == 0x07 { // BEL terminator
|
||||
return pos + 1
|
||||
}
|
||||
if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' {
|
||||
// ESC \ (ST) terminator
|
||||
return pos + 2
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return -1 // Incomplete
|
||||
|
||||
case '(', ')', '*', '+': // Charset selection
|
||||
if len(data) < 3 {
|
||||
return -1
|
||||
}
|
||||
return 3
|
||||
|
||||
case 'P', 'X', '^', '_': // DCS, SOS, PM, APC sequences
|
||||
// These need special termination sequences
|
||||
pos := 2
|
||||
for pos < len(data) {
|
||||
if data[pos] == 0x1b && pos+1 < len(data) && data[pos+1] == '\\' {
|
||||
// ESC \ (ST) terminator
|
||||
return pos + 2
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return -1 // Incomplete
|
||||
|
||||
default:
|
||||
// Simple two-character sequences
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// mightBeIncompleteUTF8 checks if data might be an incomplete UTF-8 sequence
|
||||
func (p *EscapeParser) mightBeIncompleteUTF8(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
b := data[0]
|
||||
|
||||
// Single byte (ASCII)
|
||||
if b < 0x80 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Multi-byte sequence starters
|
||||
if b >= 0xc0 {
|
||||
if b < 0xe0 {
|
||||
// 2-byte sequence
|
||||
return len(data) < 2
|
||||
}
|
||||
if b < 0xf0 {
|
||||
// 3-byte sequence
|
||||
return len(data) < 3
|
||||
}
|
||||
if b < 0xf8 {
|
||||
// 4-byte sequence
|
||||
return len(data) < 4
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Flush returns any buffered data (for use when closing)
|
||||
func (p *EscapeParser) Flush() []byte {
|
||||
if len(p.buffer) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Return buffered data as-is when flushing
|
||||
result := make([]byte, len(p.buffer))
|
||||
copy(result, p.buffer)
|
||||
p.buffer = p.buffer[:0]
|
||||
return result
|
||||
}
|
||||
|
||||
// Reset clears the parser state
|
||||
func (p *EscapeParser) Reset() {
|
||||
p.buffer = p.buffer[:0]
|
||||
}
|
||||
|
||||
// BufferSize returns the current buffer size
|
||||
func (p *EscapeParser) BufferSize() int {
|
||||
return len(p.buffer)
|
||||
}
|
||||
|
||||
// SplitEscapeSequences splits data at escape sequence boundaries
|
||||
// This is useful for processing data in chunks without splitting sequences
|
||||
func SplitEscapeSequences(data []byte) [][]byte {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var chunks [][]byte
|
||||
parser := NewEscapeParser()
|
||||
|
||||
processed, remaining := parser.ProcessData(data)
|
||||
if len(processed) > 0 {
|
||||
chunks = append(chunks, processed)
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
chunks = append(chunks, remaining)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// IsCompleteEscapeSequence checks if data contains a complete escape sequence
|
||||
func IsCompleteEscapeSequence(data []byte) bool {
|
||||
if len(data) == 0 || data[0] != 0x1b {
|
||||
return false
|
||||
}
|
||||
parser := NewEscapeParser()
|
||||
end := parser.findEscapeSequenceEnd(data)
|
||||
return end > 0 && end == len(data)
|
||||
}
|
||||
|
||||
// StripEscapeSequences removes all ANSI escape sequences from data
|
||||
func StripEscapeSequences(data []byte) []byte {
|
||||
result := make([]byte, 0, len(data))
|
||||
pos := 0
|
||||
|
||||
parser := NewEscapeParser()
|
||||
for pos < len(data) {
|
||||
if data[pos] == 0x1b {
|
||||
seqEnd := parser.findEscapeSequenceEnd(data[pos:])
|
||||
if seqEnd > 0 {
|
||||
pos += seqEnd
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, data[pos])
|
||||
pos++
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
165
linux/pkg/session/errors.go
Normal file
165
linux/pkg/session/errors.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrorCode represents standardized error codes matching Node.js implementation
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// Session-related errors
|
||||
ErrSessionNotFound ErrorCode = "SESSION_NOT_FOUND"
|
||||
ErrSessionAlreadyExists ErrorCode = "SESSION_ALREADY_EXISTS"
|
||||
ErrSessionStartFailed ErrorCode = "SESSION_START_FAILED"
|
||||
ErrSessionNotRunning ErrorCode = "SESSION_NOT_RUNNING"
|
||||
|
||||
// Process-related errors
|
||||
ErrProcessNotFound ErrorCode = "PROCESS_NOT_FOUND"
|
||||
ErrProcessSignalFailed ErrorCode = "PROCESS_SIGNAL_FAILED"
|
||||
ErrProcessTerminateFailed ErrorCode = "PROCESS_TERMINATE_FAILED"
|
||||
|
||||
// I/O related errors
|
||||
ErrStdinNotFound ErrorCode = "STDIN_NOT_FOUND"
|
||||
ErrStdinWriteFailed ErrorCode = "STDIN_WRITE_FAILED"
|
||||
ErrStreamReadFailed ErrorCode = "STREAM_READ_FAILED"
|
||||
ErrStreamWriteFailed ErrorCode = "STREAM_WRITE_FAILED"
|
||||
|
||||
// PTY-related errors
|
||||
ErrPTYCreationFailed ErrorCode = "PTY_CREATION_FAILED"
|
||||
ErrPTYConfigFailed ErrorCode = "PTY_CONFIG_FAILED"
|
||||
ErrPTYResizeFailed ErrorCode = "PTY_RESIZE_FAILED"
|
||||
|
||||
// Control-related errors
|
||||
ErrControlPathNotFound ErrorCode = "CONTROL_PATH_NOT_FOUND"
|
||||
ErrControlFileCorrupted ErrorCode = "CONTROL_FILE_CORRUPTED"
|
||||
|
||||
// Input-related errors
|
||||
ErrUnknownKey ErrorCode = "UNKNOWN_KEY"
|
||||
ErrInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
|
||||
// General errors
|
||||
ErrInvalidArgument ErrorCode = "INVALID_ARGUMENT"
|
||||
ErrPermissionDenied ErrorCode = "PERMISSION_DENIED"
|
||||
ErrTimeout ErrorCode = "TIMEOUT"
|
||||
ErrInternal ErrorCode = "INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
// SessionError represents an error with context, matching Node.js PtyError
|
||||
type SessionError struct {
|
||||
Message string
|
||||
Code ErrorCode
|
||||
SessionID string
|
||||
Cause error
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *SessionError) Error() string {
|
||||
if e.SessionID != "" {
|
||||
return fmt.Sprintf("%s (session: %s, code: %s)", e.Message, e.SessionID[:8], e.Code)
|
||||
}
|
||||
return fmt.Sprintf("%s (code: %s)", e.Message, e.Code)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying cause
|
||||
func (e *SessionError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewSessionError creates a new SessionError
|
||||
func NewSessionError(message string, code ErrorCode, sessionID string) *SessionError {
|
||||
return &SessionError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSessionErrorWithCause creates a new SessionError with an underlying cause
|
||||
func NewSessionErrorWithCause(message string, code ErrorCode, sessionID string, cause error) *SessionError {
|
||||
return &SessionError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
SessionID: sessionID,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// WrapError wraps an existing error with session context
|
||||
func WrapError(err error, code ErrorCode, sessionID string) *SessionError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If it's already a SessionError, preserve the original but add context
|
||||
if se, ok := err.(*SessionError); ok {
|
||||
return &SessionError{
|
||||
Message: se.Message,
|
||||
Code: code,
|
||||
SessionID: sessionID,
|
||||
Cause: se,
|
||||
}
|
||||
}
|
||||
|
||||
return &SessionError{
|
||||
Message: err.Error(),
|
||||
Code: code,
|
||||
SessionID: sessionID,
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// IsSessionError checks if an error is a SessionError with a specific code
|
||||
func IsSessionError(err error, code ErrorCode) bool {
|
||||
se, ok := err.(*SessionError)
|
||||
return ok && se.Code == code
|
||||
}
|
||||
|
||||
// GetSessionID extracts the session ID from an error if it's a SessionError
|
||||
func GetSessionID(err error) string {
|
||||
if se, ok := err.(*SessionError); ok {
|
||||
return se.SessionID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Common error constructors for convenience
|
||||
|
||||
// ErrSessionNotFoundError creates a session not found error
|
||||
func ErrSessionNotFoundError(sessionID string) *SessionError {
|
||||
return NewSessionError(
|
||||
fmt.Sprintf("Session %s not found", sessionID[:8]),
|
||||
ErrSessionNotFound,
|
||||
sessionID,
|
||||
)
|
||||
}
|
||||
|
||||
// ErrProcessSignalError creates a process signal error
|
||||
func ErrProcessSignalError(sessionID string, signal string, cause error) *SessionError {
|
||||
return NewSessionErrorWithCause(
|
||||
fmt.Sprintf("Failed to send signal %s to session", signal),
|
||||
ErrProcessSignalFailed,
|
||||
sessionID,
|
||||
cause,
|
||||
)
|
||||
}
|
||||
|
||||
// ErrPTYCreationError creates a PTY creation error
|
||||
func ErrPTYCreationError(sessionID string, cause error) *SessionError {
|
||||
return NewSessionErrorWithCause(
|
||||
"Failed to create PTY",
|
||||
ErrPTYCreationFailed,
|
||||
sessionID,
|
||||
cause,
|
||||
)
|
||||
}
|
||||
|
||||
// ErrStdinWriteError creates a stdin write error
|
||||
func ErrStdinWriteError(sessionID string, cause error) *SessionError {
|
||||
return NewSessionErrorWithCause(
|
||||
"Failed to write to stdin",
|
||||
ErrStdinWriteFailed,
|
||||
sessionID,
|
||||
cause,
|
||||
)
|
||||
}
|
||||
157
linux/pkg/session/process.go
Normal file
157
linux/pkg/session/process.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProcessTerminator provides graceful process termination with timeout
|
||||
// Matches the Node.js implementation behavior
|
||||
type ProcessTerminator struct {
|
||||
session *Session
|
||||
gracefulTimeout time.Duration
|
||||
checkInterval time.Duration
|
||||
}
|
||||
|
||||
// NewProcessTerminator creates a new process terminator
|
||||
func NewProcessTerminator(session *Session) *ProcessTerminator {
|
||||
return &ProcessTerminator{
|
||||
session: session,
|
||||
gracefulTimeout: 3 * time.Second, // Match Node.js 3 second timeout
|
||||
checkInterval: 500 * time.Millisecond, // Match Node.js 500ms check interval
|
||||
}
|
||||
}
|
||||
|
||||
// TerminateGracefully attempts graceful termination with escalation to SIGKILL
|
||||
// This matches the Node.js implementation behavior:
|
||||
// 1. Send SIGTERM
|
||||
// 2. Wait up to 3 seconds for graceful termination
|
||||
// 3. Send SIGKILL if process is still alive
|
||||
func (pt *ProcessTerminator) TerminateGracefully() error {
|
||||
sessionID := pt.session.ID[:8]
|
||||
pid := pt.session.info.Pid
|
||||
|
||||
// Check if already exited
|
||||
if pt.session.info.Status == string(StatusExited) {
|
||||
debugLog("[DEBUG] ProcessTerminator: Session %s already exited", sessionID)
|
||||
pt.session.cleanup()
|
||||
return nil
|
||||
}
|
||||
|
||||
if pid == 0 {
|
||||
return NewSessionError("no process to terminate", ErrProcessNotFound, pt.session.ID)
|
||||
}
|
||||
|
||||
log.Printf("[INFO] Terminating session %s (PID: %d) with SIGTERM...", sessionID, pid)
|
||||
|
||||
// Send SIGTERM first
|
||||
if err := pt.session.Signal("SIGTERM"); err != nil {
|
||||
// If process doesn't exist, that's fine
|
||||
if !pt.session.IsAlive() {
|
||||
log.Printf("[INFO] Session %s already terminated", sessionID)
|
||||
pt.session.cleanup()
|
||||
return nil
|
||||
}
|
||||
// If it's already a SessionError, return as-is
|
||||
if se, ok := err.(*SessionError); ok {
|
||||
return se
|
||||
}
|
||||
return NewSessionErrorWithCause("failed to send SIGTERM", ErrProcessTerminateFailed, pt.session.ID, err)
|
||||
}
|
||||
|
||||
// Wait for graceful termination
|
||||
startTime := time.Now()
|
||||
checkCount := 0
|
||||
maxChecks := int(pt.gracefulTimeout / pt.checkInterval)
|
||||
|
||||
for checkCount < maxChecks {
|
||||
// Wait for check interval
|
||||
time.Sleep(pt.checkInterval)
|
||||
checkCount++
|
||||
|
||||
// Check if process is still alive
|
||||
if !pt.session.IsAlive() {
|
||||
elapsed := time.Since(startTime)
|
||||
log.Printf("[INFO] Session %s terminated gracefully after %dms", sessionID, elapsed.Milliseconds())
|
||||
pt.session.cleanup()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Log progress
|
||||
elapsed := time.Since(startTime)
|
||||
log.Printf("[INFO] Session %s still alive after %dms...", sessionID, elapsed.Milliseconds())
|
||||
}
|
||||
|
||||
// Process didn't terminate gracefully, force kill
|
||||
log.Printf("[INFO] Session %s didn't terminate gracefully, sending SIGKILL...", sessionID)
|
||||
|
||||
if err := pt.session.Signal("SIGKILL"); err != nil {
|
||||
// If process doesn't exist anymore, that's fine
|
||||
if !pt.session.IsAlive() {
|
||||
log.Printf("[INFO] Session %s terminated before SIGKILL", sessionID)
|
||||
pt.session.cleanup()
|
||||
return nil
|
||||
}
|
||||
// If it's already a SessionError, return as-is
|
||||
if se, ok := err.(*SessionError); ok {
|
||||
return se
|
||||
}
|
||||
return NewSessionErrorWithCause("failed to send SIGKILL", ErrProcessTerminateFailed, pt.session.ID, err)
|
||||
}
|
||||
|
||||
// Wait a bit for SIGKILL to take effect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if pt.session.IsAlive() {
|
||||
log.Printf("[WARN] Session %s may still be alive after SIGKILL", sessionID)
|
||||
} else {
|
||||
log.Printf("[INFO] Session %s forcefully terminated with SIGKILL", sessionID)
|
||||
}
|
||||
|
||||
pt.session.cleanup()
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForProcessExit waits for a process to exit with timeout
|
||||
// Returns true if process exited within timeout, false otherwise
|
||||
func waitForProcessExit(pid int, timeout time.Duration) bool {
|
||||
startTime := time.Now()
|
||||
checkInterval := 100 * time.Millisecond
|
||||
|
||||
for time.Since(startTime) < timeout {
|
||||
// Try to find the process
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
// Process doesn't exist
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if process is alive using signal 0
|
||||
if err := proc.Signal(os.Signal(nil)); err != nil {
|
||||
// Process doesn't exist or we don't have permission
|
||||
return true
|
||||
}
|
||||
|
||||
time.Sleep(checkInterval)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isProcessRunning checks if a process is running by PID
|
||||
// Uses platform-appropriate methods
|
||||
func isProcessRunning(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// On Unix, signal 0 checks if process exists
|
||||
err = proc.Signal(os.Signal(nil))
|
||||
return err == nil
|
||||
}
|
||||
|
|
@ -22,13 +22,14 @@ import (
|
|||
const useSelectPolling = true
|
||||
|
||||
type PTY struct {
|
||||
session *Session
|
||||
cmd *exec.Cmd
|
||||
pty *os.File
|
||||
oldState *term.State
|
||||
streamWriter *protocol.StreamWriter
|
||||
stdinPipe *os.File
|
||||
resizeMutex sync.Mutex
|
||||
session *Session
|
||||
cmd *exec.Cmd
|
||||
pty *os.File
|
||||
oldState *term.State
|
||||
streamWriter *protocol.StreamWriter
|
||||
stdinPipe *os.File
|
||||
useEventDrivenStdin bool
|
||||
resizeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewPTY(session *Session) (*PTY, error) {
|
||||
|
|
@ -53,7 +54,12 @@ func NewPTY(session *Session) (*PTY, error) {
|
|||
// Verify the directory exists and is accessible
|
||||
if _, err := os.Stat(session.info.Cwd); err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Working directory '%s' not accessible: %v", session.info.Cwd, err)
|
||||
return nil, fmt.Errorf("working directory '%s' not accessible: %w", session.info.Cwd, err)
|
||||
return nil, NewSessionErrorWithCause(
|
||||
fmt.Sprintf("working directory '%s' not accessible", session.info.Cwd),
|
||||
ErrInvalidArgument,
|
||||
session.ID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
cmd.Dir = session.info.Cwd
|
||||
debugLog("[DEBUG] NewPTY: Set working directory to: %s", session.info.Cwd)
|
||||
|
|
@ -101,7 +107,7 @@ func NewPTY(session *Session) (*PTY, error) {
|
|||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Failed to start PTY: %v", err)
|
||||
return nil, fmt.Errorf("failed to start PTY: %w", err)
|
||||
return nil, ErrPTYCreationError(session.ID, err)
|
||||
}
|
||||
|
||||
debugLog("[DEBUG] NewPTY: PTY started successfully, PID: %d", cmd.Process.Pid)
|
||||
|
|
@ -110,10 +116,15 @@ func NewPTY(session *Session) (*PTY, error) {
|
|||
debugLog("[DEBUG] NewPTY: Executing command: %v in directory: %s", cmdline, cmd.Dir)
|
||||
debugLog("[DEBUG] NewPTY: Environment has %d variables", len(cmd.Env))
|
||||
|
||||
if err := pty.Setsize(ptmx, &pty.Winsize{
|
||||
Rows: uint16(session.info.Height),
|
||||
Cols: uint16(session.info.Width),
|
||||
}); err != nil {
|
||||
// Configure terminal attributes to match node-pty behavior
|
||||
// This must be done before setting size and after the process starts
|
||||
if err := configurePTYTerminal(ptmx); err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Failed to configure PTY terminal: %v", err)
|
||||
// Don't fail on terminal configuration errors, just log them
|
||||
}
|
||||
|
||||
// Set PTY size using our enhanced function
|
||||
if err := setPTYSize(ptmx, uint16(session.info.Width), uint16(session.info.Height)); err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Failed to set PTY size: %v", err)
|
||||
if err := ptmx.Close(); err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Failed to close PTY: %v", err)
|
||||
|
|
@ -121,13 +132,15 @@ func NewPTY(session *Session) (*PTY, error) {
|
|||
if err := cmd.Process.Kill(); err != nil {
|
||||
log.Printf("[ERROR] NewPTY: Failed to kill process: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to set PTY size: %w", err)
|
||||
return nil, NewSessionErrorWithCause(
|
||||
"failed to set PTY size",
|
||||
ErrPTYResizeFailed,
|
||||
session.ID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// Configure terminal modes for proper interactive shell behavior
|
||||
// The creack/pty library handles basic setup, but we ensure the terminal
|
||||
// is in the correct mode for interactive use (not raw mode)
|
||||
debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode")
|
||||
debugLog("[DEBUG] NewPTY: Terminal configured for interactive mode with flow control")
|
||||
|
||||
streamOut, err := os.Create(session.StreamOutPath())
|
||||
if err != nil {
|
||||
|
|
@ -209,19 +222,32 @@ func (p *PTY) Run() error {
|
|||
|
||||
debugLog("[DEBUG] PTY.Run: Starting PTY run for session %s, PID %d", p.session.ID[:8], p.cmd.Process.Pid)
|
||||
|
||||
stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0)
|
||||
// Use event-driven stdin handling like Node.js
|
||||
stdinWatcher, err := NewStdinWatcher(p.session.StdinPath(), p.pty)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err)
|
||||
return fmt.Errorf("failed to open stdin pipe: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err)
|
||||
// Fall back to polling if watcher fails
|
||||
log.Printf("[WARN] PTY.Run: Failed to create stdin watcher, falling back to polling: %v", err)
|
||||
|
||||
stdinPipe, err := os.OpenFile(p.session.StdinPath(), os.O_RDONLY|syscall.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to open stdin pipe: %v", err)
|
||||
return fmt.Errorf("failed to open stdin pipe: %w", err)
|
||||
}
|
||||
}()
|
||||
p.stdinPipe = stdinPipe
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to close stdin pipe: %v", err)
|
||||
}
|
||||
}()
|
||||
p.stdinPipe = stdinPipe
|
||||
} else {
|
||||
// Start the watcher
|
||||
stdinWatcher.Start()
|
||||
defer stdinWatcher.Stop()
|
||||
p.useEventDrivenStdin = true
|
||||
debugLog("[DEBUG] PTY.Run: Using event-driven stdin handling")
|
||||
}
|
||||
|
||||
debugLog("[DEBUG] PTY.Run: Stdin pipe opened successfully")
|
||||
debugLog("[DEBUG] PTY.Run: Stdin handling initialized")
|
||||
|
||||
// Set up SIGWINCH handling for terminal resize
|
||||
winchCh := make(chan os.Signal, 1)
|
||||
|
|
@ -236,10 +262,7 @@ func (p *PTY) Run() error {
|
|||
width, height, err := term.GetSize(int(os.Stdin.Fd()))
|
||||
if err == nil {
|
||||
debugLog("[DEBUG] PTY.Run: Received SIGWINCH, resizing to %dx%d", width, height)
|
||||
if err := pty.Setsize(p.pty, &pty.Winsize{
|
||||
Rows: uint16(height),
|
||||
Cols: uint16(width),
|
||||
}); err != nil {
|
||||
if err := setPTYSize(p.pty, uint16(width), uint16(height)); err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to resize PTY: %v", err)
|
||||
} else {
|
||||
// Update session info
|
||||
|
|
@ -301,45 +324,48 @@ func (p *PTY) Run() error {
|
|||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine")
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := stdinPipe.Read(buf)
|
||||
if n > 0 {
|
||||
debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n)
|
||||
if _, err := p.pty.Write(buf[:n]); err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err)
|
||||
// Only exit if the PTY is really broken, not on temporary errors
|
||||
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||||
errCh <- fmt.Errorf("failed to write to PTY: %w", err)
|
||||
return
|
||||
// Only start stdin goroutine if not using event-driven mode
|
||||
if !p.useEventDrivenStdin && p.stdinPipe != nil {
|
||||
go func() {
|
||||
debugLog("[DEBUG] PTY.Run: Starting stdin reading goroutine")
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := p.stdinPipe.Read(buf)
|
||||
if n > 0 {
|
||||
debugLog("[DEBUG] PTY.Run: Read %d bytes from stdin, writing to PTY", n)
|
||||
if _, err := p.pty.Write(buf[:n]); err != nil {
|
||||
log.Printf("[ERROR] PTY.Run: Failed to write to PTY: %v", err)
|
||||
// Only exit if the PTY is really broken, not on temporary errors
|
||||
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||||
errCh <- fmt.Errorf("failed to write to PTY: %w", err)
|
||||
return
|
||||
}
|
||||
// For broken pipe, just continue - the PTY might be closing
|
||||
debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
// For broken pipe, just continue - the PTY might be closing
|
||||
debugLog("[DEBUG] PTY.Run: PTY write failed with pipe error, continuing...")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Continue immediately after successful write
|
||||
continue
|
||||
}
|
||||
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
|
||||
// No data available, longer pause to prevent excessive CPU usage
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if err == io.EOF {
|
||||
// No writers to the FIFO yet, longer pause before retry
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
// Log other errors but don't crash the session - stdin issues shouldn't kill the PTY
|
||||
log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
// Continue immediately after successful write
|
||||
continue
|
||||
}
|
||||
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
|
||||
// No data available, longer pause to prevent excessive CPU usage
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if err == io.EOF {
|
||||
// No writers to the FIFO yet, longer pause before retry
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
// Log other errors but don't crash the session - stdin issues shouldn't kill the PTY
|
||||
log.Printf("[WARN] PTY.Run: Stdin read error (non-fatal): %v", err)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
debugLog("[DEBUG] PTY.Run: Starting process wait goroutine for PID %d", p.cmd.Process.Pid)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,12 @@ func (p *PTY) pollWithSelect() error {
|
|||
|
||||
// Get file descriptors
|
||||
ptyFd := int(p.pty.Fd())
|
||||
stdinFd := int(p.stdinPipe.Fd())
|
||||
var stdinFd int = -1
|
||||
|
||||
// Only include stdin in polling if not using event-driven mode
|
||||
if !p.useEventDrivenStdin && p.stdinPipe != nil {
|
||||
stdinFd = int(p.stdinPipe.Fd())
|
||||
}
|
||||
|
||||
// Open control FIFO in non-blocking mode
|
||||
controlPath := filepath.Join(p.session.Path(), "control")
|
||||
|
|
@ -93,7 +98,10 @@ func (p *PTY) pollWithSelect() error {
|
|||
|
||||
for {
|
||||
// Build FD list
|
||||
fds := []int{ptyFd, stdinFd}
|
||||
fds := []int{ptyFd}
|
||||
if stdinFd >= 0 {
|
||||
fds = append(fds, stdinFd)
|
||||
}
|
||||
if controlFd >= 0 {
|
||||
fds = append(fds, controlFd)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -353,7 +353,7 @@ func (s *Session) proxyInputToNodeJS(data []byte) error {
|
|||
|
||||
func (s *Session) Signal(sig string) error {
|
||||
if s.info.Pid == 0 {
|
||||
return fmt.Errorf("no process to signal")
|
||||
return NewSessionError("no process to signal", ErrProcessNotFound, s.ID)
|
||||
}
|
||||
|
||||
// Check if process is still alive before signaling
|
||||
|
|
@ -370,21 +370,27 @@ func (s *Session) Signal(sig string) error {
|
|||
|
||||
proc, err := os.FindProcess(s.info.Pid)
|
||||
if err != nil {
|
||||
return err
|
||||
return ErrProcessSignalError(s.ID, sig, err)
|
||||
}
|
||||
|
||||
switch sig {
|
||||
case "SIGTERM":
|
||||
return proc.Signal(os.Interrupt)
|
||||
if err := proc.Signal(os.Interrupt); err != nil {
|
||||
return ErrProcessSignalError(s.ID, sig, err)
|
||||
}
|
||||
return nil
|
||||
case "SIGKILL":
|
||||
err = proc.Kill()
|
||||
// If kill fails with "process already finished", that's okay
|
||||
if err != nil && strings.Contains(err.Error(), "process already finished") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err != nil {
|
||||
return ErrProcessSignalError(s.ID, sig, err)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported signal: %s", sig)
|
||||
return NewSessionError(fmt.Sprintf("unsupported signal: %s", sig), ErrInvalidArgument, s.ID)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -393,24 +399,29 @@ func (s *Session) Stop() error {
|
|||
}
|
||||
|
||||
func (s *Session) Kill() error {
|
||||
// First check if the session is already dead
|
||||
if s.info.Status == string(StatusExited) {
|
||||
// Already exited, just cleanup and return success
|
||||
// Use graceful termination like Node.js
|
||||
terminator := NewProcessTerminator(s)
|
||||
return terminator.TerminateGracefully()
|
||||
}
|
||||
|
||||
// KillWithSignal kills the session with the specified signal
|
||||
// If signal is SIGKILL, it sends it immediately without graceful termination
|
||||
func (s *Session) KillWithSignal(signal string) error {
|
||||
// If SIGKILL is explicitly requested, send it immediately
|
||||
if signal == "SIGKILL" || signal == "9" {
|
||||
err := s.Signal("SIGKILL")
|
||||
s.cleanup()
|
||||
return nil
|
||||
|
||||
// If the error is because the process doesn't exist, that's fine
|
||||
if err != nil && (strings.Contains(err.Error(), "no such process") ||
|
||||
strings.Contains(err.Error(), "process already finished")) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to kill the process
|
||||
err := s.Signal("SIGKILL")
|
||||
s.cleanup()
|
||||
|
||||
// If the error is because the process doesn't exist, that's fine
|
||||
if err != nil && (strings.Contains(err.Error(), "no such process") ||
|
||||
strings.Contains(err.Error(), "process already finished")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
// For other signals, use graceful termination
|
||||
return s.Kill()
|
||||
}
|
||||
|
||||
func (s *Session) cleanup() {
|
||||
|
|
@ -427,17 +438,21 @@ func (s *Session) cleanup() {
|
|||
|
||||
func (s *Session) Resize(width, height int) error {
|
||||
if s.pty == nil {
|
||||
return fmt.Errorf("session not started")
|
||||
return NewSessionError("session not started", ErrSessionNotRunning, s.ID)
|
||||
}
|
||||
|
||||
// Check if session is still alive
|
||||
if s.info.Status == string(StatusExited) {
|
||||
return fmt.Errorf("cannot resize exited session")
|
||||
return NewSessionError("cannot resize exited session", ErrSessionNotRunning, s.ID)
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
if width <= 0 || height <= 0 {
|
||||
return fmt.Errorf("invalid dimensions: width=%d, height=%d", width, height)
|
||||
return NewSessionError(
|
||||
fmt.Sprintf("invalid dimensions: width=%d, height=%d", width, height),
|
||||
ErrInvalidArgument,
|
||||
s.ID,
|
||||
)
|
||||
}
|
||||
|
||||
// Update session info
|
||||
|
|
|
|||
153
linux/pkg/session/stdin_watcher.go
Normal file
153
linux/pkg/session/stdin_watcher.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// StdinWatcher provides event-driven stdin handling like Node.js
|
||||
type StdinWatcher struct {
|
||||
stdinPath string
|
||||
ptyFile *os.File
|
||||
watcher *fsnotify.Watcher
|
||||
stdinFile *os.File
|
||||
buffer []byte
|
||||
mu sync.Mutex
|
||||
stopChan chan struct{}
|
||||
stoppedChan chan struct{}
|
||||
}
|
||||
|
||||
// NewStdinWatcher creates a new stdin watcher
|
||||
func NewStdinWatcher(stdinPath string, ptyFile *os.File) (*StdinWatcher, error) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create fsnotify watcher: %w", err)
|
||||
}
|
||||
|
||||
sw := &StdinWatcher{
|
||||
stdinPath: stdinPath,
|
||||
ptyFile: ptyFile,
|
||||
watcher: watcher,
|
||||
buffer: make([]byte, 4096),
|
||||
stopChan: make(chan struct{}),
|
||||
stoppedChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Open stdin pipe for reading
|
||||
stdinFile, err := os.OpenFile(stdinPath, os.O_RDONLY|syscall.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
watcher.Close()
|
||||
return nil, fmt.Errorf("failed to open stdin pipe: %w", err)
|
||||
}
|
||||
sw.stdinFile = stdinFile
|
||||
|
||||
// Add stdin path to watcher
|
||||
if err := watcher.Add(stdinPath); err != nil {
|
||||
stdinFile.Close()
|
||||
watcher.Close()
|
||||
return nil, fmt.Errorf("failed to watch stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
return sw, nil
|
||||
}
|
||||
|
||||
// Start begins watching for stdin input
|
||||
func (sw *StdinWatcher) Start() {
|
||||
go sw.watchLoop()
|
||||
}
|
||||
|
||||
// Stop stops the watcher
|
||||
func (sw *StdinWatcher) Stop() {
|
||||
close(sw.stopChan)
|
||||
<-sw.stoppedChan
|
||||
sw.cleanup()
|
||||
}
|
||||
|
||||
// watchLoop is the main event loop
|
||||
func (sw *StdinWatcher) watchLoop() {
|
||||
defer close(sw.stoppedChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sw.stopChan:
|
||||
debugLog("[DEBUG] StdinWatcher: Stopping watch loop")
|
||||
return
|
||||
|
||||
case event, ok := <-sw.watcher.Events:
|
||||
if !ok {
|
||||
debugLog("[DEBUG] StdinWatcher: Watcher events channel closed")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle write events (new data available)
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
sw.handleStdinData()
|
||||
}
|
||||
|
||||
case err, ok := <-sw.watcher.Errors:
|
||||
if !ok {
|
||||
debugLog("[DEBUG] StdinWatcher: Watcher errors channel closed")
|
||||
return
|
||||
}
|
||||
log.Printf("[ERROR] StdinWatcher: Watcher error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStdinData reads available data and forwards it to the PTY
|
||||
func (sw *StdinWatcher) handleStdinData() {
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
|
||||
for {
|
||||
n, err := sw.stdinFile.Read(sw.buffer)
|
||||
if n > 0 {
|
||||
// Forward data to PTY immediately
|
||||
if _, writeErr := sw.ptyFile.Write(sw.buffer[:n]); writeErr != nil {
|
||||
log.Printf("[ERROR] StdinWatcher: Failed to write to PTY: %v", writeErr)
|
||||
return
|
||||
}
|
||||
debugLog("[DEBUG] StdinWatcher: Forwarded %d bytes to PTY", n)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF || isEAGAIN(err) {
|
||||
// No more data available right now
|
||||
break
|
||||
}
|
||||
log.Printf("[ERROR] StdinWatcher: Failed to read from stdin: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If we read a full buffer, there might be more data
|
||||
if n == len(sw.buffer) {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup releases resources
|
||||
func (sw *StdinWatcher) cleanup() {
|
||||
if sw.watcher != nil {
|
||||
sw.watcher.Close()
|
||||
}
|
||||
if sw.stdinFile != nil {
|
||||
sw.stdinFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// isEAGAIN checks if the error is EAGAIN (resource temporarily unavailable)
|
||||
func isEAGAIN(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
// Check for EAGAIN in the error string
|
||||
return err.Error() == "resource temporarily unavailable"
|
||||
}
|
||||
133
linux/pkg/session/terminal.go
Normal file
133
linux/pkg/session/terminal.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// TerminalMode represents terminal mode settings
|
||||
type TerminalMode struct {
|
||||
Raw bool
|
||||
Echo bool
|
||||
LineMode bool
|
||||
FlowControl bool
|
||||
}
|
||||
|
||||
// configurePTYTerminal configures the PTY terminal attributes to match node-pty behavior
|
||||
// This ensures proper terminal behavior with flow control, signal handling, and line editing
|
||||
func configurePTYTerminal(ptyFile *os.File) error {
|
||||
fd := int(ptyFile.Fd())
|
||||
|
||||
// Get current terminal attributes
|
||||
termios, err := unix.IoctlGetTermios(fd, unix.TIOCGETA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get terminal attributes: %w", err)
|
||||
}
|
||||
|
||||
// Configure input flags (similar to node-pty's handleFlowControl)
|
||||
// IXON: Enable start/stop output control (Ctrl+S/Ctrl+Q)
|
||||
// IXOFF: Enable start/stop input control
|
||||
// IXANY: Any character will restart output after stop
|
||||
// ICRNL: Map CR to NL on input
|
||||
termios.Iflag |= unix.IXON | unix.IXOFF | unix.IXANY | unix.ICRNL
|
||||
|
||||
// Configure output flags
|
||||
// OPOST: Enable output processing
|
||||
// ONLCR: Map NL to CR-NL on output
|
||||
termios.Oflag |= unix.OPOST | unix.ONLCR
|
||||
|
||||
// Configure control flags
|
||||
// CS8: 8-bit characters
|
||||
// CREAD: Enable receiver
|
||||
// HUPCL: Hang up on last close
|
||||
termios.Cflag |= unix.CS8 | unix.CREAD | unix.HUPCL
|
||||
termios.Cflag &^= unix.PARENB // Disable parity
|
||||
|
||||
// Configure local flags
|
||||
// ISIG: Enable signal generation (SIGINT on Ctrl+C, etc)
|
||||
// ICANON: Enable canonical mode (line editing)
|
||||
// ECHO: Enable echo
|
||||
// ECHOE: Echo erase character as BS-SP-BS
|
||||
// ECHOK: Echo kill character
|
||||
// ECHONL: Echo NL even if ECHO is off
|
||||
// IEXTEN: Enable extended functions
|
||||
termios.Lflag |= unix.ISIG | unix.ICANON | unix.ECHO | unix.ECHOE | unix.ECHOK | unix.ECHONL | unix.IEXTEN
|
||||
|
||||
// Set control characters
|
||||
termios.Cc[unix.VEOF] = 4 // Ctrl+D
|
||||
termios.Cc[unix.VEOL] = 0 // Additional end-of-line
|
||||
termios.Cc[unix.VERASE] = 127 // DEL
|
||||
termios.Cc[unix.VINTR] = 3 // Ctrl+C
|
||||
termios.Cc[unix.VKILL] = 21 // Ctrl+U
|
||||
termios.Cc[unix.VMIN] = 1 // Minimum characters for read
|
||||
termios.Cc[unix.VQUIT] = 28 // Ctrl+\
|
||||
termios.Cc[unix.VSTART] = 17 // Ctrl+Q
|
||||
termios.Cc[unix.VSTOP] = 19 // Ctrl+S
|
||||
termios.Cc[unix.VSUSP] = 26 // Ctrl+Z
|
||||
termios.Cc[unix.VTIME] = 0 // Timeout for read
|
||||
|
||||
// Apply the terminal attributes
|
||||
if err := unix.IoctlSetTermios(fd, unix.TIOCSETA, termios); err != nil {
|
||||
return fmt.Errorf("failed to set terminal attributes: %w", err)
|
||||
}
|
||||
|
||||
debugLog("[DEBUG] PTY terminal configured with proper flow control and signal handling")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setPTYSize sets the window size of the PTY
|
||||
func setPTYSize(ptyFile *os.File, cols, rows uint16) error {
|
||||
fd := int(ptyFile.Fd())
|
||||
|
||||
ws := &unix.Winsize{
|
||||
Row: rows,
|
||||
Col: cols,
|
||||
Xpixel: 0,
|
||||
Ypixel: 0,
|
||||
}
|
||||
|
||||
if err := unix.IoctlSetWinsize(fd, unix.TIOCSWINSZ, ws); err != nil {
|
||||
return fmt.Errorf("failed to set PTY size: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPTYSize gets the current window size of the PTY
|
||||
func getPTYSize(ptyFile *os.File) (cols, rows uint16, err error) {
|
||||
fd := int(ptyFile.Fd())
|
||||
|
||||
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to get PTY size: %w", err)
|
||||
}
|
||||
|
||||
return ws.Col, ws.Row, nil
|
||||
}
|
||||
|
||||
// sendSignalToPTY sends a signal to the PTY process group
|
||||
func sendSignalToPTY(ptyFile *os.File, signal syscall.Signal) error {
|
||||
fd := int(ptyFile.Fd())
|
||||
|
||||
// Get the process group ID of the PTY
|
||||
pgid, err := unix.IoctlGetInt(fd, unix.TIOCGPGRP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get PTY process group: %w", err)
|
||||
}
|
||||
|
||||
// Send signal to the process group
|
||||
if err := syscall.Kill(-pgid, signal); err != nil {
|
||||
return fmt.Errorf("failed to send signal to PTY process group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTerminal checks if a file descriptor is a terminal
|
||||
func isTerminal(fd int) bool {
|
||||
_, err := unix.IoctlGetTermios(fd, unix.TIOCGETA)
|
||||
return err == nil
|
||||
}
|
||||
Loading…
Reference in a new issue