Pull models properly when benchmarking

This commit is contained in:
Sami Samhuri 2025-06-24 10:20:18 -04:00
parent 86a382a700
commit 7b6a1e5479
No known key found for this signature in database
2 changed files with 47 additions and 9 deletions

View file

@ -27,7 +27,7 @@ end
# Pull model if needed
puts "Ensuring model is available..."
unless `ollama list`.include?(test_model.split(':').first)
unless `ollama list`.include?(test_model)
system("ollama pull #{test_model}")
end

View file

@ -12,11 +12,18 @@ require 'concurrent'
class TagExtractor
OLLAMA_URL = 'http://localhost:11434/api/generate'
DEFAULT_MODELS = ['qwen2.5vl:3b', 'moondream:1.8b', 'llava:7b', 'llava:13b', 'llama3.2-vision:11b', 'llava-phi3:3.8b']
DEFAULT_MODELS = {
'qwen2.5vl:3b' => 2, # Your benchmark showed slight benefit at 2
'moondream:1.8b' => 8, # Your benchmark showed parallelism hurts this model
'llava:7b' => 2,
'llava:13b' => 2,
'llama3.2-vision:11b' => 2,
'llava-phi3:3.8b' => 2
}
VALID_EXTENSIONS = %w[.jpg .jpeg .png .gif .bmp .tiff .tif].freeze
def initialize(options = {})
@parallel = options[:parallel] || 8
@global_parallel = options[:parallel] # Global override if specified
@models = options[:models] || DEFAULT_MODELS
@timeout = options[:timeout] || 120
@verbose = options[:verbose] || false
@ -44,7 +51,6 @@ class TagExtractor
puts "#{images.length} images found"
puts "#{prompts.length} prompts loaded"
puts "#{@models.length} models to test"
puts "#{@parallel} parallel requests"
puts
total_tasks = images.length * prompts.length * @models.length
@ -56,9 +62,21 @@ class TagExtractor
master_csv << %w[model image_size prompt_name image_filename tags raw_output timestamp success]
# Process in batches by model to allow proper cleanup
@models.each_with_index do |model, model_index|
model_list = @models.is_a?(Hash) ? @models.keys : @models
model_list.each_with_index do |model, model_index|
# Determine parallelism for this model
parallel = if @global_parallel
@global_parallel # Use global override if specified
elsif @models.is_a?(Hash)
@models[model] || 2 # Use model-specific or default to 2
else
2 # Default parallelism
end
puts "\n" + "=" * 60
puts "📊 Model #{model_index + 1}/#{@models.length}: #{model}"
puts "📊 Model #{model_index + 1}/#{model_list.length}: #{model}"
puts " Parallelism: #{parallel}"
puts "=" * 60
# Check if model exists and pull if needed
@ -78,7 +96,7 @@ class TagExtractor
ensure_model_loaded(model)
# Create thread pool for this model
pool = Concurrent::FixedThreadPool.new(@parallel)
pool = Concurrent::FixedThreadPool.new(parallel)
prompts.each do |prompt_file, prompt_content|
prompt_name = File.basename(prompt_file, '.*')
@ -418,8 +436,28 @@ if __FILE__ == $0
options[:parallel] = n
end
opts.on("-m", "--models MODELS", "Comma-separated list of models") do |models|
options[:models] = models.split(',').map(&:strip)
opts.on("-m", "--models MODELS", "Comma-separated list of models or model:parallel pairs") do |models|
model_list = models.split(',').map(&:strip)
# Check if any model has parallelism specified
if model_list.any? { |m| m.include?(':') && m.split(':').length > 2 }
# Parse model:parallel format
model_hash = {}
model_list.each do |entry|
parts = entry.split(':')
if parts.length > 2 # Has parallelism (e.g., llava:7b:4)
model_name = parts[0..-2].join(':')
parallel = parts.last.to_i
model_hash[model_name] = parallel > 0 ? parallel : 2
else # Just model name
model_hash[entry] = 2
end
end
options[:models] = model_hash
else
# Simple list of models
options[:models] = model_list
end
end
opts.on("-t", "--timeout SECONDS", Integer, "Request timeout in seconds (default: 120)") do |t|