commit 2c51c6a10cef7b68d9acc7703014bea2e4bd3101 Author: Akshit Khurana axitkhurana@gmail.com Date: Thu Apr 4 18:09:06 2013 +0530
Make controller cache thread safe
reads and writes take place under cache lock https://trac.torproject.org/8607 --- stem/control.py | 139 ++++++++++++++++++++++++++++++++++++++++--------------- 1 files changed, 102 insertions(+), 37 deletions(-)
diff --git a/stem/control.py b/stem/control.py index 2e83a64..304a808 100644 --- a/stem/control.py +++ b/stem/control.py @@ -659,6 +659,8 @@ class Controller(BaseController): self._is_caching_enabled = True self._request_cache = {}
+ self._cache_lock = threading.RLock() + # mapping of event types to their listeners
self._event_listeners = {} @@ -753,13 +755,17 @@ class Controller(BaseController): params = set(params)
# check for cached results - for param in list(params): - cache_key = "getinfo.%s" % param.lower() + from_cache = [param.lower() for param in params] + cached_results = self._get_cache(from_cache, "getinfo") + + reply = {} + for key in cached_results.keys(): + user_expected_key = _case_insensitive_lookup(params, key, key) + reply[user_expected_key] = cached_results[key] + params.remove(user_expected_key)
- if cache_key in self._request_cache: - reply[param] = self._request_cache[cache_key] - params.remove(param) - elif param.startswith('ip-to-country/') and self.is_geoip_unavailable(): + for param in params: + if param.startswith('ip-to-country/') and self.is_geoip_unavailable(): # the geoip database already looks to be unavailable - abort the request if default == UNDEFINED: raise stem.ProtocolError("Tor geoip database is unavailable") @@ -782,16 +788,19 @@ class Controller(BaseController): reply.update(response.entries)
if self.is_caching_enabled(): + to_cache = {} for key, value in response.entries.items(): key = key.lower() # make case insensitive
if key in CACHEABLE_GETINFO_PARAMS: - self._request_cache["getinfo.%s" % key] = value + to_cache[key] = value elif key.startswith('ip-to-country/'): # both cache-able and means that we should reset the geoip failure count - self._request_cache["getinfo.%s" % key] = value + to_cache[key] = value self._geoip_failure_count = -1
+ self._set_cache(to_cache, "getinfo") + log.debug("GETINFO %s (runtime: %0.4f)" % (" ".join(params), time.time() - start_time))
if is_multiple: @@ -839,11 +848,14 @@ class Controller(BaseController): try: if not self.is_caching_enabled(): return stem.version.Version(self.get_info("version")) - elif not "version" in self._request_cache: - version = stem.version.Version(self.get_info("version")) - self._request_cache["version"] = version + else: + version = self._get_cache(["version"]).get("version", None)
- return self._request_cache["version"] + if not version: + version = stem.version.Version(self.get_info("version")) + self._set_cache({"version": version}) + + return version except Exception as exc: if default == UNDEFINED: raise exc @@ -866,10 +878,11 @@ class Controller(BaseController):
An exception is only raised if we weren't provided a default response. """ - with self._msg_lock: try: - if not "exit_policy" in self._request_cache: + config_policy = self._get_cache(["exit_policy"]).get("exit_policy", + None) + if not config_policy: policy = []
if self.get_conf("ExitPolicyRejectPrivate") == "1": @@ -884,10 +897,10 @@ class Controller(BaseController): policy += policy_line.split(",")
policy += self.get_info("exit-policy/default").split(",") + config_policy = stem.exit_policy.get_config_policy(policy) + self._set_cache({"exit_policy": config_policy})
- self._request_cache["exit_policy"] = stem.exit_policy.get_config_policy(policy) - - return self._request_cache["exit_policy"] + return config_policy except Exception as exc: if default == UNDEFINED: raise exc @@ -1316,12 +1329,14 @@ class Controller(BaseController): lookup_params = set([MAPPED_CONFIG_KEYS.get(entry, entry) for entry in params])
# check for cached results - for param in list(lookup_params): - cache_key = "getconf.%s" % param.lower() + from_cache = [param.lower() for param in lookup_params] + cached_results = self._get_cache(from_cache, "getconf")
- if cache_key in self._request_cache: - reply[param] = self._request_cache[cache_key] - lookup_params.remove(param) + reply = {} + for key in cached_results.keys(): + user_expected_key = _case_insensitive_lookup(lookup_params, key, key) + reply[user_expected_key] = cached_results[key] + lookup_params.remove(user_expected_key)
# if everything was cached then short circuit making the query if not lookup_params: @@ -1334,9 +1349,15 @@ class Controller(BaseController): reply.update(response.entries)
if self.is_caching_enabled(): - for key, value in response.entries.items(): - self._request_cache["getconf.%s" % key.lower()] = value + to_cache = {} + for param, value in response.entries.items(): + param = param.lower()
+ if isinstance(value, (bytes, unicode)): + value = [value] + to_cache[param] = value + + self._set_cache(to_cache, "getconf") # Maps the entries back to the parameters that the user requested so the # capitalization matches (ie, if they request "exitpolicy" then that # should be the key rather than "ExitPolicy"). When the same @@ -1481,19 +1502,13 @@ class Controller(BaseController): log.debug("%s (runtime: %0.4f)" % (query, time.time() - start_time))
if self.is_caching_enabled(): + to_cache = {} for param, value in params: - cache_key = "getconf.%s" % param.lower() + if isinstance(value, (bytes, unicode)): + value = [value] + to_cache[param.lower()] = value
- if value is None: - if cache_key in self._request_cache: - del self._request_cache[cache_key] - elif isinstance(value, (bytes, unicode)): - self._request_cache[cache_key] = [value] - else: - self._request_cache[cache_key] = value - - if param.lower() == "exitpolicy" and "exit_policy" in self._request_cache: - del self._request_cache["exit_policy"] + self._set_cache(to_cache, "getconf") else: log.debug("%s (failed, code: %s, message: %s)" % (query, response.code, response.message))
@@ -1582,6 +1597,55 @@ class Controller(BaseController): if not response.is_ok(): raise stem.ProtocolError("SETEVENTS received unexpected response\n%s" % response)
+ def _get_cache(self, params, func=None): + """ + Queries multiple configuration options in cache atomically, returning a + mapping of those options to their values. + + :param list params: keys to be queried in cache + :param str func: function prefix to keys + + :returns: **dict** of 'param => cached value' pairs of keys present in cache + """ + + with self._cache_lock: + cached_values = {} + + for param in params: + if func: + cache_key = "%s.%s" % (func, param) + else: + cache_key = param + + if cache_key in self._request_cache: + cached_values[param] = self._request_cache[cache_key] + + return cached_values + + def _set_cache(self, params, func=None): + """ + Changes the value of tor configuration options in cache atomically. + + :param dict params: **dict** of 'cache_key => value' pairs to be cached + :param str func: function prefix to keys + """ + + with self._cache_lock: + for key, value in params.items(): + if func: + cache_key = "%s.%s" % (func, key) + else: + cache_key = key + + if value is None: + if cache_key in self._request_cache: + del self._request_cache[cache_key] + else: + self._request_cache[cache_key] = value + + if key.lower() == "exitpolicy" and "exit_policy" in self._request_cache: + del self._request_cache["exit_policy"] + def is_caching_enabled(self): """ **True** if caching has been enabled, **False** otherwise. @@ -1608,8 +1672,9 @@ class Controller(BaseController): Drops any cached results. """
- self._request_cache = {} - self._geoip_failure_count = 0 + with self._cache_lock: + self._request_cache = {} + self._geoip_failure_count = 0
def load_conf(self, configtext): """
tor-commits@lists.torproject.org