commit a729b61c61f4c2e136f13a5a921bebfddf1cd4ed Author: Illia Volochii illia.volochii@gmail.com Date: Fri Apr 17 23:07:47 2020 +0300
Make `Controller.get_hidden_service_descriptor` asynchronous --- stem/control.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-)
diff --git a/stem/control.py b/stem/control.py index 84efbf81..fd38871a 100644 --- a/stem/control.py +++ b/stem/control.py @@ -2007,7 +2007,7 @@ class Controller(BaseController): yield desc # type: ignore
@with_default() - def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: + async def get_hidden_service_descriptor(self, address: str, default: Any = UNDEFINED, servers: Optional[Sequence[str]] = None, await_result: bool = True, timeout: Optional[float] = None) -> stem.descriptor.hidden_service.HiddenServiceDescriptorV2: """ get_hidden_service_descriptor(address, default = UNDEFINED, servers = None, await_result = True)
@@ -2050,23 +2050,25 @@ class Controller(BaseController): if not stem.util.tor_tools.is_valid_hidden_service_address(address): raise ValueError("'%s.onion' isn't a valid hidden service address" % address)
- hs_desc_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.events.Event] hs_desc_listener = None
- hs_desc_content_queue = queue.Queue() # type: queue.Queue[stem.response.events.Event] + hs_desc_content_queue = asyncio.Queue() # type: asyncio.Queue[stem.response.events.Event] hs_desc_content_listener = None
start_time = time.time()
if await_result: - def hs_desc_listener(event: stem.response.events.Event) -> None: - hs_desc_queue.put(event) + async def hs_desc_listener(event: stem.response.events.Event) -> None: + await hs_desc_queue.put(event)
- def hs_desc_content_listener(event: stem.response.events.Event) -> None: - hs_desc_content_queue.put(event) + async def hs_desc_content_listener(event: stem.response.events.Event) -> None: + await hs_desc_content_queue.put(event)
- self.add_event_listener(hs_desc_listener, EventType.HS_DESC) - self.add_event_listener(hs_desc_content_listener, EventType.HS_DESC_CONTENT) + await asyncio.gather( + self.add_event_listener(hs_desc_listener, EventType.HS_DESC), + self.add_event_listener(hs_desc_content_listener, EventType.HS_DESC_CONTENT), + )
try: request = 'HSFETCH %s' % address @@ -2074,7 +2076,7 @@ class Controller(BaseController): if servers: request += ' ' + ' '.join(['SERVER=%s' % s for s in servers])
- response = stem.response._convert_to_single_line(self.msg(request)) + response = stem.response._convert_to_single_line(await self.msg(request))
if not response.is_ok(): raise stem.ProtocolError('HSFETCH returned unexpected response code: %s' % response.code) @@ -2083,7 +2085,7 @@ class Controller(BaseController): return None # not waiting, so nothing to provide back else: while True: - event = _get_with_timeout(hs_desc_content_queue, timeout, start_time) + event = await _get_with_timeout(hs_desc_content_queue, timeout)
if event.address == address: if event.descriptor: @@ -2092,7 +2094,7 @@ class Controller(BaseController): # no descriptor, looking through HS_DESC to figure out why
while True: - event = _get_with_timeout(hs_desc_queue, timeout, start_time) + event = await _get_with_timeout(hs_desc_queue, timeout)
if event.address == address and event.action == stem.HSDescAction.FAILED: if event.reason == stem.HSDescReason.NOT_FOUND: @@ -2100,11 +2102,15 @@ class Controller(BaseController): else: raise stem.DescriptorUnavailable('Unable to retrieve the descriptor for %s.onion (retrieved from %s): %s' % (address, event.directory_fingerprint, event.reason)) finally: + awaitable_removals = [] + if hs_desc_listener: - self.remove_event_listener(hs_desc_listener) + awaitable_removals.append(self.remove_event_listener(hs_desc_listener))
if hs_desc_content_listener: - self.remove_event_listener(hs_desc_content_listener) + awaitable_removals.append(self.remove_event_listener(hs_desc_content_listener)) + + await asyncio.gather(*awaitable_removals)
async def get_conf(self, param: str, default: Any = UNDEFINED, multiple: bool = False) -> Union[str, Sequence[str]]: """