diff --git a/benchmark_parallel.rb b/benchmark_parallel.rb index c15b5b1..24df97f 100755 --- a/benchmark_parallel.rb +++ b/benchmark_parallel.rb @@ -27,7 +27,7 @@ end # Pull model if needed puts "Ensuring model is available..." -unless `ollama list`.include?(test_model.split(':').first) +unless `ollama list`.include?(test_model) system("ollama pull #{test_model}") end diff --git a/extract_tags.rb b/extract_tags.rb index 3305e9b..a615f3e 100755 --- a/extract_tags.rb +++ b/extract_tags.rb @@ -12,11 +12,18 @@ require 'concurrent' class TagExtractor OLLAMA_URL = 'http://localhost:11434/api/generate' - DEFAULT_MODELS = ['qwen2.5vl:3b', 'moondream:1.8b', 'llava:7b', 'llava:13b', 'llama3.2-vision:11b', 'llava-phi3:3.8b'] + DEFAULT_MODELS = { + 'qwen2.5vl:3b' => 2, # Your benchmark showed slight benefit at 2 + 'moondream:1.8b' => 8, # Your benchmark showed parallelism hurts this model + 'llava:7b' => 2, + 'llava:13b' => 2, + 'llama3.2-vision:11b' => 2, + 'llava-phi3:3.8b' => 2 + } VALID_EXTENSIONS = %w[.jpg .jpeg .png .gif .bmp .tiff .tif].freeze def initialize(options = {}) - @parallel = options[:parallel] || 8 + @global_parallel = options[:parallel] # Global override if specified @models = options[:models] || DEFAULT_MODELS @timeout = options[:timeout] || 120 @verbose = options[:verbose] || false @@ -44,7 +51,6 @@ class TagExtractor puts " • #{images.length} images found" puts " • #{prompts.length} prompts loaded" puts " • #{@models.length} models to test" - puts " • #{@parallel} parallel requests" puts total_tasks = images.length * prompts.length * @models.length @@ -56,9 +62,21 @@ class TagExtractor master_csv << %w[model image_size prompt_name image_filename tags raw_output timestamp success] # Process in batches by model to allow proper cleanup - @models.each_with_index do |model, model_index| + model_list = @models.is_a?(Hash) ? @models.keys : @models + + model_list.each_with_index do |model, model_index| + # Determine parallelism for this model + parallel = if @global_parallel + @global_parallel # Use global override if specified + elsif @models.is_a?(Hash) + @models[model] || 2 # Use model-specific or default to 2 + else + 2 # Default parallelism + end + puts "\n" + "=" * 60 - puts "📊 Model #{model_index + 1}/#{@models.length}: #{model}" + puts "📊 Model #{model_index + 1}/#{model_list.length}: #{model}" + puts " Parallelism: #{parallel}" puts "=" * 60 # Check if model exists and pull if needed @@ -78,7 +96,7 @@ class TagExtractor ensure_model_loaded(model) # Create thread pool for this model - pool = Concurrent::FixedThreadPool.new(@parallel) + pool = Concurrent::FixedThreadPool.new(parallel) prompts.each do |prompt_file, prompt_content| prompt_name = File.basename(prompt_file, '.*') @@ -418,8 +436,28 @@ if __FILE__ == $0 options[:parallel] = n end - opts.on("-m", "--models MODELS", "Comma-separated list of models") do |models| - options[:models] = models.split(',').map(&:strip) + opts.on("-m", "--models MODELS", "Comma-separated list of models or model:parallel pairs") do |models| + model_list = models.split(',').map(&:strip) + + # Check if any model has parallelism specified + if model_list.any? { |m| m.include?(':') && m.split(':').length > 2 } + # Parse model:parallel format + model_hash = {} + model_list.each do |entry| + parts = entry.split(':') + if parts.length > 2 # Has parallelism (e.g., llava:7b:4) + model_name = parts[0..-2].join(':') + parallel = parts.last.to_i + model_hash[model_name] = parallel > 0 ? parallel : 2 + else # Just model name + model_hash[entry] = 2 + end + end + options[:models] = model_hash + else + # Simple list of models + options[:models] = model_list + end end opts.on("-t", "--timeout SECONDS", Integer, "Request timeout in seconds (default: 120)") do |t|