|
|
|
@ -2,10 +2,12 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import dataclasses
|
|
|
|
|
import functools
|
|
|
|
|
import inspect
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
@ -32,11 +34,12 @@ from langchain_core.tracers.evaluation import (
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.tracers.langchain import LangChainTracer
|
|
|
|
|
from langsmith.client import Client
|
|
|
|
|
from langsmith.evaluation import RunEvaluator
|
|
|
|
|
from langsmith.evaluation import EvaluationResult, RunEvaluator
|
|
|
|
|
from langsmith.run_helpers import as_runnable, is_traceable_function
|
|
|
|
|
from langsmith.schemas import Dataset, DataType, Example
|
|
|
|
|
from langsmith.schemas import Dataset, DataType, Example, TracerSession
|
|
|
|
|
from langsmith.utils import LangSmithError
|
|
|
|
|
from requests import HTTPError
|
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
@ -919,9 +922,12 @@ def _prepare_eval_run(
|
|
|
|
|
project_name: str,
|
|
|
|
|
project_metadata: Optional[Dict[str, Any]] = None,
|
|
|
|
|
tags: Optional[List[str]] = None,
|
|
|
|
|
) -> Tuple[MCF, str, Dataset, List[Example]]:
|
|
|
|
|
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
|
|
|
|
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
|
|
|
|
dataset = client.read_dataset(dataset_name=dataset_name)
|
|
|
|
|
examples = list(client.list_examples(dataset_id=dataset.id))
|
|
|
|
|
if not examples:
|
|
|
|
|
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
project_extra: dict = {"metadata": project_metadata} if project_metadata else {}
|
|
|
|
@ -953,111 +959,159 @@ run_on_dataset(
|
|
|
|
|
f"View all tests for Dataset {dataset_name} at:\n{dataset.url}",
|
|
|
|
|
flush=True,
|
|
|
|
|
)
|
|
|
|
|
examples = list(client.list_examples(dataset_id=dataset.id))
|
|
|
|
|
if not examples:
|
|
|
|
|
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
|
|
|
|
return wrapped_model, project_name, dataset, examples
|
|
|
|
|
return wrapped_model, project, dataset, examples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_run_on_dataset(
|
|
|
|
|
client: Client,
|
|
|
|
|
dataset_name: str,
|
|
|
|
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
|
|
|
project_name: Optional[str],
|
|
|
|
|
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
|
|
|
|
tags: Optional[List[str]] = None,
|
|
|
|
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
|
|
|
|
concurrency_level: int = 5,
|
|
|
|
|
project_metadata: Optional[Dict[str, Any]] = None,
|
|
|
|
|
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
|
|
|
|
|
project_name = project_name or name_generation.random_name()
|
|
|
|
|
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
|
|
|
|
client,
|
|
|
|
|
dataset_name,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
|
project_name,
|
|
|
|
|
project_metadata=project_metadata,
|
|
|
|
|
tags=tags,
|
|
|
|
|
)
|
|
|
|
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
|
|
|
|
run_evaluators = _setup_evaluation(
|
|
|
|
|
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
|
|
|
|
|
)
|
|
|
|
|
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
|
|
|
|
progress_bar = progress.ProgressBarCallback(len(examples))
|
|
|
|
|
configs = [
|
|
|
|
|
RunnableConfig(
|
|
|
|
|
callbacks=[
|
|
|
|
|
LangChainTracer(
|
|
|
|
|
project_name=project_name,
|
|
|
|
|
client=client,
|
|
|
|
|
use_threading=False,
|
|
|
|
|
example_id=example.id,
|
|
|
|
|
),
|
|
|
|
|
EvaluatorCallbackHandler(
|
|
|
|
|
evaluators=run_evaluators or [],
|
|
|
|
|
client=client,
|
|
|
|
|
example_id=example.id,
|
|
|
|
|
max_concurrency=0,
|
|
|
|
|
),
|
|
|
|
|
progress_bar,
|
|
|
|
|
],
|
|
|
|
|
tags=tags or [],
|
|
|
|
|
max_concurrency=concurrency_level,
|
|
|
|
|
)
|
|
|
|
|
for example in examples
|
|
|
|
|
]
|
|
|
|
|
return wrapped_model, project_name, examples, configs
|
|
|
|
|
class _RowResult(TypedDict, total=False):
|
|
|
|
|
"""A dictionary of the results for a single example row."""
|
|
|
|
|
|
|
|
|
|
feedback: Optional[List[EvaluationResult]]
|
|
|
|
|
execution_time: Optional[float]
|
|
|
|
|
run_id: Optional[str]
|
|
|
|
|
|
|
|
|
|
def _collect_test_results(
|
|
|
|
|
examples: List[Example],
|
|
|
|
|
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
|
|
|
|
|
configs: List[RunnableConfig],
|
|
|
|
|
project_name: str,
|
|
|
|
|
) -> TestResult:
|
|
|
|
|
wait_for_all_evaluators()
|
|
|
|
|
all_eval_results = {}
|
|
|
|
|
all_execution_time = {}
|
|
|
|
|
all_run_ids = {}
|
|
|
|
|
for c in configs:
|
|
|
|
|
for callback in cast(list, c["callbacks"]):
|
|
|
|
|
if isinstance(callback, EvaluatorCallbackHandler):
|
|
|
|
|
eval_results = callback.logged_eval_results
|
|
|
|
|
all_eval_results.update(
|
|
|
|
|
{example_id: v for (_, example_id), v in eval_results.items()}
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(callback, LangChainTracer):
|
|
|
|
|
run = callback.latest_run
|
|
|
|
|
example_id = callback.example_id
|
|
|
|
|
run_id = str(run.id) if run else None
|
|
|
|
|
execution_time = (
|
|
|
|
|
(run.end_time - run.start_time).total_seconds()
|
|
|
|
|
if run and run.end_time
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
all_execution_time[str(example_id)] = execution_time
|
|
|
|
|
all_run_ids[str(example_id)] = run_id
|
|
|
|
|
|
|
|
|
|
results: dict = {}
|
|
|
|
|
for example, output in zip(examples, batch_results):
|
|
|
|
|
feedback = all_eval_results.get(str(example.id), [])
|
|
|
|
|
results[str(example.id)] = {
|
|
|
|
|
"input": example.inputs,
|
|
|
|
|
"feedback": feedback,
|
|
|
|
|
"execution_time": all_execution_time.get(str(example.id)),
|
|
|
|
|
"run_id": all_run_ids.get(str(example.id)),
|
|
|
|
|
}
|
|
|
|
|
if isinstance(output, EvalError):
|
|
|
|
|
results[str(example.id)]["Error"] = output.Error
|
|
|
|
|
else:
|
|
|
|
|
results[str(example.id)]["output"] = output
|
|
|
|
|
if example.outputs:
|
|
|
|
|
results[str(example.id)]["reference"] = example.outputs
|
|
|
|
|
return TestResult(
|
|
|
|
|
project_name=project_name,
|
|
|
|
|
results=results,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class _DatasetRunContainer:
|
|
|
|
|
"""A container to help manage the state of a eval run."""
|
|
|
|
|
|
|
|
|
|
client: Client
|
|
|
|
|
project: TracerSession
|
|
|
|
|
wrapped_model: MCF
|
|
|
|
|
examples: List[Example]
|
|
|
|
|
configs: List[RunnableConfig]
|
|
|
|
|
|
|
|
|
|
def _merge_test_outputs(
|
|
|
|
|
self,
|
|
|
|
|
batch_results: list,
|
|
|
|
|
all_eval_results: Dict[str, _RowResult],
|
|
|
|
|
) -> dict:
|
|
|
|
|
results: dict = {}
|
|
|
|
|
for example, output in zip(self.examples, batch_results):
|
|
|
|
|
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {}))
|
|
|
|
|
results[str(example.id)] = {
|
|
|
|
|
"input": example.inputs,
|
|
|
|
|
"feedback": row_result.get("feedback", []),
|
|
|
|
|
"execution_time": row_result.get("execution_time"),
|
|
|
|
|
"run_id": row_result.get("run_id"),
|
|
|
|
|
}
|
|
|
|
|
if isinstance(output, EvalError):
|
|
|
|
|
results[str(example.id)]["Error"] = output.Error
|
|
|
|
|
else:
|
|
|
|
|
results[str(example.id)]["output"] = output
|
|
|
|
|
if example.outputs:
|
|
|
|
|
results[str(example.id)]["reference"] = example.outputs
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
def _collect_metrics(self) -> Dict[str, _RowResult]:
|
|
|
|
|
all_eval_results: dict = {}
|
|
|
|
|
for c in self.configs:
|
|
|
|
|
for callback in cast(list, c["callbacks"]):
|
|
|
|
|
if isinstance(callback, EvaluatorCallbackHandler):
|
|
|
|
|
eval_results = callback.logged_eval_results
|
|
|
|
|
for (_, example_id), v in eval_results.items():
|
|
|
|
|
all_eval_results.setdefault(str(example_id), {}).update(
|
|
|
|
|
{"feedback": v}
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(callback, LangChainTracer):
|
|
|
|
|
run = callback.latest_run
|
|
|
|
|
execution_time = (
|
|
|
|
|
(run.end_time - run.start_time).total_seconds()
|
|
|
|
|
if run and run.end_time
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
run_id = str(run.id) if run else None
|
|
|
|
|
all_eval_results.setdefault(str(callback.example_id), {}).update(
|
|
|
|
|
{
|
|
|
|
|
"execution_time": execution_time,
|
|
|
|
|
"run_id": run_id,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return cast(Dict[str, _RowResult], all_eval_results)
|
|
|
|
|
|
|
|
|
|
def _collect_test_results(
|
|
|
|
|
self,
|
|
|
|
|
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
|
|
|
|
|
) -> TestResult:
|
|
|
|
|
wait_for_all_evaluators()
|
|
|
|
|
all_eval_results = self._collect_metrics()
|
|
|
|
|
results = self._merge_test_outputs(batch_results, all_eval_results)
|
|
|
|
|
return TestResult(
|
|
|
|
|
project_name=self.project.name,
|
|
|
|
|
results=results,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def finish(self, batch_results: list, verbose: bool = False) -> TestResult:
|
|
|
|
|
results = self._collect_test_results(batch_results)
|
|
|
|
|
if verbose:
|
|
|
|
|
try:
|
|
|
|
|
agg_feedback = results.get_aggregate_feedback()
|
|
|
|
|
_display_aggregate_results(agg_feedback)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
|
|
|
|
try:
|
|
|
|
|
# Closing the project permits name changing and metric optimizations
|
|
|
|
|
self.client.update_project(self.project.id, end_time=datetime.utcnow())
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Failed to close project: {repr(e)}")
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def prepare(
|
|
|
|
|
cls,
|
|
|
|
|
client: Client,
|
|
|
|
|
dataset_name: str,
|
|
|
|
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
|
|
|
project_name: Optional[str],
|
|
|
|
|
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
|
|
|
|
tags: Optional[List[str]] = None,
|
|
|
|
|
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
|
|
|
|
concurrency_level: int = 5,
|
|
|
|
|
project_metadata: Optional[Dict[str, Any]] = None,
|
|
|
|
|
) -> _DatasetRunContainer:
|
|
|
|
|
project_name = project_name or name_generation.random_name()
|
|
|
|
|
wrapped_model, project, dataset, examples = _prepare_eval_run(
|
|
|
|
|
client,
|
|
|
|
|
dataset_name,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
|
project_name,
|
|
|
|
|
project_metadata=project_metadata,
|
|
|
|
|
tags=tags,
|
|
|
|
|
)
|
|
|
|
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
|
|
|
|
run_evaluators = _setup_evaluation(
|
|
|
|
|
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
|
|
|
|
|
)
|
|
|
|
|
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
|
|
|
|
progress_bar = progress.ProgressBarCallback(len(examples))
|
|
|
|
|
configs = [
|
|
|
|
|
RunnableConfig(
|
|
|
|
|
callbacks=[
|
|
|
|
|
LangChainTracer(
|
|
|
|
|
project_name=project.name,
|
|
|
|
|
client=client,
|
|
|
|
|
use_threading=False,
|
|
|
|
|
example_id=example.id,
|
|
|
|
|
),
|
|
|
|
|
EvaluatorCallbackHandler(
|
|
|
|
|
evaluators=run_evaluators or [],
|
|
|
|
|
client=client,
|
|
|
|
|
example_id=example.id,
|
|
|
|
|
max_concurrency=0,
|
|
|
|
|
),
|
|
|
|
|
progress_bar,
|
|
|
|
|
],
|
|
|
|
|
tags=tags or [],
|
|
|
|
|
max_concurrency=concurrency_level,
|
|
|
|
|
)
|
|
|
|
|
for example in examples
|
|
|
|
|
]
|
|
|
|
|
return cls(
|
|
|
|
|
client=client,
|
|
|
|
|
project=project,
|
|
|
|
|
wrapped_model=wrapped_model,
|
|
|
|
|
examples=examples,
|
|
|
|
|
configs=configs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_jupyter_environment() -> bool:
|
|
|
|
@ -1125,7 +1179,7 @@ async def arun_on_dataset(
|
|
|
|
|
removal="0.0.305",
|
|
|
|
|
)
|
|
|
|
|
client = client or Client()
|
|
|
|
|
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
|
|
|
|
container = _DatasetRunContainer.prepare(
|
|
|
|
|
client,
|
|
|
|
|
dataset_name,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
@ -1137,26 +1191,18 @@ async def arun_on_dataset(
|
|
|
|
|
project_metadata=project_metadata,
|
|
|
|
|
)
|
|
|
|
|
batch_results = await runnable_utils.gather_with_concurrency(
|
|
|
|
|
configs[0].get("max_concurrency"),
|
|
|
|
|
container.configs[0].get("max_concurrency"),
|
|
|
|
|
*map(
|
|
|
|
|
functools.partial(
|
|
|
|
|
_arun_llm_or_chain,
|
|
|
|
|
llm_or_chain_factory=wrapped_model,
|
|
|
|
|
llm_or_chain_factory=container.wrapped_model,
|
|
|
|
|
input_mapper=input_mapper,
|
|
|
|
|
),
|
|
|
|
|
examples,
|
|
|
|
|
configs,
|
|
|
|
|
container.examples,
|
|
|
|
|
container.configs,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
results = _collect_test_results(examples, batch_results, configs, project_name)
|
|
|
|
|
if verbose:
|
|
|
|
|
try:
|
|
|
|
|
agg_feedback = results.get_aggregate_feedback()
|
|
|
|
|
print("\n Eval quantiles:")
|
|
|
|
|
print(agg_feedback)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
|
|
|
|
return results
|
|
|
|
|
return container.finish(batch_results, verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_on_dataset(
|
|
|
|
@ -1185,7 +1231,7 @@ def run_on_dataset(
|
|
|
|
|
removal="0.0.305",
|
|
|
|
|
)
|
|
|
|
|
client = client or Client()
|
|
|
|
|
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
|
|
|
|
container = _DatasetRunContainer.prepare(
|
|
|
|
|
client,
|
|
|
|
|
dataset_name,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
@ -1201,33 +1247,26 @@ def run_on_dataset(
|
|
|
|
|
_run_llm_or_chain(
|
|
|
|
|
example,
|
|
|
|
|
config,
|
|
|
|
|
llm_or_chain_factory=wrapped_model,
|
|
|
|
|
llm_or_chain_factory=container.wrapped_model,
|
|
|
|
|
input_mapper=input_mapper,
|
|
|
|
|
)
|
|
|
|
|
for example, config in zip(examples, configs)
|
|
|
|
|
for example, config in zip(container.examples, container.configs)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
with runnable_config.get_executor_for_config(configs[0]) as executor:
|
|
|
|
|
with runnable_config.get_executor_for_config(container.configs[0]) as executor:
|
|
|
|
|
batch_results = list(
|
|
|
|
|
executor.map(
|
|
|
|
|
functools.partial(
|
|
|
|
|
_run_llm_or_chain,
|
|
|
|
|
llm_or_chain_factory=wrapped_model,
|
|
|
|
|
llm_or_chain_factory=container.wrapped_model,
|
|
|
|
|
input_mapper=input_mapper,
|
|
|
|
|
),
|
|
|
|
|
examples,
|
|
|
|
|
configs,
|
|
|
|
|
container.examples,
|
|
|
|
|
container.configs,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
results = _collect_test_results(examples, batch_results, configs, project_name)
|
|
|
|
|
if verbose:
|
|
|
|
|
try:
|
|
|
|
|
agg_feedback = results.get_aggregate_feedback()
|
|
|
|
|
_display_aggregate_results(agg_feedback)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
|
|
|
|
return results
|
|
|
|
|
return container.finish(batch_results, verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_RUN_ON_DATASET_DOCSTRING = """
|
|
|
|
|