Wfh/eval max concurrency (#11368)

pull/11372/head
William FH 11 months ago committed by GitHub
parent 1165767df2
commit 06f39be1c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,13 +2,13 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading
import weakref import weakref
from concurrent.futures import Future, wait from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID from uuid import UUID
import langsmith import langsmith
from langsmith import schemas as langsmith_schemas
from langsmith.evaluation.evaluator import EvaluationResult from langsmith.evaluation.evaluator import EvaluationResult
from langchain.callbacks import manager from langchain.callbacks import manager
@ -73,6 +73,7 @@ class EvaluatorCallbackHandler(BaseTracer):
example_id: Optional[Union[UUID, str]] = None, example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True, skip_unfinished: bool = True,
project_name: Optional[str] = "evaluators", project_name: Optional[str] = "evaluators",
max_concurrency: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -81,12 +82,21 @@ class EvaluatorCallbackHandler(BaseTracer):
) )
self.client = client or langchain_tracer.get_client() self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators self.evaluators = evaluators
self.executor = _get_executor() if max_concurrency is None:
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
elif max_concurrency > 0:
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
weakref.finalize(
self,
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
)
else:
self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished self.skip_unfinished = skip_unfinished
self.project_name = project_name self.project_name = project_name
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {} self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {} self.lock = threading.Lock()
global _TRACERS global _TRACERS
_TRACERS.add(self) _TRACERS.add(self)
@ -111,12 +121,15 @@ class EvaluatorCallbackHandler(BaseTracer):
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error evaluating run {run.id} with " f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {e}", f"{evaluator.__class__.__name__}: {repr(e)}",
exc_info=True, exc_info=True,
) )
raise e raise e
example_id = str(run.reference_example_id) example_id = str(run.reference_example_id)
self.logged_eval_results.setdefault(example_id, []).append(eval_result) with self.lock:
self.logged_eval_results.setdefault((str(run.id), example_id), []).append(
eval_result
)
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run. """Run the evaluator on the run.
@ -133,9 +146,12 @@ class EvaluatorCallbackHandler(BaseTracer):
run_ = run.copy() run_ = run.copy()
run_.reference_example_id = self.example_id run_.reference_example_id = self.example_id
for evaluator in self.evaluators: for evaluator in self.evaluators:
self.futures.add( if self.executor is None:
self.executor.submit(self._evaluate_in_project, run_, evaluator) self._evaluate_in_project(run_, evaluator)
) else:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None: def wait_for_futures(self) -> None:
"""Wait for all futures to complete.""" """Wait for all futures to complete."""

@ -948,7 +948,10 @@ def _collect_test_results(
for c in configs: for c in configs:
for callback in cast(list, c["callbacks"]): for callback in cast(list, c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler): if isinstance(callback, EvaluatorCallbackHandler):
all_eval_results.update(callback.logged_eval_results) eval_results = callback.logged_eval_results
all_eval_results.update(
{example_id: v for (_, example_id), v in eval_results.items()}
)
results = {} results = {}
for example, output in zip(examples, batch_results): for example, output in zip(examples, batch_results):
feedback = all_eval_results.get(str(example.id), []) feedback = all_eval_results.get(str(example.id), [])

Loading…
Cancel
Save