commit ac6d5d8ed5997770afb58eab64121d71fb7c3a95
Author: Illia Volochii <illia.volochii(a)gmail.com>
Date: Fri Apr 17 19:35:30 2020 +0300
Fix the `with_default` decorator for asynchronous generators
---
stem/control.py | 108 +++++++++++++++++++++++++++-----------------------------
1 file changed, 53 insertions(+), 55 deletions(-)
diff --git a/stem/control.py b/stem/control.py
index d8d3d5e4..0b716269 100644
--- a/stem/control.py
+++ b/stem/control.py
@@ -467,62 +467,60 @@ def with_default(yields: bool = False) -> Callable:
else:
return kwargs.get('default', UNDEFINED)
- if asyncio.iscoroutinefunction(func):
- if not yields:
- @functools.wraps(func)
- async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- return await func(self, *args, **kwargs)
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- return default
- else:
- @functools.wraps(func)
- async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- for val in await func(self, *args, **kwargs):
- yield val
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- if default is not None:
- for val in default:
- yield val
+ if asyncio.iscoroutinefunction(func) and not yields:
+ @functools.wraps(func)
+ async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return await func(self, *args, **kwargs)
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ return default
+ elif inspect.isasyncgenfunction(func) and yields:
+ @functools.wraps(func)
+ async def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ async for val in func(self, *args, **kwargs):
+ yield val
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ if default is not None:
+ for val in default:
+ yield val
+ elif not yields:
+ @functools.wraps(func)
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ return func(self, *args, **kwargs)
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ return default
else:
- if not yields:
- @functools.wraps(func)
- def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- return func(self, *args, **kwargs)
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- return default
- else:
- @functools.wraps(func)
- def wrapped(self, *args: Any, **kwargs: Any) -> Any:
- try:
- for val in func(self, *args, **kwargs):
- yield val
- except:
- default = get_default(func, args, kwargs)
-
- if default == UNDEFINED:
- raise
- else:
- if default is not None:
- for val in default:
- yield val
+ @functools.wraps(func)
+ def wrapped(self, *args: Any, **kwargs: Any) -> Any:
+ try:
+ for val in func(self, *args, **kwargs):
+ yield val
+ except:
+ default = get_default(func, args, kwargs)
+
+ if default == UNDEFINED:
+ raise
+ else:
+ if default is not None:
+ for val in default:
+ yield val
return wrapped