[tor-commits] [stem/master] Make the authentication process asynchronous

atagar at torproject.org atagar at torproject.org
Thu Jul 16 01:28:58 UTC 2020


commit 7d9a8d057a83eea3b12bfe1a25c45d31b061c72f
Author: Illia Volochii <illia.volochii at gmail.com>
Date:   Tue Apr 14 23:42:59 2020 +0300

    Make the authentication process asynchronous
---
 stem/connection.py | 62 +++++++++++++++++++++++++++---------------------------
 stem/control.py    | 54 +++++++++++++++++++++++------------------------
 2 files changed, 58 insertions(+), 58 deletions(-)

diff --git a/stem/connection.py b/stem/connection.py
index 0ce2a153..eac38e2f 100644
--- a/stem/connection.py
+++ b/stem/connection.py
@@ -366,7 +366,7 @@ def _connect_auth(control_socket: stem.socket.ControlSocket, password: str, pass
     return None
 
 
-def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
+async def authenticate(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: Optional[str] = None, chroot_path: Optional[str] = None, protocolinfo_response: Optional[stem.response.protocolinfo.ProtocolInfoResponse] = None) -> None:
   """
   Authenticates to a control socket using the information provided by a
   PROTOCOLINFO response. In practice this will often be all we need to
@@ -477,7 +477,7 @@ def authenticate(controller: Union[stem.control.BaseController, stem.socket.Cont
 
   if not protocolinfo_response:
     try:
-      protocolinfo_response = get_protocolinfo(controller)
+      protocolinfo_response = await get_protocolinfo(controller)
     except stem.ProtocolError:
       raise IncorrectSocketType('unable to use the control socket')
     except stem.SocketError as exc:
@@ -524,9 +524,9 @@ def authenticate(controller: Union[stem.control.BaseController, stem.socket.Cont
 
     try:
       if auth_type == AuthMethod.NONE:
-        authenticate_none(controller, False)
+        await authenticate_none(controller, False)
       elif auth_type == AuthMethod.PASSWORD:
-        authenticate_password(controller, password, False)
+        await authenticate_password(controller, password, False)
       elif auth_type in (AuthMethod.COOKIE, AuthMethod.SAFECOOKIE):
         cookie_path = protocolinfo_response.cookie_path
 
@@ -534,12 +534,12 @@ def authenticate(controller: Union[stem.control.BaseController, stem.socket.Cont
           cookie_path = os.path.join(chroot_path, cookie_path.lstrip(os.path.sep))
 
         if auth_type == AuthMethod.SAFECOOKIE:
-          authenticate_safecookie(controller, cookie_path, False)
+          await authenticate_safecookie(controller, cookie_path, False)
         else:
-          authenticate_cookie(controller, cookie_path, False)
+          await authenticate_cookie(controller, cookie_path, False)
 
       if isinstance(controller, stem.control.BaseController):
-        controller._post_authentication()
+        await controller._post_authentication()
 
       return  # success!
     except OpenAuthRejected as exc:
@@ -580,7 +580,7 @@ def authenticate(controller: Union[stem.control.BaseController, stem.socket.Cont
   raise AssertionError('BUG: Authentication failed without providing a recognized exception: %s' % str(auth_exceptions))
 
 
-def authenticate_none(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], suppress_ctl_errors: bool = True) -> None:
+async def authenticate_none(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], suppress_ctl_errors: bool = True) -> None:
   """
   Authenticates to an open control socket. All control connections need to
   authenticate before they can be used, even if tor hasn't been configured to
@@ -605,19 +605,19 @@ def authenticate_none(controller: Union[stem.control.BaseController, stem.socket
   """
 
   try:
-    auth_response = _msg(controller, 'AUTHENTICATE')
+    auth_response = await _msg(controller, 'AUTHENTICATE')
 
     # if we got anything but an OK response then error
     if str(auth_response) != 'OK':
       try:
-        controller.connect()
+        await controller.connect()
       except:
         pass
 
       raise OpenAuthRejected(str(auth_response), auth_response)
   except stem.ControllerError as exc:
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -627,7 +627,7 @@ def authenticate_none(controller: Union[stem.control.BaseController, stem.socket
       raise OpenAuthRejected('Socket failed (%s)' % exc)
 
 
-def authenticate_password(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: str, suppress_ctl_errors: bool = True) -> None:
+async def authenticate_password(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], password: str, suppress_ctl_errors: bool = True) -> None:
   """
   Authenticates to a control socket that uses a password (via the
   HashedControlPassword torrc option). Quotes in the password are escaped.
@@ -668,12 +668,12 @@ def authenticate_password(controller: Union[stem.control.BaseController, stem.so
   password = password.replace('"', '\\"')
 
   try:
-    auth_response = _msg(controller, 'AUTHENTICATE "%s"' % password)
+    auth_response = await _msg(controller, 'AUTHENTICATE "%s"' % password)
 
     # if we got anything but an OK response then error
     if str(auth_response) != 'OK':
       try:
-        controller.connect()
+        await controller.connect()
       except:
         pass
 
@@ -687,7 +687,7 @@ def authenticate_password(controller: Union[stem.control.BaseController, stem.so
         raise PasswordAuthRejected(str(auth_response), auth_response)
   except stem.ControllerError as exc:
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -697,7 +697,7 @@ def authenticate_password(controller: Union[stem.control.BaseController, stem.so
       raise PasswordAuthRejected('Socket failed (%s)' % exc)
 
 
-def authenticate_cookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
+async def authenticate_cookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
   """
   Authenticates to a control socket that uses the contents of an authentication
   cookie (generated via the CookieAuthentication torrc option). This does basic
@@ -757,12 +757,12 @@ def authenticate_cookie(controller: Union[stem.control.BaseController, stem.sock
 
     auth_token_hex = binascii.b2a_hex(stem.util.str_tools._to_bytes(cookie_data))
     msg = 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(auth_token_hex)
-    auth_response = _msg(controller, msg)
+    auth_response = await _msg(controller, msg)
 
     # if we got anything but an OK response then error
     if str(auth_response) != 'OK':
       try:
-        controller.connect()
+        await controller.connect()
       except:
         pass
 
@@ -777,7 +777,7 @@ def authenticate_cookie(controller: Union[stem.control.BaseController, stem.sock
         raise CookieAuthRejected(str(auth_response), cookie_path, False, auth_response)
   except stem.ControllerError as exc:
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -787,7 +787,7 @@ def authenticate_cookie(controller: Union[stem.control.BaseController, stem.sock
       raise CookieAuthRejected('Socket failed (%s)' % exc, cookie_path, False)
 
 
-def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
+async def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], cookie_path: str, suppress_ctl_errors: bool = True) -> None:
   """
   Authenticates to a control socket using the safe cookie method, which is
   enabled by setting the CookieAuthentication torrc option on Tor client's which
@@ -853,11 +853,11 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
 
   try:
     client_nonce_hex = stem.util.str_tools._to_unicode(binascii.b2a_hex(client_nonce))
-    authchallenge_response = _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex)  # type: ignore
+    authchallenge_response = await _msg(controller, 'AUTHCHALLENGE SAFECOOKIE %s' % client_nonce_hex)  # type: ignore
 
     if not authchallenge_response.is_ok():
       try:
-        controller.connect()
+        await controller.connect()
       except:
         pass
 
@@ -880,7 +880,7 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
         raise AuthChallengeFailed(authchallenge_response_str, cookie_path)
   except stem.ControllerError as exc:
     try:
-      controller.connect()
+      await controller.connect()
     except:
       pass
 
@@ -912,7 +912,7 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
       CLIENT_HASH_CONSTANT,
       cookie_data + client_nonce + authchallenge_response.server_nonce)
 
-    auth_response = _msg(controller, 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(binascii.b2a_hex(client_hash)))
+    auth_response = await _msg(controller, 'AUTHENTICATE %s' % stem.util.str_tools._to_unicode(binascii.b2a_hex(client_hash)))
   except stem.ControllerError as exc:
     try:
       controller.connect()
@@ -942,7 +942,7 @@ def authenticate_safecookie(controller: Union[stem.control.BaseController, stem.
       raise CookieAuthRejected(str(auth_response), cookie_path, True, auth_response)
 
 
-def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.ControlSocket]) -> stem.response.protocolinfo.ProtocolInfoResponse:
+async def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.ControlSocket]) -> stem.response.protocolinfo.ProtocolInfoResponse:
   """
   Issues a PROTOCOLINFO query to a control socket, getting information about
   the tor process running on it. If the socket is already closed then it is
@@ -963,7 +963,7 @@ def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.
   """
 
   try:
-    protocolinfo_response = _msg(controller, 'PROTOCOLINFO 1')
+    protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
   except:
     protocolinfo_response = None
 
@@ -974,7 +974,7 @@ def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.
     controller.connect()
 
     try:
-      protocolinfo_response = _msg(controller, 'PROTOCOLINFO 1')
+      protocolinfo_response = await _msg(controller, 'PROTOCOLINFO 1')
     except stem.SocketClosed as exc:
       raise stem.SocketError(exc)
 
@@ -982,17 +982,17 @@ def get_protocolinfo(controller: Union[stem.control.BaseController, stem.socket.
   return protocolinfo_response  # type: ignore
 
 
-def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage:
+async def _msg(controller: Union[stem.control.BaseController, stem.socket.ControlSocket], message: str) -> stem.response.ControlMessage:
   """
   Sends and receives a message with either a
   :class:`~stem.socket.ControlSocket` or :class:`~stem.control.BaseController`.
   """
 
   if isinstance(controller, stem.socket.ControlSocket):
-    controller.send(message)
-    return controller.recv()
+    await controller.send(message)
+    return await controller.recv()
   else:
-    return controller.msg(message)
+    return await controller.msg(message)
 
 
 def _connection_for_default_port(address: str) -> stem.socket.ControlPort:
diff --git a/stem/control.py b/stem/control.py
index 0559ee39..435fb741 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -576,7 +576,7 @@ class BaseController(object):
       self._create_loop_tasks()
 
     if is_authenticated:
-      self._post_authentication()
+      self._asyncio_loop.create_task(self._post_authentication())
 
   async def msg(self, message: str) -> stem.response.ControlMessage:
     """
@@ -843,7 +843,7 @@ class BaseController(object):
 
     await self._socket_close()
 
-  def _post_authentication(self) -> None:
+  async def _post_authentication(self) -> None:
     # actions to be taken after we have a newly authenticated connection
 
     self._is_authenticated = True
@@ -1042,7 +1042,7 @@ class Controller(BaseController):
         self.clear_cache()
         self._notify_status_listeners(State.RESET)
 
-    self.add_event_listener(_sighup_listener, EventType.SIGNAL)  # type: ignore
+    self._asyncio_loop.create_task(self.add_event_listener(_sighup_listener, EventType.SIGNAL))
 
     def _confchanged_listener(event: stem.response.events.ConfChangedEvent) -> None:
       if self.is_caching_enabled():
@@ -1057,7 +1057,7 @@ class Controller(BaseController):
 
         self._confchanged_cache_invalidation(to_cache)
 
-    self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED)  # type: ignore
+    self._asyncio_loop.create_task(self.add_event_listener(_confchanged_listener, EventType.CONF_CHANGED))
 
     def _address_changed_listener(event: stem.response.events.StatusEvent) -> None:
       if event.action in ('EXTERNAL_ADDRESS', 'DNS_USELESS'):
@@ -1065,20 +1065,20 @@ class Controller(BaseController):
         self._set_cache({'address': None}, 'getinfo')
         self._last_address_exc = None
 
-    self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER)  # type: ignore
+    self._asyncio_loop.create_task(self.add_event_listener(_address_changed_listener, EventType.STATUS_SERVER))
 
-  def close(self) -> None:
+  async def close(self) -> None:
     self.clear_cache()
-    super(Controller, self).close()
+    await super(Controller, self).close()
 
-  def authenticate(self, *args: Any, **kwargs: Any) -> None:
+  async def authenticate(self, *args: Any, **kwargs: Any) -> None:
     """
     A convenience method to authenticate the controller. This is just a
     pass-through to :func:`stem.connection.authenticate`.
     """
 
     import stem.connection
-    stem.connection.authenticate(self, *args, **kwargs)
+    await stem.connection.authenticate(self, *args, **kwargs)
 
   def reconnect(self, *args: Any, **kwargs: Any) -> None:
     """
@@ -2073,7 +2073,7 @@ class Controller(BaseController):
       if hs_desc_content_listener:
         self.remove_event_listener(hs_desc_content_listener)
 
-  def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
+  async def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]:
     """
     get_conf(param, default = UNDEFINED, multiple = False)
 
@@ -2119,18 +2119,18 @@ class Controller(BaseController):
     if not param:
       return default if default != UNDEFINED else None
 
-    entries = self.get_conf_map(param, default, multiple)
+    entries = await self.get_conf_map(param, default, multiple)
     return _case_insensitive_lookup(entries, param, default)
 
   # TODO: temporary aliases until we have better type support in our API
 
-  def _get_conf_single(self, param: str, default: Any = UNDEFINED) -> str:
-    return self.get_conf(param, default)  # type: ignore
+  async def _get_conf_single(self, param: str, default: Any = UNDEFINED) -> str:
+    return await self.get_conf(param, default)  # type: ignore
 
-  def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]:
-    return self.get_conf(param, default, multiple = True)  # type: ignore
+  async def _get_conf_multiple(self, param: str, default: Any = UNDEFINED) -> List[str]:
+    return await self.get_conf(param, default, multiple = True)  # type: ignore
 
-  def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
+  await def get_conf_map(self, params: Union[str, Sequence[str]], default: Any = UNDEFINED, multiple: bool = True) -> Dict[str, Union[str, Sequence[str]]]:
     """
     get_conf_map(params, default = UNDEFINED, multiple = True)
 
@@ -2211,7 +2211,7 @@ class Controller(BaseController):
       return self._get_conf_dict_to_response(reply, default, multiple)
 
     try:
-      response = stem.response._convert_to_getconf(self.msg('GETCONF %s' % ' '.join(lookup_params)))
+      response = stem.response._convert_to_getconf(await self.msg('GETCONF %s' % ' '.join(lookup_params)))
       reply.update(response.entries)
 
       if self.is_caching_enabled():
@@ -3001,7 +3001,7 @@ class Controller(BaseController):
     else:
       raise stem.ProtocolError('DEL_ONION returned unexpected response code: %s' % response.code)
 
-  def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None:
+  async def add_event_listener(self, listener: Callable[[stem.response.events.Event], None], *events: 'stem.control.EventType') -> None:
     """
     Directs further tor controller events to a given function. The function is
     expected to take a single argument, which is a
@@ -3050,7 +3050,7 @@ class Controller(BaseController):
       for event_type in events:
         self._event_listeners.setdefault(event_type, []).append(listener)
 
-      failed_events = self._attach_listeners()[1]
+      failed_events = (await self._attach_listeners())[1]
 
       # restricted the failures to just things we requested
 
@@ -3732,14 +3732,14 @@ class Controller(BaseController):
 
     self.msg('DROPGUARDS')
 
-  def _post_authentication(self) -> None:
-    super(Controller, self)._post_authentication()
+  async def _post_authentication(self) -> None:
+    await super(Controller, self)._post_authentication()
 
     # try to re-attach event listeners to the new instance
 
     with self._event_listeners_lock:
       try:
-        failed_events = self._attach_listeners()[1]
+        failed_events = (await self._attach_listeners())[1]
 
         if failed_events:
           # remove our listeners for these so we don't keep failing
@@ -3753,10 +3753,10 @@ class Controller(BaseController):
 
     # issue TAKEOWNERSHIP if we're the owning process for this tor instance
 
-    owning_pid = self.get_conf('__OwningControllerProcess', None)
+    owning_pid = await self.get_conf('__OwningControllerProcess', None)
 
     if owning_pid == str(os.getpid()) and self.is_localhost():
-      response = stem.response._convert_to_single_line(self.msg('TAKEOWNERSHIP'))
+      response = stem.response._convert_to_single_line(await self.msg('TAKEOWNERSHIP'))
 
       if response.is_ok():
         # Now that tor is tracking our ownership of the process via the control
@@ -3793,7 +3793,7 @@ class Controller(BaseController):
             except Exception as exc:
               log.warn('Event listener raised an uncaught exception (%s): %s' % (exc, event))
 
-  def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]:
+  async def _attach_listeners(self) -> Tuple[Sequence[str], Sequence[str]]:
     """
     Attempts to subscribe to the self._event_listeners events from tor. This is
     a no-op if we're not currently authenticated.
@@ -3808,7 +3808,7 @@ class Controller(BaseController):
     with self._event_listeners_lock:
       if self.is_authenticated():
         # try to set them all
-        response = self.msg('SETEVENTS %s' % ' '.join(self._event_listeners.keys()))
+        response = await self.msg('SETEVENTS %s' % ' '.join(self._event_listeners.keys()))
 
         if response.is_ok():
           set_events = list(self._event_listeners.keys())
@@ -3827,7 +3827,7 @@ class Controller(BaseController):
           # See if we can set some subset of our events.
 
           for event in list(self._event_listeners.keys()):
-            response = self.msg('SETEVENTS %s' % ' '.join(set_events + [event]))
+            response = await self.msg('SETEVENTS %s' % ' '.join(set_events + [event]))
 
             if response.is_ok():
               set_events.append(event)





More information about the tor-commits mailing list