mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Update to RunOnDataset helper functions to accept evaluator callbacks (#6629)
Also improve docstrings and update the tracing datasets notebook to focus on "debug, evaluate, monitor"
This commit is contained in:
parent
7ac9b22886
commit
6ca383ecf6
84
langchain/callbacks/tracers/evaluation.py
Normal file
84
langchain/callbacks/tracers/evaluation.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""A tracer that runs evaluators over completed runs."""
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
|
from typing import Any, Optional, Sequence, Set, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorCallbackHandler(BaseTracer):
|
||||||
|
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
evaluators : Sequence[RunEvaluator]
|
||||||
|
The run evaluators to apply to all top level runs.
|
||||||
|
max_workers : int, optional
|
||||||
|
The maximum number of worker threads to use for running the evaluators.
|
||||||
|
If not specified, it will default to the number of evaluators.
|
||||||
|
client : LangChainPlusClient, optional
|
||||||
|
The LangChainPlusClient instance to use for evaluating the runs.
|
||||||
|
If not specified, a new instance will be created.
|
||||||
|
example_id : Union[UUID, str], optional
|
||||||
|
The example ID to be associated with the runs.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
example_id : Union[UUID, None]
|
||||||
|
The example ID associated with the runs.
|
||||||
|
client : LangChainPlusClient
|
||||||
|
The LangChainPlusClient instance used for evaluating the runs.
|
||||||
|
evaluators : Sequence[RunEvaluator]
|
||||||
|
The sequence of run evaluators to be executed.
|
||||||
|
executor : ThreadPoolExecutor
|
||||||
|
The thread pool executor used for running the evaluators.
|
||||||
|
futures : Set[Future]
|
||||||
|
The set of futures representing the running evaluators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "evaluator_callback_handler"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
evaluators: Sequence[RunEvaluator],
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
client: Optional[LangChainPlusClient] = None,
|
||||||
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.example_id = (
|
||||||
|
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||||
|
)
|
||||||
|
self.client = client or LangChainPlusClient()
|
||||||
|
self.evaluators = evaluators
|
||||||
|
self.executor = ThreadPoolExecutor(
|
||||||
|
max_workers=max(max_workers or len(evaluators), 1)
|
||||||
|
)
|
||||||
|
self.futures: Set[Future] = set()
|
||||||
|
|
||||||
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Run the evaluator on the run.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
run : Run
|
||||||
|
The run to be evaluated.
|
||||||
|
|
||||||
|
"""
|
||||||
|
run_ = run.copy()
|
||||||
|
run_.reference_example_id = self.example_id
|
||||||
|
for evaluator in self.evaluators:
|
||||||
|
self.futures.add(
|
||||||
|
self.executor.submit(self.client.evaluate_run, run_, evaluator)
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for_futures(self) -> None:
|
||||||
|
"""Wait for all futures to complete."""
|
||||||
|
futures = list(self.futures)
|
||||||
|
wait(futures)
|
||||||
|
for future in futures:
|
||||||
|
self.futures.remove(future)
|
@ -1,20 +1,52 @@
|
|||||||
"""A tracer that collects all nested runs in a list."""
|
"""A tracer that collects all nested runs in a list."""
|
||||||
from typing import Any, List
|
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
class RunCollectorCallbackHandler(BaseTracer):
|
class RunCollectorCallbackHandler(BaseTracer):
|
||||||
"""A tracer that collects all nested runs in a list.
|
"""
|
||||||
|
A tracer that collects all nested runs in a list.
|
||||||
|
|
||||||
Useful for inspection and for evaluation."""
|
This tracer is useful for inspection and evaluation purposes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example_id : Optional[Union[UUID, str]], default=None
|
||||||
|
The ID of the example being traced. It can be either a UUID or a string.
|
||||||
|
"""
|
||||||
|
|
||||||
name = "run-collector_callback_handler"
|
name = "run-collector_callback_handler"
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the RunCollectorCallbackHandler.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example_id : Optional[Union[UUID, str]], default=None
|
||||||
|
The ID of the example being traced. It can be either a UUID or a string.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.example_id = (
|
||||||
|
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||||
|
)
|
||||||
self.traced_runs: List[Run] = []
|
self.traced_runs: List[Run] = []
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
self.traced_runs.append(run)
|
"""
|
||||||
|
Persist a run by adding it to the traced_runs list.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
run : Run
|
||||||
|
The run to be persisted.
|
||||||
|
"""
|
||||||
|
run_ = run.copy()
|
||||||
|
run_.reference_example_id = self.example_id
|
||||||
|
self.traced_runs.append(run_)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Utilities for running LLMs/Chains over datasets."""
|
"""Utilities for running language models or Chains over datasets."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -13,15 +14,18 @@ from typing import (
|
|||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchainplus_sdk import LangChainPlusClient
|
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
|
||||||
from langchainplus_sdk.schemas import Example
|
from langchainplus_sdk.schemas import Example
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
@ -41,11 +45,21 @@ MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
|||||||
|
|
||||||
|
|
||||||
class InputFormatError(Exception):
|
class InputFormatError(Exception):
|
||||||
"""Raised when input format is invalid."""
|
"""Raised when the input format is invalid."""
|
||||||
|
|
||||||
|
|
||||||
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
||||||
"""Get prompts from inputs."""
|
"""
|
||||||
|
Get prompts from inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: The input dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of prompts.
|
||||||
|
Raises:
|
||||||
|
InputFormatError: If the input format is invalid.
|
||||||
|
"""
|
||||||
if not inputs:
|
if not inputs:
|
||||||
raise InputFormatError("Inputs should not be empty.")
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
@ -83,7 +97,17 @@ def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
||||||
"""Get Chat Messages from inputs."""
|
"""
|
||||||
|
Get Chat Messages from inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: The input dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of chat messages.
|
||||||
|
Raises:
|
||||||
|
InputFormatError: If the input format is invalid.
|
||||||
|
"""
|
||||||
if not inputs:
|
if not inputs:
|
||||||
raise InputFormatError("Inputs should not be empty.")
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
@ -112,13 +136,25 @@ def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
|||||||
async def _arun_llm(
|
async def _arun_llm(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
langchain_tracer: Optional[LangChainTracer],
|
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
callbacks: Optional[List[BaseCallbackHandler]] = (
|
"""
|
||||||
[langchain_tracer] if langchain_tracer else None
|
Asynchronously run the language model.
|
||||||
)
|
|
||||||
|
Args:
|
||||||
|
llm: The language model to run.
|
||||||
|
inputs: The input dictionary.
|
||||||
|
tags: Optional tags to add to the run.
|
||||||
|
callbacks: Optional callbacks to use during the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LLMResult or ChatResult.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the LLM type is unsupported.
|
||||||
|
InputFormatError: If the input format is invalid.
|
||||||
|
"""
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
try:
|
try:
|
||||||
llm_prompts = _get_prompts(inputs)
|
llm_prompts = _get_prompts(inputs)
|
||||||
@ -152,18 +188,32 @@ async def _arun_llm_or_chain(
|
|||||||
example: Example,
|
example: Example,
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
langchain_tracer: Optional[LangChainTracer],
|
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
"""Run the chain asynchronously."""
|
"""
|
||||||
if langchain_tracer is not None:
|
Asynchronously run the Chain or language model.
|
||||||
previous_example_id = langchain_tracer.example_id
|
|
||||||
langchain_tracer.example_id = example.id
|
Args:
|
||||||
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
example: The example to run.
|
||||||
|
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||||
|
n_repetitions: The number of times to run the model on each example.
|
||||||
|
tags: Optional tags to add to the run.
|
||||||
|
callbacks: Optional callbacks to use during the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of outputs.
|
||||||
|
"""
|
||||||
|
if callbacks:
|
||||||
|
previous_example_ids = [
|
||||||
|
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||||
|
]
|
||||||
|
for tracer in callbacks:
|
||||||
|
if hasattr(tracer, "example_id"):
|
||||||
|
tracer.example_id = example.id
|
||||||
else:
|
else:
|
||||||
previous_example_id = None
|
previous_example_ids = None
|
||||||
callbacks = None
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for _ in range(n_repetitions):
|
for _ in range(n_repetitions):
|
||||||
try:
|
try:
|
||||||
@ -171,8 +221,8 @@ async def _arun_llm_or_chain(
|
|||||||
output: Any = await _arun_llm(
|
output: Any = await _arun_llm(
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
example.inputs,
|
example.inputs,
|
||||||
langchain_tracer,
|
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chain = llm_or_chain_factory()
|
chain = llm_or_chain_factory()
|
||||||
@ -183,15 +233,19 @@ async def _arun_llm_or_chain(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||||
outputs.append({"Error": str(e)})
|
outputs.append({"Error": str(e)})
|
||||||
if langchain_tracer is not None:
|
if callbacks and previous_example_ids:
|
||||||
langchain_tracer.example_id = previous_example_id
|
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||||
|
if hasattr(tracer, "example_id"):
|
||||||
|
tracer.example_id = example_id
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
async def _gather_with_concurrency(
|
async def _gather_with_concurrency(
|
||||||
n: int,
|
n: int,
|
||||||
initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]],
|
initializer: Callable[[], Coroutine[Any, Any, Any]],
|
||||||
*async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]],
|
*async_funcs: Callable[
|
||||||
|
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||||
|
],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Run coroutines with a concurrency limit.
|
Run coroutines with a concurrency limit.
|
||||||
@ -207,37 +261,42 @@ async def _gather_with_concurrency(
|
|||||||
semaphore = asyncio.Semaphore(n)
|
semaphore = asyncio.Semaphore(n)
|
||||||
job_state = {"num_processed": 0}
|
job_state = {"num_processed": 0}
|
||||||
|
|
||||||
tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue()
|
callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
tracer_queue.put_nowait(await initializer())
|
callback_queue.put_nowait(await initializer())
|
||||||
|
|
||||||
async def run_coroutine_with_semaphore(
|
async def run_coroutine_with_semaphore(
|
||||||
async_func: Callable[
|
async_func: Callable[
|
||||||
[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]
|
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||||
]
|
]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
tracer = await tracer_queue.get()
|
callbacks = await callback_queue.get()
|
||||||
try:
|
try:
|
||||||
result = await async_func(tracer, job_state)
|
result = await async_func(callbacks, job_state)
|
||||||
finally:
|
finally:
|
||||||
tracer_queue.put_nowait(tracer)
|
callback_queue.put_nowait(callbacks)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||||
)
|
)
|
||||||
while tracer_queue:
|
while callback_queue:
|
||||||
try:
|
try:
|
||||||
tracer = tracer_queue.get_nowait()
|
callbacks = callback_queue.get_nowait()
|
||||||
except asyncio.QueueEmpty:
|
except asyncio.QueueEmpty:
|
||||||
break
|
break
|
||||||
if tracer:
|
for callback in callbacks:
|
||||||
tracer.wait_for_futures()
|
if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
|
||||||
|
callback.wait_for_futures()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChainTracer]:
|
async def _callbacks_initializer(
|
||||||
|
project_name: Optional[str],
|
||||||
|
client: LangChainPlusClient,
|
||||||
|
run_evaluators: Sequence[RunEvaluator],
|
||||||
|
) -> List[BaseTracer]:
|
||||||
"""
|
"""
|
||||||
Initialize a tracer to share across tasks.
|
Initialize a tracer to share across tasks.
|
||||||
|
|
||||||
@ -247,11 +306,19 @@ async def _tracer_initializer(project_name: Optional[str]) -> Optional[LangChain
|
|||||||
Returns:
|
Returns:
|
||||||
A LangChainTracer instance with an active project.
|
A LangChainTracer instance with an active project.
|
||||||
"""
|
"""
|
||||||
|
callbacks: List[BaseTracer] = []
|
||||||
if project_name:
|
if project_name:
|
||||||
tracer = LangChainTracer(project_name=project_name)
|
callbacks.append(LangChainTracer(project_name=project_name))
|
||||||
return tracer
|
if run_evaluators:
|
||||||
else:
|
callbacks.append(
|
||||||
return None
|
EvaluatorCallbackHandler(
|
||||||
|
client=client,
|
||||||
|
evaluators=run_evaluators,
|
||||||
|
# We already have concurrency, don't want to overload the machine
|
||||||
|
max_workers=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
async def arun_on_examples(
|
async def arun_on_examples(
|
||||||
@ -262,13 +329,16 @@ async def arun_on_examples(
|
|||||||
num_repetitions: int = 1,
|
num_repetitions: int = 1,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run the chain on examples and store traces to the specified project name.
|
Asynchronously run the chain on examples and store traces
|
||||||
|
to the specified project name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
examples: Examples to run the model or chain over
|
examples: Examples to run the model or chain over.
|
||||||
llm_or_chain_factory: Language model or Chain constructor to run
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
over the dataset. The Chain constructor is used to permit
|
over the dataset. The Chain constructor is used to permit
|
||||||
independent calls on each example without carrying over state.
|
independent calls on each example without carrying over state.
|
||||||
@ -277,24 +347,35 @@ async def arun_on_examples(
|
|||||||
This is useful when testing success rates or generating confidence
|
This is useful when testing success rates or generating confidence
|
||||||
intervals.
|
intervals.
|
||||||
project_name: Project name to use when tracing runs.
|
project_name: Project name to use when tracing runs.
|
||||||
|
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||||
verbose: Whether to print progress.
|
verbose: Whether to print progress.
|
||||||
tags: Tags to add to the traces.
|
client: Client to use to read the dataset. If not provided, a new
|
||||||
|
client will be created using the credentials in the environment.
|
||||||
|
tags: Tags to add to each run in the project.
|
||||||
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
"""
|
"""
|
||||||
|
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
|
||||||
|
client_ = client or LangChainPlusClient()
|
||||||
|
client_.create_project(project_name, mode="eval")
|
||||||
|
|
||||||
results: Dict[str, List[Any]] = {}
|
results: Dict[str, List[Any]] = {}
|
||||||
|
evaluation_handler = EvaluatorCallbackHandler(
|
||||||
|
evaluators=run_evaluators or [], client=client_
|
||||||
|
)
|
||||||
|
|
||||||
async def process_example(
|
async def process_example(
|
||||||
example: Example, tracer: Optional[LangChainTracer], job_state: dict
|
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a single example."""
|
"""Process a single example."""
|
||||||
result = await _arun_llm_or_chain(
|
result = await _arun_llm_or_chain(
|
||||||
example,
|
example,
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
num_repetitions,
|
num_repetitions,
|
||||||
tracer,
|
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
results[str(example.id)] = result
|
results[str(example.id)] = result
|
||||||
job_state["num_processed"] += 1
|
job_state["num_processed"] += 1
|
||||||
@ -307,9 +388,15 @@ async def arun_on_examples(
|
|||||||
|
|
||||||
await _gather_with_concurrency(
|
await _gather_with_concurrency(
|
||||||
concurrency_level,
|
concurrency_level,
|
||||||
functools.partial(_tracer_initializer, project_name),
|
functools.partial(
|
||||||
|
_callbacks_initializer,
|
||||||
|
project_name=project_name,
|
||||||
|
client=client_,
|
||||||
|
run_evaluators=run_evaluators or [],
|
||||||
|
),
|
||||||
*(functools.partial(process_example, e) for e in examples),
|
*(functools.partial(process_example, e) for e in examples),
|
||||||
)
|
)
|
||||||
|
evaluation_handler.wait_for_futures()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -320,7 +407,21 @@ def run_llm(
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""Run the language model on the example."""
|
"""
|
||||||
|
Run the language model on the example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: The language model to run.
|
||||||
|
inputs: The input dictionary.
|
||||||
|
callbacks: The callbacks to use during the run.
|
||||||
|
tags: Optional tags to add to the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LLMResult or ChatResult.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the LLM type is unsupported.
|
||||||
|
InputFormatError: If the input format is invalid.
|
||||||
|
"""
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
try:
|
try:
|
||||||
llm_prompts = _get_prompts(inputs)
|
llm_prompts = _get_prompts(inputs)
|
||||||
@ -350,18 +451,32 @@ def run_llm_or_chain(
|
|||||||
example: Example,
|
example: Example,
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
langchain_tracer: Optional[LangChainTracer] = None,
|
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
"""Run the chain synchronously."""
|
"""
|
||||||
if langchain_tracer is not None:
|
Run the Chain or language model synchronously.
|
||||||
previous_example_id = langchain_tracer.example_id
|
|
||||||
langchain_tracer.example_id = example.id
|
Args:
|
||||||
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
example: The example to run.
|
||||||
|
llm_or_chain_factory: The Chain or language model constructor to run.
|
||||||
|
n_repetitions: The number of times to run the model on each example.
|
||||||
|
tags: Optional tags to add to the run.
|
||||||
|
callbacks: Optional callbacks to use during the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of outputs.
|
||||||
|
"""
|
||||||
|
if callbacks:
|
||||||
|
previous_example_ids = [
|
||||||
|
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||||
|
]
|
||||||
|
for tracer in callbacks:
|
||||||
|
if hasattr(tracer, "example_id"):
|
||||||
|
tracer.example_id = example.id
|
||||||
else:
|
else:
|
||||||
previous_example_id = None
|
previous_example_ids = None
|
||||||
callbacks = None
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for _ in range(n_repetitions):
|
for _ in range(n_repetitions):
|
||||||
try:
|
try:
|
||||||
@ -376,8 +491,10 @@ def run_llm_or_chain(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||||
outputs.append({"Error": str(e)})
|
outputs.append({"Error": str(e)})
|
||||||
if langchain_tracer is not None:
|
if callbacks and previous_example_ids:
|
||||||
langchain_tracer.example_id = previous_example_id
|
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||||
|
if hasattr(tracer, "example_id"):
|
||||||
|
tracer.example_id = example_id
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -388,48 +505,74 @@ def run_on_examples(
|
|||||||
num_repetitions: int = 1,
|
num_repetitions: int = 1,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the chain on examples and store traces to the specified project name.
|
"""
|
||||||
|
Run the Chain or language model on examples and store
|
||||||
|
traces to the specified project name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
examples: Examples to run model or chain over.
|
examples: Examples to run the model or chain over.
|
||||||
llm_or_chain_factory: Language model or Chain constructor to run
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
over the dataset. The Chain constructor is used to permit
|
over the dataset. The Chain constructor is used to permit
|
||||||
independent calls on each example without carrying over state.
|
independent calls on each example without carrying over state.
|
||||||
concurrency_level: Number of async workers to run in parallel.
|
|
||||||
num_repetitions: Number of times to run the model on each example.
|
num_repetitions: Number of times to run the model on each example.
|
||||||
This is useful when testing success rates or generating confidence
|
This is useful when testing success rates or generating confidence
|
||||||
intervals.
|
intervals.
|
||||||
project_name: Project name to use when tracing runs.
|
project_name: Name of the project to store the traces in.
|
||||||
|
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||||
verbose: Whether to print progress.
|
verbose: Whether to print progress.
|
||||||
tags: Tags to add to the run traces.
|
client: Client to use to access the dataset. If None, a new client
|
||||||
|
will be created using the credentials in the environment.
|
||||||
|
tags: Tags to add to each run in the project.
|
||||||
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
"""
|
"""
|
||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
tracer = LangChainTracer(project_name=project_name) if project_name else None
|
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
|
||||||
|
client_ = client or LangChainPlusClient()
|
||||||
|
client_.create_project(project_name, mode="eval")
|
||||||
|
tracer = LangChainTracer(project_name=project_name)
|
||||||
|
evalution_handler = EvaluatorCallbackHandler(
|
||||||
|
evaluators=run_evaluators or [], client=client_
|
||||||
|
)
|
||||||
|
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
|
||||||
for i, example in enumerate(examples):
|
for i, example in enumerate(examples):
|
||||||
result = run_llm_or_chain(
|
result = run_llm_or_chain(
|
||||||
example,
|
example,
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
num_repetitions,
|
num_repetitions,
|
||||||
langchain_tracer=tracer,
|
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{i+1} processed", flush=True, end="\r")
|
print(f"{i+1} processed", flush=True, end="\r")
|
||||||
results[str(example.id)] = result
|
results[str(example.id)] = result
|
||||||
if tracer:
|
tracer.wait_for_futures()
|
||||||
tracer.wait_for_futures()
|
evalution_handler.wait_for_futures()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _get_project_name(
|
def _get_project_name(
|
||||||
project_name: Optional[str],
|
project_name: Optional[str],
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
dataset_name: str,
|
dataset_name: Optional[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the project name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_name: The project name if manually specified.
|
||||||
|
llm_or_chain_factory: The Chain or language model constructor.
|
||||||
|
dataset_name: The dataset name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The project name.
|
||||||
|
"""
|
||||||
if project_name is not None:
|
if project_name is not None:
|
||||||
return project_name
|
return project_name
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
@ -437,7 +580,8 @@ def _get_project_name(
|
|||||||
model_name = llm_or_chain_factory.__class__.__name__
|
model_name = llm_or_chain_factory.__class__.__name__
|
||||||
else:
|
else:
|
||||||
model_name = llm_or_chain_factory().__class__.__name__
|
model_name = llm_or_chain_factory().__class__.__name__
|
||||||
return f"{dataset_name}-{model_name}-{current_time}"
|
dataset_prefix = f"{dataset_name}-" if dataset_name else ""
|
||||||
|
return f"{dataset_prefix}{model_name}-{current_time}"
|
||||||
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
async def arun_on_dataset(
|
||||||
@ -450,12 +594,13 @@ async def arun_on_dataset(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run the chain on a dataset and store traces to the specified project name.
|
Asynchronously run the Chain or language model on a dataset
|
||||||
|
and store traces to the specified project name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: Client to use to read the dataset.
|
|
||||||
dataset_name: Name of the dataset to run the chain on.
|
dataset_name: Name of the dataset to run the chain on.
|
||||||
llm_or_chain_factory: Language model or Chain constructor to run
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
over the dataset. The Chain constructor is used to permit
|
over the dataset. The Chain constructor is used to permit
|
||||||
@ -469,7 +614,8 @@ async def arun_on_dataset(
|
|||||||
verbose: Whether to print progress.
|
verbose: Whether to print progress.
|
||||||
client: Client to use to read the dataset. If not provided, a new
|
client: Client to use to read the dataset. If not provided, a new
|
||||||
client will be created using the credentials in the environment.
|
client will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the sesssion.
|
tags: Tags to add to each run in the project.
|
||||||
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the run's project name and the resulting model outputs.
|
A dictionary containing the run's project name and the resulting model outputs.
|
||||||
@ -478,7 +624,6 @@ async def arun_on_dataset(
|
|||||||
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
|
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
|
||||||
dataset = client_.read_dataset(dataset_name=dataset_name)
|
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||||
examples = client_.list_examples(dataset_id=str(dataset.id))
|
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||||
|
|
||||||
results = await arun_on_examples(
|
results = await arun_on_examples(
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
@ -486,7 +631,9 @@ async def arun_on_dataset(
|
|||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
|
client=client_,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
run_evaluators=run_evaluators,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"project_name": project_name,
|
"project_name": project_name,
|
||||||
@ -503,8 +650,11 @@ def run_on_dataset(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the chain on a dataset and store traces to the specified project name.
|
"""
|
||||||
|
Run the Chain or language model on a dataset and store traces
|
||||||
|
to the specified project name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_name: Name of the dataset to run the chain on.
|
dataset_name: Name of the dataset to run the chain on.
|
||||||
@ -520,7 +670,8 @@ def run_on_dataset(
|
|||||||
verbose: Whether to print progress.
|
verbose: Whether to print progress.
|
||||||
client: Client to use to access the dataset. If None, a new client
|
client: Client to use to access the dataset. If None, a new client
|
||||||
will be created using the credentials in the environment.
|
will be created using the credentials in the environment.
|
||||||
tags: Tags to add to each run in the sesssion.
|
tags: Tags to add to each run in the project.
|
||||||
|
run_evaluators: Evaluators to run on the results of the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the run's project name and the resulting model outputs.
|
A dictionary containing the run's project name and the resulting model outputs.
|
||||||
@ -536,6 +687,8 @@ def run_on_dataset(
|
|||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
run_evaluators=run_evaluators,
|
||||||
|
client=client_,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"project_name": project_name,
|
"project_name": project_name,
|
||||||
|
@ -117,10 +117,12 @@ def get_qa_evaluator(
|
|||||||
choices_map={"CORRECT": 1, "INCORRECT": 0},
|
choices_map={"CORRECT": 1, "INCORRECT": 0},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
tags = kwargs.pop("tags", [])
|
||||||
return RunEvaluatorChain(
|
return RunEvaluatorChain(
|
||||||
eval_chain=eval_chain,
|
eval_chain=eval_chain,
|
||||||
input_mapper=input_mapper,
|
input_mapper=input_mapper,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
|
tags=tags + [evaluation_name],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,6 +176,7 @@ def get_criteria_evaluator(
|
|||||||
choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name
|
choices_map={"Y": 1, "N": 0}, evaluation_name=evaluation_name
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
tags = kwargs.pop("tags", [])
|
||||||
eval_chain = CriteriaEvalChain.from_llm(
|
eval_chain = CriteriaEvalChain.from_llm(
|
||||||
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
|
llm=llm, criteria=criteria_, prompt=prompt, **kwargs
|
||||||
)
|
)
|
||||||
@ -181,6 +184,7 @@ def get_criteria_evaluator(
|
|||||||
eval_chain=eval_chain,
|
eval_chain=eval_chain,
|
||||||
input_mapper=input_mapper,
|
input_mapper=input_mapper,
|
||||||
output_parser=parser,
|
output_parser=parser,
|
||||||
|
tags=tags + [evaluation_name],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -303,9 +307,11 @@ def get_trajectory_evaluator(
|
|||||||
TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
|
TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
|
||||||
)
|
)
|
||||||
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
|
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
|
||||||
|
tags = kwargs.pop("tags", [])
|
||||||
return RunEvaluatorChain(
|
return RunEvaluatorChain(
|
||||||
eval_chain=eval_chain,
|
eval_chain=eval_chain,
|
||||||
input_mapper=input_mapper,
|
input_mapper=input_mapper,
|
||||||
output_parser=parser,
|
output_parser=parser,
|
||||||
|
tags=tags + [evaluation_name],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -169,8 +169,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
example: Example,
|
example: Example,
|
||||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
tracer: Any,
|
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
|
callbacks: Optional[Any] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
|
Loading…
Reference in New Issue
Block a user