diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 9d45aa44..1e2d4c64 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -47,9 +47,11 @@ class AsyncProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: - check_running_loop() - - yield asyncio.run(cls.create_async(model, messages, **kwargs)) + loop = create_event_loop() + try: + yield loop.run_until_complete(cls.create_async(model, messages, **kwargs)) + finally: + loop.close() @staticmethod @abstractmethod @@ -70,10 +72,7 @@ class AsyncGeneratorProvider(AsyncProvider): stream: bool = True, **kwargs ) -> CreateResult: - check_running_loop() - - # Force use selector event loop on windows - loop = asyncio.SelectorEventLoop() + loop = get_new_event_loop() try: generator = cls.create_async_generator( model, @@ -108,12 +107,17 @@ class AsyncGeneratorProvider(AsyncProvider): ) -> AsyncGenerator: raise NotImplementedError() -# Don't create a new loop in a running loop -def check_running_loop(): + +def create_event_loop(): + # Don't create a new loop in a running loop if asyncio.events._get_running_loop() is not None: raise RuntimeError( 'Use "create_async" instead of "create" function in a async loop.') + # Force use selector event loop on windows + return asyncio.SelectorEventLoop() + + _cookies = {} def get_cookies(cookie_domain: str) -> dict: