diff --git a/libs/langchain/langchain/callbacks/tracers/evaluation.py b/libs/langchain/langchain/callbacks/tracers/evaluation.py index 0d333f9f04..9d524665fc 100644 --- a/libs/langchain/langchain/callbacks/tracers/evaluation.py +++ b/libs/langchain/langchain/callbacks/tracers/evaluation.py @@ -2,20 +2,33 @@ from __future__ import annotations import logging -from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Sequence, Set, Union +import weakref +from concurrent.futures import Future, wait +from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID import langsmith +from langsmith import schemas as langsmith_schemas from langsmith.evaluation.evaluator import EvaluationResult from langchain.callbacks import manager from langchain.callbacks.tracers import langchain as langchain_tracer from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.langchain import _get_executor from langchain.callbacks.tracers.schemas import Run logger = logging.getLogger(__name__) +_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() + + +def wait_for_all_evaluators() -> None: + """Wait for all tracers to finish.""" + global _TRACERS + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() + class EvaluatorCallbackHandler(BaseTracer): """A tracer that runs a run evaluator whenever a run is persisted. @@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer): ---------- evaluators : Sequence[RunEvaluator] The run evaluators to apply to all top level runs. - max_workers : int, optional - The maximum number of worker threads to use for running the evaluators. - If not specified, it will default to the number of evaluators. client : LangSmith Client, optional The LangSmith client instance to use for evaluating the runs. If not specified, a new instance will be created. @@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer): def __init__( self, evaluators: Sequence[langsmith.RunEvaluator], - max_workers: Optional[int] = None, client: Optional[langsmith.Client] = None, example_id: Optional[Union[UUID, str]] = None, skip_unfinished: bool = True, @@ -72,11 +81,14 @@ class EvaluatorCallbackHandler(BaseTracer): ) self.client = client or langchain_tracer.get_client() self.evaluators = evaluators - self.max_workers = max_workers or len(evaluators) - self.futures: Set[Future] = set() + self.executor = _get_executor() + self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.skip_unfinished = skip_unfinished self.project_name = project_name + self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {} self.logged_eval_results: Dict[str, List[EvaluationResult]] = {} + global _TRACERS + _TRACERS.add(self) def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None: """Evaluate the run in the project. @@ -120,15 +132,11 @@ class EvaluatorCallbackHandler(BaseTracer): return run_ = run.copy() run_.reference_example_id = self.example_id - if self.max_workers > 0: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - list( - executor.map( - self._evaluate_in_project, - [run_ for _ in range(len(self.evaluators))], - self.evaluators, - ) - ) - else: - for evaluator in self.evaluators: - self._evaluate_in_project(run_, evaluator) + for evaluator in self.evaluators: + self.futures.add( + self.executor.submit(self._evaluate_in_project, run_, evaluator) + ) + + def wait_for_futures(self) -> None: + """Wait for all futures to complete.""" + wait(self.futures) diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 07cde9e568..fd91fa3842 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -6,7 +6,7 @@ import os import weakref from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union from uuid import UUID from langsmith import Client @@ -21,8 +21,7 @@ logger = logging.getLogger(__name__) _LOGGED = set() _TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() _CLIENT: Optional[Client] = None -_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented -_EXECUTORS: List[ThreadPoolExecutor] = [] +_EXECUTOR: Optional[ThreadPoolExecutor] = None def log_error_once(method: str, exception: Exception) -> None: @@ -50,6 +49,14 @@ def get_client() -> Client: return _CLIENT +def _get_executor() -> ThreadPoolExecutor: + """Get the executor.""" + global _EXECUTOR + if _EXECUTOR is None: + _EXECUTOR = ThreadPoolExecutor() + return _EXECUTOR + + class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" @@ -71,21 +78,10 @@ class LangChainTracer(BaseTracer): self.project_name = project_name or os.getenv( "LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default") ) - if use_threading: - global _MAX_EXECUTORS - if len(_EXECUTORS) < _MAX_EXECUTORS: - self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( - max_workers=1 - ) - _EXECUTORS.append(self.executor) - else: - self.executor = _EXECUTORS.pop(0) - _EXECUTORS.append(self.executor) - else: - self.executor = None self.client = client or get_client() - self._futures: Set[Future] = set() + self._futures: weakref.WeakSet[Future] = weakref.WeakSet() self.tags = tags or [] + self.executor = _get_executor() if use_threading else None global _TRACERS _TRACERS.add(self) @@ -229,7 +225,4 @@ class LangChainTracer(BaseTracer): def wait_for_futures(self) -> None: """Wait for the given futures to complete.""" - futures = list(self._futures) - wait(futures) - for future in futures: - self._futures.remove(future) + wait(self._futures) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 7f8e2e6592..8119f81a46 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -24,8 +24,11 @@ from langsmith import Client, RunEvaluator from langsmith.schemas import Dataset, DataType, Example from langchain.callbacks.manager import Callbacks -from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler -from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers +from langchain.callbacks.tracers.evaluation import ( + EvaluatorCallbackHandler, + wait_for_all_evaluators, +) +from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.chains.base import Chain from langchain.evaluation.loading import load_evaluator from langchain.evaluation.schema import ( @@ -915,7 +918,6 @@ def _prepare_run_on_dataset( EvaluatorCallbackHandler( evaluators=run_evaluators or [], client=client, - max_workers=0, example_id=example.id, ), progress_bar, @@ -934,7 +936,7 @@ def _collect_test_results( configs: List[RunnableConfig], project_name: str, ) -> TestResult: - wait_for_all_tracers() + wait_for_all_evaluators() all_eval_results = {} for c in configs: for callback in cast(list, c["callbacks"]):