commit ac6d5d8ed5997770afb58eab64121d71fb7c3a95 Author: Illia Volochii illia.volochii@gmail.com Date: Fri Apr 17 19:35:30 2020 +0300
Fix the `with_default` decorator for asynchronous generators --- stem/control.py | 108 +++++++++++++++++++++++++++----------------------------- 1 file changed, 53 insertions(+), 55 deletions(-)
diff --git a/stem/control.py b/stem/control.py index d8d3d5e4..0b716269 100644 --- a/stem/control.py +++ b/stem/control.py @@ -467,62 +467,60 @@ def with_default(yields: bool = False) -> Callable: else: return kwargs.get('default', UNDEFINED)
- if asyncio.iscoroutinefunction(func): - if not yields: - @functools.wraps(func) - async def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - return await func(self, *args, **kwargs) - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - return default - else: - @functools.wraps(func) - async def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - for val in await func(self, *args, **kwargs): - yield val - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - if default is not None: - for val in default: - yield val + if asyncio.iscoroutinefunction(func) and not yields: + @functools.wraps(func) + async def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + return await func(self, *args, **kwargs) + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + return default + elif inspect.isasyncgenfunction(func) and yields: + @functools.wraps(func) + async def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + async for val in func(self, *args, **kwargs): + yield val + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + if default is not None: + for val in default: + yield val + elif not yields: + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + return func(self, *args, **kwargs) + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + return default else: - if not yields: - @functools.wraps(func) - def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - return func(self, *args, **kwargs) - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - return default - else: - @functools.wraps(func) - def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - for val in func(self, *args, **kwargs): - yield val - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - if default is not None: - for val in default: - yield val + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + for val in func(self, *args, **kwargs): + yield val + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + if default is not None: + for val in default: + yield val
return wrapped