Update to RunOnDataset helper functions to accept evaluator callbacks (#6629)

Also improve docstrings and update the tracing datasets notebook to
focus on "debug, evaluate, monitor"
pull/6808/head
Zander Chase 1 year ago committed by GitHub
parent 7ac9b22886
commit 6ca383ecf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,84 @@
"""A tracer that runs evaluators over completed runs."""
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Optional, Sequence, Set, Union
from uuid import UUID
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
Parameters
----------
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
max_workers : int, optional
The maximum number of worker threads to use for running the evaluators.
If not specified, it will default to the number of evaluators.
client : LangChainPlusClient, optional
The LangChainPlusClient instance to use for evaluating the runs.
If not specified, a new instance will be created.
example_id : Union[UUID, str], optional
The example ID to be associated with the runs.
Attributes
----------
example_id : Union[UUID, None]
The example ID associated with the runs.
client : LangChainPlusClient
The LangChainPlusClient instance used for evaluating the runs.
evaluators : Sequence[RunEvaluator]
The sequence of run evaluators to be executed.
executor : ThreadPoolExecutor
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
"""
name = "evaluator_callback_handler"
def __init__(
self,
evaluators: Sequence[RunEvaluator],
max_workers: Optional[int] = None,
client: Optional[LangChainPlusClient] = None,
example_id: Optional[Union[UUID, str]] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or LangChainPlusClient()
self.evaluators = evaluators
self.executor = ThreadPoolExecutor(
max_workers=max(max_workers or len(evaluators), 1)
)
self.futures: Set[Future] = set()
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
Parameters
----------
run : Run
The run to be evaluated.
"""
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
self.futures.add(
self.executor.submit(self.client.evaluate_run, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""
futures = list(self.futures)
wait(futures)
for future in futures:
self.futures.remove(future)

@ -1,20 +1,52 @@
"""A tracer that collects all nested runs in a list."""
from typing import Any, List
from typing import Any, List, Optional, Union
from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
class RunCollectorCallbackHandler(BaseTracer):
"""A tracer that collects all nested runs in a list.
"""
A tracer that collects all nested runs in a list.
This tracer is useful for inspection and evaluation purposes.
Useful for inspection and for evaluation."""
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
name = "run-collector_callback_handler"
def __init__(self, **kwargs: Any) -> None:
def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
) -> None:
"""
Initialize the RunCollectorCallbackHandler.
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.traced_runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
self.traced_runs.append(run)
"""
Persist a run by adding it to the traced_runs list.
Parameters
----------
run : Run
The run to be persisted.
"""
run_ = run.copy()
run_.reference_example_id = self.example_id
self.traced_runs.append(run_)

@ -1,4 +1,5 @@
"""Utilities for running LLMs/Chains over datasets."""
"""Utilities for running language models or Chains over datasets."""
from __future__ import annotations
import asyncio
@ -13,15 +14,18 @@ from typing import (
Iterator,
List,
Optional,
Sequence,
Union,
)
from langchainplus_sdk import LangChainPlusClient
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchainplus_sdk.schemas import Example
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
@ -41,11 +45,21 @@ MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
class InputFormatError(Exception):
"""Raised when input format is invalid."""
"""Raised when the input format is invalid."""
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
"""Get prompts from inputs."""
"""
Get prompts from inputs.
Args:
inputs: The input dictionary.
Returns:
A list of prompts.
Raises:
InputFormatError: If the input format is invalid.
"""
if not inputs:
raise InputFormatError("Inputs should not be empty.")
@ -83,7 +97,17 @@ def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
"""Get Chat Messages from inputs."""
"""
Get Chat Messages from inputs.
Args:
inputs: The input dictionary.
Returns:
A list of chat messages.
Raises:
InputFormatError: If the input format is invalid.
"""
if not inputs:
raise InputFormatError("Inputs should not be empty.")
@ -112,13 +136,25 @@ def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
async def _arun_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> Union[LLMResult, ChatResult]:
callbacks: Optional[List[BaseCallbackHandler]] = (
[langchain_tracer] if langchain_tracer else None
)
"""
Asynchronously run the language model.
Args:
llm: The language model to run.
inputs: The input dictionary.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
The LLMResult or ChatResult.
Raises:
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
@ -152,18 +188,32 @@ async def _arun_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
langchain_tracer: Optional[LangChainTracer],
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously."""
if langchain_tracer is not None:
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
"""
Asynchronously run the Chain or language model.
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
A list of outputs.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else:
previous_example_id = None
callbacks = None
previous_example_ids = None
outputs = []
for _ in range(n_repetitions):
try:
@ -171,8 +221,8 @@ async def _arun_llm_or_chain(
output: Any = await _arun_llm(
llm_or_chain_factory,
example.inputs,
langchain_tracer,
tags=tags,
callbacks=callbacks,
)
else:
chain = llm_or_chain_factory()
@ -183,15 +233,19 @@ async def _arun_llm_or_chain(
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
if langchain_tracer is not None:
langchain_tracer.example_id = previous_example_id
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
return outputs
async def _gather_with_concurrency(
n: int,
initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]],
*async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]],
initializer: Callable[[], Coroutine[Any, Any, Any]],
*async_funcs: Callable[
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
],
) -> List[Any]:
"""
Run coroutines with a concurrency limit.
@ -207,37 +261,42 @@ async def _gather_with_concurrency(
semaphore = asyncio.Semaphore(n)
job_state = {"num_processed": 0}
tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue()
callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
for _ in range(n):
tracer_queue.put_nowait(await initializer())
callback_queue.put_nowait(await initializer())
async def run_coroutine_with_semaphore(
async_func: Callable[
[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
]
) -> Any:
async with semaphore:
tracer = await tracer_queue.get()
callbacks = await callback_queue.get()
try:
result = await async_func(tracer, job_state)
result = await async_func(callbacks, job_state)
finally:
tracer_queue.put_nowait(tracer)
callback_queue.put_nowait(callbacks)
return result
results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
while tracer_queue:
while callback_queue:
try:
tracer = tracer_queue.get_nowait()
callbacks = callback_queue.get_nowait()
except asyncio.QueueEmpty:
break
if tracer:
tracer.wait_for_futures()
for callback in callbacks:
if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
callback.wait_for_futures()
return results
async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChainTracer]:
async def _callbacks_initializer(
project_name: Optional[str],
client: LangChainPlusClient,
run_evaluators: Sequence[RunEvaluator],
) -> List[BaseTracer]:
"""
Initialize a tracer to share across tasks.
@ -247,11 +306,19 @@ async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChain
Returns:
A LangChainTracer instance with an active project.
"""
callbacks: List[BaseTracer] = []
if project_name:
tracer = LangChainTracer(project_name=project_name)
return tracer
else:
return None
callbacks.append(LangChainTracer(project_name=project_name))
if run_evaluators:
callbacks.append(
EvaluatorCallbackHandler(
client=client,
evaluators=run_evaluators,
# We already have concurrency, don't want to overload the machine
max_workers=1,
)
)
return callbacks
async def arun_on_examples(
@ -262,13 +329,16 @@ async def arun_on_examples(
num_repetitions: int = 1,
project_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]:
"""
Run the chain on examples and store traces to the specified project name.
Asynchronously run the chain on examples and store traces
to the specified project name.
Args:
examples: Examples to run the model or chain over
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
@ -277,24 +347,35 @@ async def arun_on_examples(
This is useful when testing success rates or generating confidence
intervals.
project_name: Project name to use when tracing runs.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to the traces.
client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns:
A dictionary mapping example ids to the model outputs.
"""
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
results: Dict[str, List[Any]] = {}
evaluation_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
async def process_example(
example: Example, tracer: Optional[LangChainTracer], job_state: dict
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
) -> None:
"""Process a single example."""
result = await _arun_llm_or_chain(
example,
llm_or_chain_factory,
num_repetitions,
tracer,
tags=tags,
callbacks=callbacks,
)
results[str(example.id)] = result
job_state["num_processed"] += 1
@ -307,9 +388,15 @@ async def arun_on_examples(
await _gather_with_concurrency(
concurrency_level,
functools.partial(_tracer_initializer, project_name),
functools.partial(
_callbacks_initializer,
project_name=project_name,
client=client_,
run_evaluators=run_evaluators or [],
),
*(functools.partial(process_example, e) for e in examples),
)
evaluation_handler.wait_for_futures()
return results
@ -320,7 +407,21 @@ def run_llm(
*,
tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
"""
Run the language model on the example.
Args:
llm: The language model to run.
inputs: The input dictionary.
callbacks: The callbacks to use during the run.
tags: Optional tags to add to the run.
Returns:
The LLMResult or ChatResult.
Raises:
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
@ -350,18 +451,32 @@ def run_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
langchain_tracer: Optional[LangChainTracer] = None,
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously."""
if langchain_tracer is not None:
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
"""
Run the Chain or language model synchronously.
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
A list of outputs.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else:
previous_example_id = None
callbacks = None
previous_example_ids = None
outputs = []
for _ in range(n_repetitions):
try:
@ -376,8 +491,10 @@ def run_llm_or_chain(
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
if langchain_tracer is not None:
langchain_tracer.example_id = previous_example_id
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
return outputs
@ -388,48 +505,74 @@ def run_on_examples(
num_repetitions: int = 1,
project_name: Optional[str] = None,
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]:
"""Run the chain on examples and store traces to the specified project name.
"""
Run the Chain or language model on examples and store
traces to the specified project name.
Args:
examples: Examples to run model or chain over.
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
concurrency_level: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
project_name: Project name to use when tracing runs.
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to the run traces.
client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns:
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
tracer = LangChainTracer(project_name=project_name) if project_name else None
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
tracer = LangChainTracer(project_name=project_name)
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
for i, example in enumerate(examples):
result = run_llm_or_chain(
example,
llm_or_chain_factory,
num_repetitions,
langchain_tracer=tracer,
tags=tags,
callbacks=callbacks,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
if tracer:
tracer.wait_for_futures()
tracer.wait_for_futures()
evalution_handler.wait_for_futures()
return results
def _get_project_name(
project_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str,
dataset_name: Optional[str],
) -> str:
"""
Get the project name.
Args:
project_name: The project name if manually specified.
llm_or_chain_factory: The Chain or language model constructor.
dataset_name: The dataset name.
Returns:
The project name.
"""
if project_name is not None:
return project_name
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
@ -437,7 +580,8 @@ def _get_project_name(
model_name = llm_or_chain_factory.__class__.__name__
else:
model_name = llm_or_chain_factory().__class__.__name__
return f"{dataset_name}-{model_name}-{current_time}"
dataset_prefix = f"{dataset_name}-" if dataset_name else ""
return f"{dataset_prefix}{model_name}-{current_time}"
async def arun_on_dataset(
@ -450,12 +594,13 @@ async def arun_on_dataset(
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]:
"""
Run the chain on a dataset and store traces to the specified project name.
Asynchronously run the Chain or language model on a dataset
and store traces to the specified project name.
Args:
client: Client to use to read the dataset.
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
@ -469,7 +614,8 @@ async def arun_on_dataset(
verbose: Whether to print progress.
client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns:
A dictionary containing the run's project name and the resulting model outputs.
@ -478,7 +624,6 @@ async def arun_on_dataset(
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
results = await arun_on_examples(
examples,
llm_or_chain_factory,
@ -486,7 +631,9 @@ async def arun_on_dataset(
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,
client=client_,
tags=tags,
run_evaluators=run_evaluators,
)
return {
"project_name": project_name,
@ -503,8 +650,11 @@ def run_on_dataset(
verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified project name.
"""
Run the Chain or language model on a dataset and store traces
to the specified project name.
Args:
dataset_name: Name of the dataset to run the chain on.
@ -520,7 +670,8 @@ def run_on_dataset(
verbose: Whether to print progress.
client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns:
A dictionary containing the run's project name and the resulting model outputs.
@ -536,6 +687,8 @@ def run_on_dataset(
project_name=project_name,
verbose=verbose,
tags=tags,
run_evaluators=run_evaluators,
client=client_,
)
return {
"project_name": project_name,

@ -117,10 +117,12 @@ def get_qa_evaluator(
choices_map={"CORRECT": 1, "INCORRECT": 0},
),
)
tags = kwargs.pop("tags", [])
return RunEvaluatorChain(
eval_chain=eval_chain,
input_mapper=input_mapper,
output_parser=output_parser,
tags=tags + [evaluation_name],
**kwargs,
)
@ -174,6 +176,7 @@ def get_criteria_evaluator(
choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name
),
)
tags = kwargs.pop("tags", [])
eval_chain = CriteriaEvalChain.from_llm(
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
)
@ -181,6 +184,7 @@ def get_criteria_evaluator(
eval_chain=eval_chain,
input_mapper=input_mapper,
output_parser=parser,
tags=tags + [evaluation_name],
**kwargs,
)
@ -303,9 +307,11 @@ def get_trajectory_evaluator(
TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
)
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
tags = kwargs.pop("tags", [])
return RunEvaluatorChain(
eval_chain=eval_chain,
input_mapper=input_mapper,
output_parser=parser,
tags=tags + [evaluation_name],
**kwargs,
)

File diff suppressed because it is too large Load Diff

@ -169,8 +169,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tracer: Any,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)

Loading…
Cancel
Save