Update to RunOnDataset helper functions to accept evaluator callbacks (#6629)

Also improve docstrings and update the tracing datasets notebook to
focus on "debug, evaluate, monitor"
This commit is contained in:
Zander Chase 2023-06-26 23:58:13 -07:00 committed by GitHub
parent 7ac9b22886
commit 6ca383ecf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1054 additions and 751 deletions

View File

@ -0,0 +1,84 @@
"""A tracer that runs evaluators over completed runs."""
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Optional, Sequence, Set, Union
from uuid import UUID
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
Parameters
----------
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
max_workers : int, optional
The maximum number of worker threads to use for running the evaluators.
If not specified, it will default to the number of evaluators.
client : LangChainPlusClient, optional
The LangChainPlusClient instance to use for evaluating the runs.
If not specified, a new instance will be created.
example_id : Union[UUID, str], optional
The example ID to be associated with the runs.
Attributes
----------
example_id : Union[UUID, None]
The example ID associated with the runs.
client : LangChainPlusClient
The LangChainPlusClient instance used for evaluating the runs.
evaluators : Sequence[RunEvaluator]
The sequence of run evaluators to be executed.
executor : ThreadPoolExecutor
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
"""
name = "evaluator_callback_handler"
def __init__(
self,
evaluators: Sequence[RunEvaluator],
max_workers: Optional[int] = None,
client: Optional[LangChainPlusClient] = None,
example_id: Optional[Union[UUID, str]] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or LangChainPlusClient()
self.evaluators = evaluators
self.executor = ThreadPoolExecutor(
max_workers=max(max_workers or len(evaluators), 1)
)
self.futures: Set[Future] = set()
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
Parameters
----------
run : Run
The run to be evaluated.
"""
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
self.futures.add(
self.executor.submit(self.client.evaluate_run, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""
futures = list(self.futures)
wait(futures)
for future in futures:
self.futures.remove(future)

View File

@ -1,20 +1,52 @@
"""A tracer that collects all nested runs in a list.""" """A tracer that collects all nested runs in a list."""
from typing import Any, List
from typing import Any, List, Optional, Union
from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
class RunCollectorCallbackHandler(BaseTracer): class RunCollectorCallbackHandler(BaseTracer):
"""A tracer that collects all nested runs in a list. """
A tracer that collects all nested runs in a list.
Useful for inspection and for evaluation.""" This tracer is useful for inspection and evaluation purposes.
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
name = "run-collector_callback_handler" name = "run-collector_callback_handler"
def __init__(self, **kwargs: Any) -> None: def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
) -> None:
"""
Initialize the RunCollectorCallbackHandler.
Parameters
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.traced_runs: List[Run] = [] self.traced_runs: List[Run] = []
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
self.traced_runs.append(run) """
Persist a run by adding it to the traced_runs list.
Parameters
----------
run : Run
The run to be persisted.
"""
run_ = run.copy()
run_.reference_example_id = self.example_id
self.traced_runs.append(run_)

View File

@ -1,4 +1,5 @@
"""Utilities for running LLMs/Chains over datasets.""" """Utilities for running language models or Chains over datasets."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -13,15 +14,18 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
Sequence,
Union, Union,
) )
from langchainplus_sdk import LangChainPlusClient from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchainplus_sdk.schemas import Example from langchainplus_sdk.schemas import Example
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
@ -41,11 +45,21 @@ MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
class InputFormatError(Exception): class InputFormatError(Exception):
"""Raised when input format is invalid.""" """Raised when the input format is invalid."""
def _get_prompts(inputs: Dict[str, Any]) -> List[str]: def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
"""Get prompts from inputs.""" """
Get prompts from inputs.
Args:
inputs: The input dictionary.
Returns:
A list of prompts.
Raises:
InputFormatError: If the input format is invalid.
"""
if not inputs: if not inputs:
raise InputFormatError("Inputs should not be empty.") raise InputFormatError("Inputs should not be empty.")
@ -83,7 +97,17 @@ def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]: def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
"""Get Chat Messages from inputs.""" """
Get Chat Messages from inputs.
Args:
inputs: The input dictionary.
Returns:
A list of chat messages.
Raises:
InputFormatError: If the input format is invalid.
"""
if not inputs: if not inputs:
raise InputFormatError("Inputs should not be empty.") raise InputFormatError("Inputs should not be empty.")
@ -112,13 +136,25 @@ def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
async def _arun_llm( async def _arun_llm(
llm: BaseLanguageModel, llm: BaseLanguageModel,
inputs: Dict[str, Any], inputs: Dict[str, Any],
langchain_tracer: Optional[LangChainTracer],
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
callbacks: Optional[List[BaseCallbackHandler]] = ( """
[langchain_tracer] if langchain_tracer else None Asynchronously run the language model.
)
Args:
llm: The language model to run.
inputs: The input dictionary.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
The LLMResult or ChatResult.
Raises:
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
try: try:
llm_prompts = _get_prompts(inputs) llm_prompts = _get_prompts(inputs)
@ -152,18 +188,32 @@ async def _arun_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int, n_repetitions: int,
langchain_tracer: Optional[LangChainTracer],
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously.""" """
if langchain_tracer is not None: Asynchronously run the Chain or language model.
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id Args:
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer] example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
A list of outputs.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else: else:
previous_example_id = None previous_example_ids = None
callbacks = None
outputs = [] outputs = []
for _ in range(n_repetitions): for _ in range(n_repetitions):
try: try:
@ -171,8 +221,8 @@ async def _arun_llm_or_chain(
output: Any = await _arun_llm( output: Any = await _arun_llm(
llm_or_chain_factory, llm_or_chain_factory,
example.inputs, example.inputs,
langchain_tracer,
tags=tags, tags=tags,
callbacks=callbacks,
) )
else: else:
chain = llm_or_chain_factory() chain = llm_or_chain_factory()
@ -183,15 +233,19 @@ async def _arun_llm_or_chain(
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)}) outputs.append({"Error": str(e)})
if langchain_tracer is not None: if callbacks and previous_example_ids:
langchain_tracer.example_id = previous_example_id for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
return outputs return outputs
async def _gather_with_concurrency( async def _gather_with_concurrency(
n: int, n: int,
initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]], initializer: Callable[[], Coroutine[Any, Any, Any]],
*async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]], *async_funcs: Callable[
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
],
) -> List[Any]: ) -> List[Any]:
""" """
Run coroutines with a concurrency limit. Run coroutines with a concurrency limit.
@ -207,37 +261,42 @@ async def _gather_with_concurrency(
semaphore = asyncio.Semaphore(n) semaphore = asyncio.Semaphore(n)
job_state = {"num_processed": 0} job_state = {"num_processed": 0}
tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue() callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
for _ in range(n): for _ in range(n):
tracer_queue.put_nowait(await initializer()) callback_queue.put_nowait(await initializer())
async def run_coroutine_with_semaphore( async def run_coroutine_with_semaphore(
async_func: Callable[ async_func: Callable[
[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any] [Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
] ]
) -> Any: ) -> Any:
async with semaphore: async with semaphore:
tracer = await tracer_queue.get() callbacks = await callback_queue.get()
try: try:
result = await async_func(tracer, job_state) result = await async_func(callbacks, job_state)
finally: finally:
tracer_queue.put_nowait(tracer) callback_queue.put_nowait(callbacks)
return result return result
results = await asyncio.gather( results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs) *(run_coroutine_with_semaphore(function) for function in async_funcs)
) )
while tracer_queue: while callback_queue:
try: try:
tracer = tracer_queue.get_nowait() callbacks = callback_queue.get_nowait()
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
break break
if tracer: for callback in callbacks:
tracer.wait_for_futures() if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
callback.wait_for_futures()
return results return results
async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChainTracer]: async def _callbacks_initializer(
project_name: Optional[str],
client: LangChainPlusClient,
run_evaluators: Sequence[RunEvaluator],
) -> List[BaseTracer]:
""" """
Initialize a tracer to share across tasks. Initialize a tracer to share across tasks.
@ -247,11 +306,19 @@ async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChain
Returns: Returns:
A LangChainTracer instance with an active project. A LangChainTracer instance with an active project.
""" """
callbacks: List[BaseTracer] = []
if project_name: if project_name:
tracer = LangChainTracer(project_name=project_name) callbacks.append(LangChainTracer(project_name=project_name))
return tracer if run_evaluators:
else: callbacks.append(
return None EvaluatorCallbackHandler(
client=client,
evaluators=run_evaluators,
# We already have concurrency, don't want to overload the machine
max_workers=1,
)
)
return callbacks
async def arun_on_examples( async def arun_on_examples(
@ -262,13 +329,16 @@ async def arun_on_examples(
num_repetitions: int = 1, num_repetitions: int = 1,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the chain on examples and store traces to the specified project name. Asynchronously run the chain on examples and store traces
to the specified project name.
Args: Args:
examples: Examples to run the model or chain over examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
@ -277,24 +347,35 @@ async def arun_on_examples(
This is useful when testing success rates or generating confidence This is useful when testing success rates or generating confidence
intervals. intervals.
project_name: Project name to use when tracing runs. project_name: Project name to use when tracing runs.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to the traces. client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
results: Dict[str, List[Any]] = {} results: Dict[str, List[Any]] = {}
evaluation_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
async def process_example( async def process_example(
example: Example, tracer: Optional[LangChainTracer], job_state: dict example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
) -> None: ) -> None:
"""Process a single example.""" """Process a single example."""
result = await _arun_llm_or_chain( result = await _arun_llm_or_chain(
example, example,
llm_or_chain_factory, llm_or_chain_factory,
num_repetitions, num_repetitions,
tracer,
tags=tags, tags=tags,
callbacks=callbacks,
) )
results[str(example.id)] = result results[str(example.id)] = result
job_state["num_processed"] += 1 job_state["num_processed"] += 1
@ -307,9 +388,15 @@ async def arun_on_examples(
await _gather_with_concurrency( await _gather_with_concurrency(
concurrency_level, concurrency_level,
functools.partial(_tracer_initializer, project_name), functools.partial(
_callbacks_initializer,
project_name=project_name,
client=client_,
run_evaluators=run_evaluators or [],
),
*(functools.partial(process_example, e) for e in examples), *(functools.partial(process_example, e) for e in examples),
) )
evaluation_handler.wait_for_futures()
return results return results
@ -320,7 +407,21 @@ def run_llm(
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example.""" """
Run the language model on the example.
Args:
llm: The language model to run.
inputs: The input dictionary.
callbacks: The callbacks to use during the run.
tags: Optional tags to add to the run.
Returns:
The LLMResult or ChatResult.
Raises:
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
try: try:
llm_prompts = _get_prompts(inputs) llm_prompts = _get_prompts(inputs)
@ -350,18 +451,32 @@ def run_llm_or_chain(
example: Example, example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int, n_repetitions: int,
langchain_tracer: Optional[LangChainTracer] = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously.""" """
if langchain_tracer is not None: Run the Chain or language model synchronously.
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id Args:
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer] example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
Returns:
A list of outputs.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else: else:
previous_example_id = None previous_example_ids = None
callbacks = None
outputs = [] outputs = []
for _ in range(n_repetitions): for _ in range(n_repetitions):
try: try:
@ -376,8 +491,10 @@ def run_llm_or_chain(
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)}) outputs.append({"Error": str(e)})
if langchain_tracer is not None: if callbacks and previous_example_ids:
langchain_tracer.example_id = previous_example_id for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
return outputs return outputs
@ -388,48 +505,74 @@ def run_on_examples(
num_repetitions: int = 1, num_repetitions: int = 1,
project_name: Optional[str] = None, project_name: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the chain on examples and store traces to the specified project name. """
Run the Chain or language model on examples and store
traces to the specified project name.
Args: Args:
examples: Examples to run model or chain over. examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state. independent calls on each example without carrying over state.
concurrency_level: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example. num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence This is useful when testing success rates or generating confidence
intervals. intervals.
project_name: Project name to use when tracing runs. project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress. verbose: Whether to print progress.
tags: Tags to add to the run traces. client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
tracer = LangChainTracer(project_name=project_name) if project_name else None project_name = _get_project_name(project_name, llm_or_chain_factory, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
tracer = LangChainTracer(project_name=project_name)
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
for i, example in enumerate(examples): for i, example in enumerate(examples):
result = run_llm_or_chain( result = run_llm_or_chain(
example, example,
llm_or_chain_factory, llm_or_chain_factory,
num_repetitions, num_repetitions,
langchain_tracer=tracer,
tags=tags, tags=tags,
callbacks=callbacks,
) )
if verbose: if verbose:
print(f"{i+1} processed", flush=True, end="\r") print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result results[str(example.id)] = result
if tracer: tracer.wait_for_futures()
tracer.wait_for_futures() evalution_handler.wait_for_futures()
return results return results
def _get_project_name( def _get_project_name(
project_name: Optional[str], project_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str, dataset_name: Optional[str],
) -> str: ) -> str:
"""
Get the project name.
Args:
project_name: The project name if manually specified.
llm_or_chain_factory: The Chain or language model constructor.
dataset_name: The dataset name.
Returns:
The project name.
"""
if project_name is not None: if project_name is not None:
return project_name return project_name
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
@ -437,7 +580,8 @@ def _get_project_name(
model_name = llm_or_chain_factory.__class__.__name__ model_name = llm_or_chain_factory.__class__.__name__
else: else:
model_name = llm_or_chain_factory().__class__.__name__ model_name = llm_or_chain_factory().__class__.__name__
return f"{dataset_name}-{model_name}-{current_time}" dataset_prefix = f"{dataset_name}-" if dataset_name else ""
return f"{dataset_prefix}{model_name}-{current_time}"
async def arun_on_dataset( async def arun_on_dataset(
@ -450,12 +594,13 @@ async def arun_on_dataset(
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None, client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the chain on a dataset and store traces to the specified project name. Asynchronously run the Chain or language model on a dataset
and store traces to the specified project name.
Args: Args:
client: Client to use to read the dataset.
dataset_name: Name of the dataset to run the chain on. dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit over the dataset. The Chain constructor is used to permit
@ -469,7 +614,8 @@ async def arun_on_dataset(
verbose: Whether to print progress. verbose: Whether to print progress.
client: Client to use to read the dataset. If not provided, a new client: Client to use to read the dataset. If not provided, a new
client will be created using the credentials in the environment. client will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion. tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns: Returns:
A dictionary containing the run's project name and the resulting model outputs. A dictionary containing the run's project name and the resulting model outputs.
@ -478,7 +624,6 @@ async def arun_on_dataset(
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name) project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name) dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id)) examples = client_.list_examples(dataset_id=str(dataset.id))
results = await arun_on_examples( results = await arun_on_examples(
examples, examples,
llm_or_chain_factory, llm_or_chain_factory,
@ -486,7 +631,9 @@ async def arun_on_dataset(
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
client=client_,
tags=tags, tags=tags,
run_evaluators=run_evaluators,
) )
return { return {
"project_name": project_name, "project_name": project_name,
@ -503,8 +650,11 @@ def run_on_dataset(
verbose: bool = False, verbose: bool = False,
client: Optional[LangChainPlusClient] = None, client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified project name. """
Run the Chain or language model on a dataset and store traces
to the specified project name.
Args: Args:
dataset_name: Name of the dataset to run the chain on. dataset_name: Name of the dataset to run the chain on.
@ -520,7 +670,8 @@ def run_on_dataset(
verbose: Whether to print progress. verbose: Whether to print progress.
client: Client to use to access the dataset. If None, a new client client: Client to use to access the dataset. If None, a new client
will be created using the credentials in the environment. will be created using the credentials in the environment.
tags: Tags to add to each run in the sesssion. tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
Returns: Returns:
A dictionary containing the run's project name and the resulting model outputs. A dictionary containing the run's project name and the resulting model outputs.
@ -536,6 +687,8 @@ def run_on_dataset(
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
tags=tags, tags=tags,
run_evaluators=run_evaluators,
client=client_,
) )
return { return {
"project_name": project_name, "project_name": project_name,

View File

@ -117,10 +117,12 @@ def get_qa_evaluator(
choices_map={"CORRECT": 1, "INCORRECT": 0}, choices_map={"CORRECT": 1, "INCORRECT": 0},
), ),
) )
tags = kwargs.pop("tags", [])
return RunEvaluatorChain( return RunEvaluatorChain(
eval_chain=eval_chain, eval_chain=eval_chain,
input_mapper=input_mapper, input_mapper=input_mapper,
output_parser=output_parser, output_parser=output_parser,
tags=tags + [evaluation_name],
**kwargs, **kwargs,
) )
@ -174,6 +176,7 @@ def get_criteria_evaluator(
choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name
), ),
) )
tags = kwargs.pop("tags", [])
eval_chain = CriteriaEvalChain.from_llm( eval_chain = CriteriaEvalChain.from_llm(
llm=llm, criteria=criteria_, prompt=prompt, **kwargs llm=llm, criteria=criteria_, prompt=prompt, **kwargs
) )
@ -181,6 +184,7 @@ def get_criteria_evaluator(
eval_chain=eval_chain, eval_chain=eval_chain,
input_mapper=input_mapper, input_mapper=input_mapper,
output_parser=parser, output_parser=parser,
tags=tags + [evaluation_name],
**kwargs, **kwargs,
) )
@ -303,9 +307,11 @@ def get_trajectory_evaluator(
TrajectoryEvalOutputParser(evaluation_name=evaluation_name), TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
) )
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs) eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
tags = kwargs.pop("tags", [])
return RunEvaluatorChain( return RunEvaluatorChain(
eval_chain=eval_chain, eval_chain=eval_chain,
input_mapper=input_mapper, input_mapper=input_mapper,
output_parser=parser, output_parser=parser,
tags=tags + [evaluation_name],
**kwargs, **kwargs,
) )

File diff suppressed because it is too large Load Diff

View File

@ -169,8 +169,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
example: Example, example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain], llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int, n_repetitions: int,
tracer: Any,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
return [ return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions) {"result": f"Result for example {example.id}"} for _ in range(n_repetitions)