diff --git a/libs/langchain/langchain/callbacks/tracers/evaluation.py b/libs/langchain/langchain/callbacks/tracers/evaluation.py index 9d524665fc..f5bc5f25e3 100644 --- a/libs/langchain/langchain/callbacks/tracers/evaluation.py +++ b/libs/langchain/langchain/callbacks/tracers/evaluation.py @@ -2,13 +2,13 @@ from __future__ import annotations import logging +import threading import weakref -from concurrent.futures import Future, wait -from typing import Any, Dict, List, Optional, Sequence, Union +from concurrent.futures import Future, ThreadPoolExecutor, wait +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from uuid import UUID import langsmith -from langsmith import schemas as langsmith_schemas from langsmith.evaluation.evaluator import EvaluationResult from langchain.callbacks import manager @@ -73,6 +73,7 @@ class EvaluatorCallbackHandler(BaseTracer): example_id: Optional[Union[UUID, str]] = None, skip_unfinished: bool = True, project_name: Optional[str] = "evaluators", + max_concurrency: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -81,12 +82,21 @@ class EvaluatorCallbackHandler(BaseTracer): ) self.client = client or langchain_tracer.get_client() self.evaluators = evaluators - self.executor = _get_executor() + if max_concurrency is None: + self.executor: Optional[ThreadPoolExecutor] = _get_executor() + elif max_concurrency > 0: + self.executor = ThreadPoolExecutor(max_workers=max_concurrency) + weakref.finalize( + self, + lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True), + ) + else: + self.executor = None self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.skip_unfinished = skip_unfinished self.project_name = project_name - self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {} - self.logged_eval_results: Dict[str, List[EvaluationResult]] = {} + self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} + self.lock = threading.Lock() global _TRACERS _TRACERS.add(self) @@ -111,12 +121,15 @@ class EvaluatorCallbackHandler(BaseTracer): except Exception as e: logger.error( f"Error evaluating run {run.id} with " - f"{evaluator.__class__.__name__}: {e}", + f"{evaluator.__class__.__name__}: {repr(e)}", exc_info=True, ) raise e example_id = str(run.reference_example_id) - self.logged_eval_results.setdefault(example_id, []).append(eval_result) + with self.lock: + self.logged_eval_results.setdefault((str(run.id), example_id), []).append( + eval_result + ) def _persist_run(self, run: Run) -> None: """Run the evaluator on the run. @@ -133,9 +146,12 @@ class EvaluatorCallbackHandler(BaseTracer): run_ = run.copy() run_.reference_example_id = self.example_id for evaluator in self.evaluators: - self.futures.add( - self.executor.submit(self._evaluate_in_project, run_, evaluator) - ) + if self.executor is None: + self._evaluate_in_project(run_, evaluator) + else: + self.futures.add( + self.executor.submit(self._evaluate_in_project, run_, evaluator) + ) def wait_for_futures(self) -> None: """Wait for all futures to complete.""" diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index ba3e76c69e..43b62a7104 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -948,7 +948,10 @@ def _collect_test_results( for c in configs: for callback in cast(list, c["callbacks"]): if isinstance(callback, EvaluatorCallbackHandler): - all_eval_results.update(callback.logged_eval_results) + eval_results = callback.logged_eval_results + all_eval_results.update( + {example_id: v for (_, example_id), v in eval_results.items()} + ) results = {} for example, output in zip(examples, batch_results): feedback = all_eval_results.get(str(example.id), [])