Implement arithmetic expansion

This commit is contained in:
Sami Samhuri 2026-02-03 03:44:32 -08:00
parent a73974586e
commit df480e3cc1
No known key found for this signature in database
2 changed files with 193 additions and 6 deletions

View file

@ -25,7 +25,7 @@ module Shell
expand_braces(expanded)
end
.flat_map do |word|
if word =~ /[*?\[]/
if /[*?\[]/.match?(word)
glob_words = expand_globs(word)
glob_words.empty? ? [word] : glob_words
else
@ -142,8 +142,13 @@ module Shell
output << run_command_substitution(cmd)
when "$"
if line[i + 1] == "("
cmd, i = read_dollar_paren(line, i + 2)
output << run_command_substitution(cmd)
if line[i + 2] == "("
expr, i = read_arithmetic(line, i + 3)
output << expand_arithmetic(expr)
else
cmd, i = read_dollar_paren(line, i + 2)
output << run_command_substitution(cmd)
end
else
output << c
i += 1
@ -200,8 +205,13 @@ module Shell
output << run_command_substitution(cmd)
when "$"
if line[i + 1] == "("
cmd, i = read_dollar_paren(line, i + 2)
output << run_command_substitution(cmd)
if line[i + 2] == "("
expr, i = read_arithmetic(line, i + 3)
output << expand_arithmetic(expr)
else
cmd, i = read_dollar_paren(line, i + 2)
output << run_command_substitution(cmd)
end
else
output << c
i += 1
@ -269,6 +279,35 @@ module Shell
raise ArgumentError, "Unmatched $(...)"
end
def read_arithmetic(line, start_index)
output = +""
i = start_index
depth = 1
while i < line.length
c = line[i]
if c == "("
depth += 1
output << c
elsif c == ")"
depth -= 1
if depth.zero?
if line[i + 1] == ")"
return [output, i + 2]
else
depth += 1
output << c
end
else
output << c
end
else
output << c
end
i += 1
end
raise ArgumentError, "Unmatched $((...))"
end
def run_command_substitution(command)
stdout, status = Open3.capture2("/bin/sh", "-c", command)
raise Errno::ENOENT, command unless status.success?
@ -276,6 +315,151 @@ module Shell
stdout.tr("\n", " ")
end
def expand_arithmetic(expr)
tokens = tokenize_arithmetic(expr)
rpn = arithmetic_to_rpn(tokens)
evaluate_rpn(rpn).to_s
end
def tokenize_arithmetic(expr)
tokens = []
i = 0
while i < expr.length
c = expr[i]
if c.match?(/\s/)
i += 1
next
end
if c.match?(/\d/)
j = i + 1
j += 1 while j < expr.length && expr[j].match?(/\d/)
tokens << [:number, expr[i...j].to_i]
i = j
next
end
if c.match?(/[A-Za-z_]/)
j = i + 1
j += 1 while j < expr.length && expr[j].match?(/[A-Za-z0-9_]/)
name = expr[i...j]
value = ENV[name]
value = (value.nil? || value.empty?) ? 0 : value.to_i
tokens << [:number, value]
i = j
next
end
if c.match?(%r{[+\-*/()%]})
tokens << [:op, c]
i += 1
next
end
raise ArgumentError, "Invalid arithmetic expression: #{expr}"
end
tokens
end
def arithmetic_to_rpn(tokens)
output = []
ops = []
prev_type = nil
tokens.each do |type, value|
if type == :number
output << [:number, value]
prev_type = :number
next
end
op = value
if op == "("
ops << op
prev_type = :lparen
next
end
if op == ")"
while (top = ops.pop)
break if top == "("
output << [:op, top]
end
raise ArgumentError, "Unmatched ) in arithmetic expression" if top != "("
prev_type = :rparen
next
end
if op == "-" && (prev_type.nil? || prev_type == :op || prev_type == :lparen)
op = "u-"
elsif op == "+" && (prev_type.nil? || prev_type == :op || prev_type == :lparen)
op = "u+"
end
while !ops.empty? && precedence(ops.last) >= precedence(op)
output << [:op, ops.pop]
end
ops << op
prev_type = :op
end
while (top = ops.pop)
raise ArgumentError, "Unmatched ( in arithmetic expression" if top == "("
output << [:op, top]
end
output
end
def precedence(op)
case op
when "u+", "u-"
3
when "*", "/", "%"
2
when "+", "-"
1
else
0
end
end
def evaluate_rpn(rpn)
stack = []
rpn.each do |type, value|
if type == :number
stack << value
next
end
case value
when "u+"
raise ArgumentError, "Invalid arithmetic expression" if stack.empty?
stack << stack.pop
when "u-"
raise ArgumentError, "Invalid arithmetic expression" if stack.empty?
stack << -stack.pop
else
b = stack.pop
a = stack.pop
raise ArgumentError, "Invalid arithmetic expression" if a.nil? || b.nil?
stack << apply_operator(a, b, value)
end
end
raise ArgumentError, "Invalid arithmetic expression" unless stack.length == 1
stack[0]
end
def apply_operator(a, b, op)
case op
when "+"
a + b
when "-"
a - b
when "*"
a * b
when "/"
(b == 0) ? 0 : a / b
when "%"
(b == 0) ? 0 : a % b
else
raise ArgumentError, "Invalid arithmetic expression"
end
end
def expand_braces(word)
# Simple, non-nested brace expansion: pre{a,b}post -> preapost, prebpost
match = word.match(/(.*?)\{([^{}]*)\}(.*)/)

View file

@ -75,10 +75,13 @@ class ShellTest < Minitest::Test
end
def test_expands_arithmetic
skip "arithmetic expansion not implemented"
assert_equal "3", `#{A1_PATH} -c 'echo $((1 + 2))'`.chomp
end
def test_expands_arithmetic_with_variables
assert_equal "3", `A1_NUM=2 #{A1_PATH} -c 'echo $((A1_NUM + 1))'`.chomp
end
def test_expands_tilde_user
user = Etc.getlogin
skip "no login user" unless user