Shared Executor (#11028)

pull/11196/head^2
William FH 11 months ago committed by GitHub
parent 926e4b6bad
commit e9b51513e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,20 +2,33 @@
from __future__ import annotations
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Sequence, Set, Union
import weakref
from concurrent.futures import Future, wait
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
import langsmith
from langsmith import schemas as langsmith_schemas
from langsmith.evaluation.evaluator import EvaluationResult
from langchain.callbacks import manager
from langchain.callbacks.tracers import langchain as langchain_tracer
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.langchain import _get_executor
from langchain.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__)
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
def wait_for_all_evaluators() -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer):
----------
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
max_workers : int, optional
The maximum number of worker threads to use for running the evaluators.
If not specified, it will default to the number of evaluators.
client : LangSmith Client, optional
The LangSmith client instance to use for evaluating the runs.
If not specified, a new instance will be created.
@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer):
def __init__(
self,
evaluators: Sequence[langsmith.RunEvaluator],
max_workers: Optional[int] = None,
client: Optional[langsmith.Client] = None,
example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True,
@ -72,11 +81,14 @@ class EvaluatorCallbackHandler(BaseTracer):
)
self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators
self.max_workers = max_workers or len(evaluators)
self.futures: Set[Future] = set()
self.executor = _get_executor()
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
global _TRACERS
_TRACERS.add(self)
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
"""Evaluate the run in the project.
@ -120,15 +132,11 @@ class EvaluatorCallbackHandler(BaseTracer):
return
run_ = run.copy()
run_.reference_example_id = self.example_id
if self.max_workers > 0:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
list(
executor.map(
self._evaluate_in_project,
[run_ for _ in range(len(self.evaluators))],
self.evaluators,
)
)
else:
for evaluator in self.evaluators:
self._evaluate_in_project(run_, evaluator)
for evaluator in self.evaluators:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""
wait(self.futures)

@ -6,7 +6,7 @@ import os
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Union
from uuid import UUID
from langsmith import Client
@ -21,8 +21,7 @@ logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
_CLIENT: Optional[Client] = None
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
_EXECUTORS: List[ThreadPoolExecutor] = []
_EXECUTOR: Optional[ThreadPoolExecutor] = None
def log_error_once(method: str, exception: Exception) -> None:
@ -50,6 +49,14 @@ def get_client() -> Client:
return _CLIENT
def _get_executor() -> ThreadPoolExecutor:
"""Get the executor."""
global _EXECUTOR
if _EXECUTOR is None:
_EXECUTOR = ThreadPoolExecutor()
return _EXECUTOR
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@ -71,21 +78,10 @@ class LangChainTracer(BaseTracer):
self.project_name = project_name or os.getenv(
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
)
if use_threading:
global _MAX_EXECUTORS
if len(_EXECUTORS) < _MAX_EXECUTORS:
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
max_workers=1
)
_EXECUTORS.append(self.executor)
else:
self.executor = _EXECUTORS.pop(0)
_EXECUTORS.append(self.executor)
else:
self.executor = None
self.client = client or get_client()
self._futures: Set[Future] = set()
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.tags = tags or []
self.executor = _get_executor() if use_threading else None
global _TRACERS
_TRACERS.add(self)
@ -229,7 +225,4 @@ class LangChainTracer(BaseTracer):
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
futures = list(self._futures)
wait(futures)
for future in futures:
self._futures.remove(future)
wait(self._futures)

@ -24,8 +24,11 @@ from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
from langchain.callbacks.tracers.evaluation import (
EvaluatorCallbackHandler,
wait_for_all_evaluators,
)
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain
from langchain.evaluation.loading import load_evaluator
from langchain.evaluation.schema import (
@ -915,7 +918,6 @@ def _prepare_run_on_dataset(
EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
max_workers=0,
example_id=example.id,
),
progress_bar,
@ -934,7 +936,7 @@ def _collect_test_results(
configs: List[RunnableConfig],
project_name: str,
) -> TestResult:
wait_for_all_tracers()
wait_for_all_evaluators()
all_eval_results = {}
for c in configs:
for callback in cast(list, c["callbacks"]):

Loading…
Cancel
Save