Last active
February 14, 2024 18:15
-
-
Save smurfix/0130817fa5ba6d3bb4a0f00e4d93cf86 to your computer and use it in GitHub Desktop.
Trio: results-gathering nursery wrapper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import trio | |
import outcome | |
from contextlib import asynccontextmanager | |
class StreamResultsNursery: | |
def __init__(self, max_buffer_size=1): | |
self.nursery = trio.open_nursery() | |
self.max_buffer_size = max_buffer_size | |
self.res_in, self.res_out = trio.open_memory_channel(self.max_buffer_size) | |
self._waiting = 1 | |
self._loop = False | |
@property | |
def cancel_scope(self): | |
return self.nursery.cancel_scope | |
async def __aenter__(self): | |
self.nm = await self.nursery.__aenter__() | |
return self | |
def __aexit__(self, *exc): | |
return self.nursery.__aexit__(*exc) | |
def start_soon(self, p, *a): | |
self.nm.start_soon(self._wrap, p, a) | |
async def _wrap(self, p, a): | |
try: | |
await self.res_in.send(await p(*a)) | |
finally: | |
self._waiting -= 1 | |
async def _wrap_ts(self, p, a, task_status): | |
try: | |
await self.res_in.send(await p(*a, task_status=task_status)) | |
finally: | |
self._waiting -= 1 | |
async def start(self, p, *a): | |
self._waiting += 1 | |
if self.res_in is None: | |
await self.nm.start(p,*a) | |
else: | |
await self.nm.start(self._wrap_ts, p, a) | |
def start_soon(self, p, *a): | |
self._waiting += 1 | |
if self.res_in is None: | |
self.nm.start_soon(p,*a) | |
else: | |
self.nm.start_soon(self._wrap, p, a) | |
def __aiter__(self): | |
if not self._loop: | |
self._loop = True | |
self._waiting -= 1 | |
return self | |
async def __anext__(self): | |
if self.res_out is None: | |
raise StopAsyncIteration # never started | |
try: | |
if self._waiting: | |
return await self.res_out.receive() | |
else: | |
return self.res_out.receive_nowait() | |
except (trio.WouldBlock,trio.EndOfChannel): | |
raise StopAsyncIteration # never started | |
if __name__ == "__main__": | |
# test code | |
import random | |
async def rand(): | |
await trio.sleep(random.random()) | |
return random.random() | |
async def main(n): | |
async with StreamResultsNursery() as N: | |
for _ in range(10): | |
N.start_soon(rand) | |
async for rn in N: | |
print(rn) | |
trio.run(main,10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the correction!