From 5322bac5fc6adda75777886a2fe487d759589ed9 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 21 Jun 2023 18:20:17 -0700 Subject: [PATCH] Wait for all futures (#6554) - Expose method to wait for all futures - Wait for submissions in the run_on_dataset functions to ensure runs are fully submitted before cleaning up --- langchain/callbacks/tracers/langchain.py | 63 +++++++++++++++++++----- langchain/client/runner_utils.py | 14 +++++- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index ecb3dbe5..48d6164c 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,9 +3,9 @@ from __future__ import annotations import logging import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union from uuid import UUID from langchainplus_sdk import LangChainPlusClient @@ -21,6 +21,7 @@ from langchain.schema import BaseMessage, messages_to_dict logger = logging.getLogger(__name__) _LOGGED = set() +_TRACERS: List[LangChainTracer] = [] def log_error_once(method: str, exception: Exception) -> None: @@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None: logger.error(exception) +def wait_for_all_tracers() -> None: + global _TRACERS + for tracer in _TRACERS: + tracer.wait_for_futures() + + class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" @@ -52,6 +59,9 @@ class LangChainTracer(BaseTracer): # set max_workers to 1 to process tasks in order self.executor = ThreadPoolExecutor(max_workers=1) self.client = client or LangChainPlusClient() + self._futures: Set[Future] = set() + global _TRACERS + _TRACERS.append(self) def on_chat_model_start( self, @@ -93,7 +103,7 @@ class LangChainTracer(BaseTracer): extra["runtime"] = get_runtime_environment() run_dict["extra"] = extra try: - run = self.client.create_run(**run_dict, session_name=self.session_name) + self.client.create_run(**run_dict, session_name=self.session_name) except Exception as e: # Errors are swallowed by the thread executor so we need to log them here log_error_once("post", e) @@ -110,40 +120,67 @@ class LangChainTracer(BaseTracer): def _on_llm_start(self, run: Run) -> None: """Persist an LLM run.""" - self.executor.submit(self._persist_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + ) def _on_chat_model_start(self, run: Run) -> None: """Persist an LLM run.""" - self.executor.submit(self._persist_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + ) def _on_llm_end(self, run: Run) -> None: """Process the LLM Run.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) def _on_llm_error(self, run: Run) -> None: """Process the LLM Run upon error.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) def _on_chain_start(self, run: Run) -> None: """Process the Chain Run upon start.""" - self.executor.submit(self._persist_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + ) def _on_chain_end(self, run: Run) -> None: """Process the Chain Run.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) def _on_chain_error(self, run: Run) -> None: """Process the Chain Run upon error.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) def _on_tool_start(self, run: Run) -> None: """Process the Tool Run upon start.""" - self.executor.submit(self._persist_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._persist_run_single, run.copy(deep=True)) + ) def _on_tool_end(self, run: Run) -> None: """Process the Tool Run.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) def _on_tool_error(self, run: Run) -> None: """Process the Tool Run upon error.""" - self.executor.submit(self._update_run_single, run.copy(deep=True)) + self._futures.add( + self.executor.submit(self._update_run_single, run.copy(deep=True)) + ) + + def wait_for_futures(self) -> None: + """Wait for the given futures to complete.""" + futures = list(self._futures) + wait(futures) + for future in futures: + self._futures.remove(future) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index 99066ea6..5a7d6d42 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -224,9 +224,17 @@ async def _gather_with_concurrency( tracer_queue.put_nowait(tracer) return result - return await asyncio.gather( + results = await asyncio.gather( *(run_coroutine_with_semaphore(function) for function in async_funcs) ) + while tracer_queue: + try: + tracer = tracer_queue.get_nowait() + except asyncio.QueueEmpty: + break + if tracer: + tracer.wait_for_futures() + return results async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]: @@ -411,7 +419,9 @@ def run_on_examples( ) if verbose: print(f"{i+1} processed", flush=True, end="\r") - results[str(example.id)] = result + results[str(example.id)] = result + if tracer: + tracer.wait_for_futures() return results