Wait for all futures (#6554)

- Expose method to wait for all futures
- Wait for submissions in the run_on_dataset functions to ensure runs
are fully submitted before cleaning up
multi_strategy_parser
Zander Chase 11 months ago committed by GitHub
parent e0605b464b
commit 5322bac5fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,9 +3,9 @@ from __future__ import annotations
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union
from uuid import UUID
from langchainplus_sdk import LangChainPlusClient
@ -21,6 +21,7 @@ from langchain.schema import BaseMessage, messages_to_dict
logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: List[LangChainTracer] = []
def log_error_once(method: str, exception: Exception) -> None:
@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None:
logger.error(exception)
def wait_for_all_tracers() -> None:
global _TRACERS
for tracer in _TRACERS:
tracer.wait_for_futures()
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@ -52,6 +59,9 @@ class LangChainTracer(BaseTracer):
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
self._futures: Set[Future] = set()
global _TRACERS
_TRACERS.append(self)
def on_chat_model_start(
self,
@ -93,7 +103,7 @@ class LangChainTracer(BaseTracer):
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
run = self.client.create_run(**run_dict, session_name=self.session_name)
self.client.create_run(**run_dict, session_name=self.session_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
@ -110,40 +120,67 @@ class LangChainTracer(BaseTracer):
def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
futures = list(self._futures)
wait(futures)
for future in futures:
self._futures.remove(future)

@ -224,9 +224,17 @@ async def _gather_with_concurrency(
tracer_queue.put_nowait(tracer)
return result
return await asyncio.gather(
results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
while tracer_queue:
try:
tracer = tracer_queue.get_nowait()
except asyncio.QueueEmpty:
break
if tracer:
tracer.wait_for_futures()
return results
async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
@ -411,7 +419,9 @@ def run_on_examples(
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
results[str(example.id)] = result
if tracer:
tracer.wait_for_futures()
return results

Loading…
Cancel
Save