diff --git a/lib/rack/attack/store_proxy.rb b/lib/rack/attack/store_proxy.rb index d88ad67..4d69853 100644 --- a/lib/rack/attack/store_proxy.rb +++ b/lib/rack/attack/store_proxy.rb @@ -3,26 +3,30 @@ module Rack module StoreProxy PROXIES = [DalliProxy, MemCacheProxy, RedisStoreProxy] - def self.build(store) - # RedisStore#increment needs different behavior, so detect that - # (method has an arity of 2; must call #expire separately - if (defined?(::ActiveSupport::Cache::RedisStore) && store.is_a?(::ActiveSupport::Cache::RedisStore)) || - (defined?(::ActiveSupport::Cache::MemCacheStore) && store.is_a?(::ActiveSupport::Cache::MemCacheStore)) + ACTIVE_SUPPORT_WRAPPER_CLASSES = Set.new(['ActiveSupport::Cache::MemCacheStore', 'ActiveSupport::Cache::RedisStore']).freeze + ACTIVE_SUPPORT_CLIENTS = Set.new(['Redis::Store', 'Dalli::Client', 'MemCache']).freeze - # ActiveSupport::Cache::RedisStore doesn't expose any way to set an expiry, - # so use the raw Redis::Store instead. - # We also want to use the underlying Dalli client instead of ::ActiveSupport::Cache::MemCacheStore, - # and the MemCache client if using Rails 3.x - client = store.instance_variable_get(:@data) - if (defined?(::Redis::Store) && client.is_a?(Redis::Store)) || - (defined?(Dalli::Client) && client.is_a?(Dalli::Client)) || (defined?(MemCache) && client.is_a?(MemCache)) - store = store.instance_variable_get(:@data) - end - end - klass = PROXIES.find { |proxy| proxy.handle?(store) } - klass ? klass.new(store) : store + def self.build(store) + client = unwrap_active_support_stores(store) + klass = PROXIES.find { |proxy| proxy.handle?(client) } + klass ? klass.new(client) : client end + + private + def self.unwrap_active_support_stores(store) + # ActiveSupport::Cache::RedisStore doesn't expose any way to set an expiry, + # so use the raw Redis::Store instead. + # We also want to use the underlying Dalli client instead of ::ActiveSupport::Cache::MemCacheStore, + # and the MemCache client if using Rails 3.x + + client = store.instance_variable_get(:@data) + if ACTIVE_SUPPORT_WRAPPER_CLASSES.include?(store.class.to_s) && ACTIVE_SUPPORT_CLIENTS.include?(client.class.to_s) + client + else + store + end + end end end end