Wfh/eval max concurrency (#11368)

pull/11372/head
William FH 10 months ago committed by GitHub
parent 1165767df2
commit 06f39be1c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,13 +2,13 @@
from __future__ import annotations
import logging
import threading
import weakref
from concurrent.futures import Future, wait
from typing import Any, Dict, List, Optional, Sequence, Union
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
import langsmith
from langsmith import schemas as langsmith_schemas
from langsmith.evaluation.evaluator import EvaluationResult
from langchain.callbacks import manager
@ -73,6 +73,7 @@ class EvaluatorCallbackHandler(BaseTracer):
example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True,
project_name: Optional[str] = "evaluators",
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@ -81,12 +82,21 @@ class EvaluatorCallbackHandler(BaseTracer):
)
self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators
self.executor = _get_executor()
if max_concurrency is None:
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
elif max_concurrency > 0:
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
weakref.finalize(
self,
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
)
else:
self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
self.lock = threading.Lock()
global _TRACERS
_TRACERS.add(self)
@ -111,12 +121,15 @@ class EvaluatorCallbackHandler(BaseTracer):
except Exception as e:
logger.error(
f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {e}",
f"{evaluator.__class__.__name__}: {repr(e)}",
exc_info=True,
)
raise e
example_id = str(run.reference_example_id)
self.logged_eval_results.setdefault(example_id, []).append(eval_result)
with self.lock:
self.logged_eval_results.setdefault((str(run.id), example_id), []).append(
eval_result
)
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
@ -133,9 +146,12 @@ class EvaluatorCallbackHandler(BaseTracer):
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
if self.executor is None:
self._evaluate_in_project(run_, evaluator)
else:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""

@ -948,7 +948,10 @@ def _collect_test_results(
for c in configs:
for callback in cast(list, c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler):
all_eval_results.update(callback.logged_eval_results)
eval_results = callback.logged_eval_results
all_eval_results.update(
{example_id: v for (_, example_id), v in eval_results.items()}
)
results = {}
for example, output in zip(examples, batch_results):
feedback = all_eval_results.get(str(example.id), [])

Loading…
Cancel
Save