[tor-commits] [stem/master] Move controller connecting to aenter

atagar at torproject.org atagar at torproject.org
Wed Jul 29 23:58:18 UTC 2020


commit a77ff0225a8d92bd62a5c9ce7887f3abddff2114
Author: Damian Johnson <atagar at torproject.org>
Date:   Wed Jul 29 16:31:54 2020 -0700

    Move controller connecting to aenter
    
    Static methods such as from_port() and from_socket_file() cannot invoke
    asynchronous methods. Fundimentally this is the same problem as our ainit -
    when a loop is transtively running us we cannot join any futures we create.
    
    Luckily in this case we can simply sidestep the headache. from_port() and
    from_socket_file() are designed for 'with' statements so we can simply move the
    act of connecting into our context management (which is already asynchronous).
    
    I encountered this problem when I ran the following...
    
      import asyncio
    
      from stem.control import Controller
    
      async def print_version_async():
        async with Controller.from_port() as controller:
          await controller.authenticate()
          print('[with asyncio] tor is version %s' % await controller.get_version())
    
      def print_version_sync():
        with Controller.from_port() as controller:
          controller.authenticate()
          print('[without asyncio] tor is version %s' % controller.get_version())
    
      print_version_sync()
      asyncio.run(print_version_async())
    
    Before:
    
      % python3.7 demo.py
      [without asyncio] tor is version 0.4.5.0-alpha-dev (git-9d922b8eaae54242)
      /home/atagar/Desktop/stem/stem/control.py:1059: RuntimeWarning: coroutine 'BaseController.connect' was never awaited
        controller.connect()
      [with asyncio] tor is version 0.4.5.0-alpha-dev (git-9d922b8eaae54242)
    
    After:
    
      % python3.7 demo.py
      [without asyncio] tor is version 0.4.5.0-alpha-dev (git-9d922b8eaae54242)
      [with asyncio] tor is version 0.4.5.0-alpha-dev (git-9d922b8eaae54242)
---
 stem/control.py                  | 25 +++++++++----------------
 test/integ/control/controller.py |  8 ++++++--
 test/integ/process.py            |  1 +
 3 files changed, 16 insertions(+), 18 deletions(-)

diff --git a/stem/control.py b/stem/control.py
index 578a471d..0b5721e8 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -843,6 +843,13 @@ class BaseController(Synchronous):
       return is_changed
 
   async def __aenter__(self) -> 'stem.control.BaseController':
+    if not self.is_alive():
+      try:
+        await self.connect()
+      except:
+        self.stop()
+        raise
+
     return self
 
   async def __aexit__(self, exit_type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]) -> None:
@@ -1053,14 +1060,7 @@ class Controller(BaseController):
     else:
       control_port = stem.socket.ControlPort(address, int(port))
 
-    controller = Controller(control_port)
-
-    try:
-      controller.connect()
-      return controller
-    except:
-      controller.stop()
-      raise
+    return Controller(control_port)
 
   @staticmethod
   def from_socket_file(path: str = '/var/run/tor/control') -> 'stem.control.Controller':
@@ -1075,14 +1075,7 @@ class Controller(BaseController):
     """
 
     control_socket = stem.socket.ControlSocketFile(path)
-    controller = Controller(control_socket)
-
-    try:
-      controller.connect()
-      return controller
-    except:
-      controller.stop()
-      raise
+    return Controller(control_socket)
 
   def __init__(self, control_socket: stem.socket.ControlSocket, is_authenticated: bool = False) -> None:
     self._is_caching_enabled = True
diff --git a/test/integ/control/controller.py b/test/integ/control/controller.py
index 40e036a7..19d4ba85 100644
--- a/test/integ/control/controller.py
+++ b/test/integ/control/controller.py
@@ -69,7 +69,9 @@ class TestController(unittest.TestCase):
       with stem.control.Controller.from_port(port = test.runner.CONTROL_PORT) as controller:
         self.assertTrue(isinstance(controller, stem.control.Controller))
     else:
-      self.assertRaises(stem.SocketError, stem.control.Controller.from_port, '127.0.0.1', test.runner.CONTROL_PORT)
+      with self.assertRaises(stem.SocketError):
+        with stem.control.Controller.from_port(port = test.runner.CONTROL_PORT) as controller:
+          pass
 
   def test_from_socket_file(self):
     """
@@ -80,7 +82,9 @@ class TestController(unittest.TestCase):
       with stem.control.Controller.from_socket_file(path = test.runner.CONTROL_SOCKET_PATH) as controller:
         self.assertTrue(isinstance(controller, stem.control.Controller))
     else:
-      self.assertRaises(stem.SocketError, stem.control.Controller.from_socket_file, test.runner.CONTROL_SOCKET_PATH)
+      with self.assertRaises(stem.SocketError):
+        with stem.control.Controller.from_socket_file(path = test.runner.CONTROL_SOCKET_PATH) as controller:
+          pass
 
   @test.require.controller
   @async_test
diff --git a/test/integ/process.py b/test/integ/process.py
index 30cb0430..fd47c493 100644
--- a/test/integ/process.py
+++ b/test/integ/process.py
@@ -610,6 +610,7 @@ class TestProcess(unittest.TestCase):
       # We're the controlling process. Just need to connect then disconnect.
 
       controller = stem.control.Controller.from_port(port = int(control_port))
+      controller.connect()
       controller.authenticate()
       controller.close()
 



More information about the tor-commits mailing list