mirror of
https://github.com/1SecondEveryday/image-analysis-eval.git
synced 2026-03-25 09:05:49 +00:00
Pull models properly when benchmarking
This commit is contained in:
parent
86a382a700
commit
7b6a1e5479
2 changed files with 47 additions and 9 deletions
|
|
@ -27,7 +27,7 @@ end
|
|||
|
||||
# Pull model if needed
|
||||
puts "Ensuring model is available..."
|
||||
unless `ollama list`.include?(test_model.split(':').first)
|
||||
unless `ollama list`.include?(test_model)
|
||||
system("ollama pull #{test_model}")
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,18 @@ require 'concurrent'
|
|||
|
||||
class TagExtractor
|
||||
OLLAMA_URL = 'http://localhost:11434/api/generate'
|
||||
DEFAULT_MODELS = ['qwen2.5vl:3b', 'moondream:1.8b', 'llava:7b', 'llava:13b', 'llama3.2-vision:11b', 'llava-phi3:3.8b']
|
||||
DEFAULT_MODELS = {
|
||||
'qwen2.5vl:3b' => 2, # Your benchmark showed slight benefit at 2
|
||||
'moondream:1.8b' => 8, # Your benchmark showed parallelism hurts this model
|
||||
'llava:7b' => 2,
|
||||
'llava:13b' => 2,
|
||||
'llama3.2-vision:11b' => 2,
|
||||
'llava-phi3:3.8b' => 2
|
||||
}
|
||||
VALID_EXTENSIONS = %w[.jpg .jpeg .png .gif .bmp .tiff .tif].freeze
|
||||
|
||||
def initialize(options = {})
|
||||
@parallel = options[:parallel] || 8
|
||||
@global_parallel = options[:parallel] # Global override if specified
|
||||
@models = options[:models] || DEFAULT_MODELS
|
||||
@timeout = options[:timeout] || 120
|
||||
@verbose = options[:verbose] || false
|
||||
|
|
@ -44,7 +51,6 @@ class TagExtractor
|
|||
puts " • #{images.length} images found"
|
||||
puts " • #{prompts.length} prompts loaded"
|
||||
puts " • #{@models.length} models to test"
|
||||
puts " • #{@parallel} parallel requests"
|
||||
puts
|
||||
|
||||
total_tasks = images.length * prompts.length * @models.length
|
||||
|
|
@ -56,9 +62,21 @@ class TagExtractor
|
|||
master_csv << %w[model image_size prompt_name image_filename tags raw_output timestamp success]
|
||||
|
||||
# Process in batches by model to allow proper cleanup
|
||||
@models.each_with_index do |model, model_index|
|
||||
model_list = @models.is_a?(Hash) ? @models.keys : @models
|
||||
|
||||
model_list.each_with_index do |model, model_index|
|
||||
# Determine parallelism for this model
|
||||
parallel = if @global_parallel
|
||||
@global_parallel # Use global override if specified
|
||||
elsif @models.is_a?(Hash)
|
||||
@models[model] || 2 # Use model-specific or default to 2
|
||||
else
|
||||
2 # Default parallelism
|
||||
end
|
||||
|
||||
puts "\n" + "=" * 60
|
||||
puts "📊 Model #{model_index + 1}/#{@models.length}: #{model}"
|
||||
puts "📊 Model #{model_index + 1}/#{model_list.length}: #{model}"
|
||||
puts " Parallelism: #{parallel}"
|
||||
puts "=" * 60
|
||||
|
||||
# Check if model exists and pull if needed
|
||||
|
|
@ -78,7 +96,7 @@ class TagExtractor
|
|||
ensure_model_loaded(model)
|
||||
|
||||
# Create thread pool for this model
|
||||
pool = Concurrent::FixedThreadPool.new(@parallel)
|
||||
pool = Concurrent::FixedThreadPool.new(parallel)
|
||||
|
||||
prompts.each do |prompt_file, prompt_content|
|
||||
prompt_name = File.basename(prompt_file, '.*')
|
||||
|
|
@ -418,8 +436,28 @@ if __FILE__ == $0
|
|||
options[:parallel] = n
|
||||
end
|
||||
|
||||
opts.on("-m", "--models MODELS", "Comma-separated list of models") do |models|
|
||||
options[:models] = models.split(',').map(&:strip)
|
||||
opts.on("-m", "--models MODELS", "Comma-separated list of models or model:parallel pairs") do |models|
|
||||
model_list = models.split(',').map(&:strip)
|
||||
|
||||
# Check if any model has parallelism specified
|
||||
if model_list.any? { |m| m.include?(':') && m.split(':').length > 2 }
|
||||
# Parse model:parallel format
|
||||
model_hash = {}
|
||||
model_list.each do |entry|
|
||||
parts = entry.split(':')
|
||||
if parts.length > 2 # Has parallelism (e.g., llava:7b:4)
|
||||
model_name = parts[0..-2].join(':')
|
||||
parallel = parts.last.to_i
|
||||
model_hash[model_name] = parallel > 0 ? parallel : 2
|
||||
else # Just model name
|
||||
model_hash[entry] = 2
|
||||
end
|
||||
end
|
||||
options[:models] = model_hash
|
||||
else
|
||||
# Simple list of models
|
||||
options[:models] = model_list
|
||||
end
|
||||
end
|
||||
|
||||
opts.on("-t", "--timeout SECONDS", Integer, "Request timeout in seconds (default: 120)") do |t|
|
||||
|
|
|
|||
Loading…
Reference in a new issue