commit b3f59acd750f10b11af76c52575b56208aad80dd Author: Illia Volochii illia.volochii@gmail.com Date: Fri Apr 17 18:41:12 2020 +0300
Fix the `with_default` decorator for synchronous functions --- stem/control.py | 95 +++++++++++++++++++++++++++++++++------------------------ 1 file changed, 55 insertions(+), 40 deletions(-)
diff --git a/stem/control.py b/stem/control.py index 0e03edfb..cc9ef964 100644 --- a/stem/control.py +++ b/stem/control.py @@ -458,12 +458,6 @@ def with_default(yields: bool = False) -> Callable: """
def decorator(func: Callable) -> Callable: - is_coroutine_func = asyncio.iscoroutinefunction(func) - def coroutine_if_needed(func: Callable) -> Callable: - if is_coroutine_func: - return asyncio.coroutine(func) - return func - def get_default(func: Callable, args: Any, kwargs: Any) -> Any: arg_names = inspect.getfullargspec(func).args[1:] # drop 'self' default_position = arg_names.index('default') if 'default' in arg_names else None @@ -473,41 +467,62 @@ def with_default(yields: bool = False) -> Callable: else: return kwargs.get('default', UNDEFINED)
- if not yields: - @functools.wraps(func) - @coroutine_if_needed - def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - result = func(self, *args, **kwargs) - if is_coroutine_func: - result = yield from result - return result - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - return default + if asyncio.iscoroutinefunction(func): + if not yields: + @functools.wraps(func) + async def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + return await func(self, *args, **kwargs) + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + return default + else: + @functools.wraps(func) + async def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + for val in await func(self, *args, **kwargs): + yield val + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + if default is not None: + for val in default: + yield val else: - @functools.wraps(func) - @coroutine_if_needed - def wrapped(self, *args: Any, **kwargs: Any) -> Any: - try: - result = func(self, *args, **kwargs) - if is_coroutine_func: - result = yield from result - for val in result: - yield val - except: - default = get_default(func, args, kwargs) - - if default == UNDEFINED: - raise - else: - if default is not None: - for val in default: - yield val + if not yields: + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + return func(self, *args, **kwargs) + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + return default + else: + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> Any: + try: + for val in func(self, *args, **kwargs): + yield val + except: + default = get_default(func, args, kwargs) + + if default == UNDEFINED: + raise + else: + if default is not None: + for val in default: + yield val
return wrapped