Add progress bar + runner fixes (#10348)

- Add progress bar to eval runs
- Use thread pool for concurrency
- Update some error messages
- Friendlier project name
- Print out quantiles of the final stats 

Closes LS-902
pull/10371/head
William FH 1 year ago committed by GitHub
parent 0672533b3e
commit 46e9abdc75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,29 +2,20 @@
from __future__ import annotations
import logging
from concurrent.futures import Future, ThreadPoolExecutor, wait
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from uuid import UUID
import langsmith
from langsmith import schemas as langsmith_schemas
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks import manager
from langchain.callbacks.tracers import langchain as langchain_tracer
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.langchain import _get_client
from langchain.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__)
_TRACERS: List[EvaluatorCallbackHandler] = []
def wait_for_all_evaluators() -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in _TRACERS:
tracer.wait_for_futures()
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
@ -79,17 +70,13 @@ class EvaluatorCallbackHandler(BaseTracer):
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or _get_client()
self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators
self.executor = ThreadPoolExecutor(
max_workers=max(max_workers or len(evaluators), 1)
)
self.max_workers = max_workers or len(evaluators)
self.futures: Set[Future] = set()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
global _TRACERS
_TRACERS.append(self)
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
"""Evaluate the run in the project.
@ -105,7 +92,7 @@ class EvaluatorCallbackHandler(BaseTracer):
try:
if self.project_name is None:
feedback = self.client.evaluate_run(run, evaluator)
with tracing_v2_enabled(
with manager.tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
):
feedback = self.client.evaluate_run(run, evaluator)
@ -133,14 +120,15 @@ class EvaluatorCallbackHandler(BaseTracer):
return
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""
futures = list(self.futures)
wait(futures)
for future in futures:
self.futures.remove(future)
if self.max_workers > 0:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
list(
executor.map(
self._evaluate_in_project,
[run_ for _ in range(len(self.evaluators))],
self.evaluators,
)
)
else:
for evaluator in self.evaluators:
self._evaluate_in_project(run_, evaluator)

@ -42,7 +42,7 @@ def wait_for_all_tracers() -> None:
tracer.wait_for_futures()
def _get_client() -> Client:
def get_client() -> Client:
"""Get the client."""
global _CLIENT
if _CLIENT is None:
@ -83,7 +83,7 @@ class LangChainTracer(BaseTracer):
_EXECUTORS.append(self.executor)
else:
self.executor = None
self.client = client or _get_client()
self.client = client or get_client()
self._futures: Set[Future] = set()
self.tags = tags or []
global _TRACERS

@ -0,0 +1,729 @@
import random
adjectives = [
"abandoned",
"aching",
"advanced",
"ample",
"artistic",
"back",
"best",
"bold",
"brief",
"clear",
"cold",
"complicated",
"cooked",
"crazy",
"crushing",
"damp",
"dear",
"definite",
"dependable",
"diligent",
"drab",
"earnest",
"elderly",
"enchanted",
"essential",
"excellent",
"extraneous",
"fixed",
"flowery",
"formal",
"fresh",
"frosty",
"giving",
"glossy",
"healthy",
"helpful",
"impressionable",
"kind",
"large",
"left",
"long",
"loyal",
"mealy",
"memorable",
"monthly",
"new",
"notable",
"only",
"ordinary",
"passionate",
"perfect",
"pertinent",
"proper",
"puzzled",
"reflecting",
"respectful",
"roasted",
"scholarly",
"shiny",
"slight",
"sparkling",
"spotless",
"stupendous",
"sunny",
"tart",
"terrific",
"timely",
"unique",
"upbeat",
"vacant",
"virtual",
"warm",
"weary",
"whispered",
"worthwhile",
"yellow",
]
nouns = [
"account",
"acknowledgment",
"address",
"advertising",
"airplane",
"animal",
"appointment",
"arrival",
"artist",
"attachment",
"attitude",
"availability",
"backpack",
"bag",
"balance",
"bass",
"bean",
"beauty",
"bibliography",
"bill",
"bite",
"blossom",
"boat",
"book",
"box",
"boy",
"bread",
"bridge",
"broccoli",
"building",
"butter",
"button",
"cabbage",
"cake",
"camera",
"camp",
"candle",
"candy",
"canvas",
"car",
"card",
"carrot",
"cart",
"case",
"cat",
"chain",
"chair",
"chalk",
"chance",
"change",
"channel",
"character",
"charge",
"charm",
"chart",
"check",
"cheek",
"cheese",
"chef",
"cherry",
"chicken",
"child",
"church",
"circle",
"class",
"clay",
"click",
"clock",
"cloth",
"cloud",
"clove",
"club",
"coach",
"coal",
"coast",
"coat",
"cod",
"coffee",
"collar",
"color",
"comb",
"comfort",
"comic",
"committee",
"community",
"company",
"comparison",
"competition",
"condition",
"connection",
"control",
"cook",
"copper",
"copy",
"corn",
"cough",
"country",
"cover",
"crate",
"crayon",
"cream",
"creator",
"crew",
"crown",
"current",
"curtain",
"curve",
"cushion",
"dad",
"daughter",
"day",
"death",
"debt",
"decision",
"deer",
"degree",
"design",
"desire",
"desk",
"detail",
"development",
"digestion",
"dime",
"dinner",
"direction",
"dirt",
"discovery",
"discussion",
"disease",
"disgust",
"distance",
"distribution",
"division",
"doctor",
"dog",
"door",
"drain",
"drawer",
"dress",
"drink",
"driving",
"dust",
"ear",
"earth",
"edge",
"education",
"effect",
"egg",
"end",
"energy",
"engine",
"error",
"event",
"example",
"exchange",
"existence",
"expansion",
"experience",
"expert",
"eye",
"face",
"fact",
"fall",
"family",
"farm",
"father",
"fear",
"feeling",
"field",
"finger",
"fire",
"fish",
"flag",
"flight",
"floor",
"flower",
"fold",
"food",
"football",
"force",
"form",
"frame",
"friend",
"frog",
"fruit",
"fuel",
"furniture",
"game",
"garden",
"gate",
"girl",
"glass",
"glove",
"goat",
"gold",
"government",
"grade",
"grain",
"grass",
"green",
"grip",
"group",
"growth",
"guide",
"guitar",
"hair",
"hall",
"hand",
"harbor",
"harmony",
"hat",
"head",
"health",
"heart",
"heat",
"hill",
"history",
"hobbies",
"hole",
"hope",
"horn",
"horse",
"hospital",
"hour",
"house",
"humor",
"idea",
"impulse",
"income",
"increase",
"industry",
"ink",
"insect",
"instrument",
"insurance",
"interest",
"invention",
"iron",
"island",
"jelly",
"jet",
"jewel",
"join",
"judge",
"juice",
"jump",
"kettle",
"key",
"kick",
"kiss",
"kitten",
"knee",
"knife",
"knowledge",
"land",
"language",
"laugh",
"law",
"lead",
"learning",
"leather",
"leg",
"lettuce",
"level",
"library",
"lift",
"light",
"limit",
"line",
"linen",
"lip",
"liquid",
"list",
"look",
"loss",
"love",
"lunch",
"machine",
"man",
"manager",
"map",
"marble",
"mark",
"market",
"mass",
"match",
"meal",
"measure",
"meat",
"meeting",
"memory",
"metal",
"middle",
"milk",
"mind",
"mine",
"minute",
"mist",
"mitten",
"mom",
"money",
"monkey",
"month",
"moon",
"morning",
"mother",
"motion",
"mountain",
"mouth",
"muscle",
"music",
"nail",
"name",
"nation",
"neck",
"need",
"news",
"night",
"noise",
"note",
"number",
"nut",
"observation",
"offer",
"oil",
"operation",
"opinion",
"orange",
"order",
"organization",
"ornament",
"oven",
"page",
"pail",
"pain",
"paint",
"pan",
"pancake",
"paper",
"parcel",
"parent",
"part",
"passenger",
"paste",
"payment",
"peace",
"pear",
"pen",
"pencil",
"person",
"pest",
"pet",
"picture",
"pie",
"pin",
"pipe",
"pizza",
"place",
"plane",
"plant",
"plastic",
"plate",
"play",
"pleasure",
"plot",
"plough",
"pocket",
"point",
"poison",
"police",
"pollution",
"popcorn",
"porter",
"position",
"pot",
"potato",
"powder",
"power",
"price",
"print",
"process",
"produce",
"product",
"profit",
"property",
"prose",
"protest",
"pull",
"pump",
"punishment",
"purpose",
"push",
"quarter",
"question",
"quiet",
"quill",
"quilt",
"quince",
"rabbit",
"rail",
"rain",
"range",
"rat",
"rate",
"ray",
"reaction",
"reading",
"reason",
"record",
"regret",
"relation",
"religion",
"representative",
"request",
"respect",
"rest",
"reward",
"rhythm",
"rice",
"river",
"road",
"roll",
"room",
"root",
"rose",
"route",
"rub",
"rule",
"run",
"sack",
"sail",
"salt",
"sand",
"scale",
"scarecrow",
"scarf",
"scene",
"scent",
"school",
"science",
"scissors",
"screw",
"sea",
"seat",
"secretary",
"seed",
"selection",
"self",
"sense",
"servant",
"shade",
"shake",
"shame",
"shape",
"sheep",
"sheet",
"shelf",
"ship",
"shirt",
"shock",
"shoe",
"shop",
"show",
"side",
"sign",
"silk",
"sink",
"sister",
"size",
"sky",
"slave",
"sleep",
"smash",
"smell",
"smile",
"smoke",
"snail",
"snake",
"sneeze",
"snow",
"soap",
"society",
"sock",
"soda",
"sofa",
"son",
"song",
"sort",
"sound",
"soup",
"space",
"spark",
"speed",
"sponge",
"spoon",
"spray",
"spring",
"spy",
"square",
"stamp",
"star",
"start",
"statement",
"station",
"steam",
"steel",
"stem",
"step",
"stew",
"stick",
"stitch",
"stocking",
"stomach",
"stone",
"stop",
"store",
"story",
"stove",
"stranger",
"straw",
"stream",
"street",
"stretch",
"string",
"structure",
"substance",
"sugar",
"suggestion",
"suit",
"summer",
"sun",
"support",
"surprise",
"sweater",
"swim",
"system",
"table",
"tail",
"talk",
"tank",
"taste",
"tax",
"tea",
"teaching",
"team",
"tendency",
"test",
"texture",
"theory",
"thing",
"thought",
"thread",
"throat",
"thumb",
"thunder",
"ticket",
"time",
"tin",
"title",
"toad",
"toe",
"tooth",
"toothpaste",
"touch",
"town",
"toy",
"trade",
"train",
"transport",
"tray",
"treatment",
"tree",
"trick",
"trip",
"trouble",
"trousers",
"truck",
"tub",
"turkey",
"turn",
"twist",
"umbrella",
"uncle",
"underwear",
"unit",
"use",
"vacation",
"value",
"van",
"vase",
"vegetable",
"veil",
"vein",
"verse",
"vessel",
"view",
"visitor",
"voice",
"volcano",
"walk",
"wall",
"war",
"wash",
"waste",
"watch",
"water",
"wave",
"wax",
"way",
"wealth",
"weather",
"week",
"weight",
"wheel",
"whip",
"whistle",
"window",
"wine",
"wing",
"winter",
"wire",
"wish",
"woman",
"wood",
"wool",
"word",
"work",
"worm",
"wound",
"wrist",
"writer",
"yard",
"yoke",
"zebra",
"zinc",
"zipper",
"zone",
]
def random_name(prefix: str = "test") -> str:
"""Generate a random name."""
adjective = random.choice(adjectives)
noun = random.choice(nouns)
number = random.randint(1, 100)
return f"{prefix}-{adjective}-{noun}-{number}"

@ -0,0 +1,82 @@
"""A simple progress bar for the console."""
import threading
from typing import Any, Dict, Optional, Sequence
from uuid import UUID
from langchain.callbacks import base as base_callbacks
from langchain.schema.document import Document
from langchain.schema.output import LLMResult
class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
"""A simple progress bar for the console."""
def __init__(self, total: int, ncols: int = 50, **kwargs: Any):
"""Initialize the progress bar.
Args:
total: int, the total number of items to be processed.
ncols: int, the character width of the progress bar.
"""
self.total = total
self.ncols = ncols
self.counter = 0
self.lock = threading.Lock()
self._print_bar()
def increment(self) -> None:
"""Increment the counter and update the progress bar."""
with self.lock:
self.counter += 1
self._print_bar()
def _print_bar(self) -> None:
"""Print the progress bar to the console."""
progress = self.counter / self.total
arrow = "-" * int(round(progress * self.ncols) - 1) + ">"
spaces = " " * (self.ncols - len(arrow))
print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="")
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()

@ -2,21 +2,16 @@
from __future__ import annotations
import asyncio
import functools
import inspect
import itertools
import logging
import uuid
import warnings
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterator,
List,
Optional,
Sequence,
@ -24,16 +19,13 @@ from typing import (
Union,
cast,
)
from urllib.parse import urlparse, urlunparse
from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
from langchain.chains.base import Chain
from langchain.evaluation.loading import load_evaluator
from langchain.evaluation.schema import EvaluatorType, StringEvaluator
@ -41,8 +33,11 @@ from langchain.schema import ChatResult, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, messages_from_dict
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
from langchain.schema.runnable import config as runnable_config
from langchain.schema.runnable import utils as runnable_utils
from langchain.smith import evaluation as smith_eval
from langchain.smith.evaluation import config as smith_eval_config
from langchain.smith.evaluation import name_generation, progress
if TYPE_CHECKING:
import pandas as pd
@ -69,6 +64,26 @@ class InputFormatError(Exception):
class TestResult(dict):
"""A dictionary of the results of a single test run."""
def get_aggregate_feedback(
self, quantiles: Optional[Sequence[float]] = None
) -> pd.DataFrame:
"""Return quantiles for the feedback scores.
This method calculates and prints the quantiles for the feedback scores
across all feedback keys.
Returns:
A DataFrame containing the quantiles for each feedback key.
"""
df = self.to_dataframe()
feedback_cols = [
col for col in df.columns if col not in ["input", "output", "reference"]
]
_quantiles = df[feedback_cols].quantile(
quantiles or [0.25, 0.5, 0.75], numeric_only=True
)
return _quantiles.transpose()
def to_dataframe(self) -> pd.DataFrame:
"""Convert the results to a dataframe."""
try:
@ -83,27 +98,19 @@ class TestResult(dict):
records = []
for example_id, result in self["results"].items():
feedback = result["feedback"]
records.append(
{**{f.key: f.score for f in feedback}, "output": result["output"]}
)
r = {
**{f.key: f.score for f in feedback},
"input": result["input"],
"output": result["output"],
}
if "reference" in result:
r["reference"] = result["reference"]
records.append(r)
indices.append(example_id)
return pd.DataFrame(records, index=indices)
def _get_eval_project_url(api_url: str, project_id: str) -> str:
"""Get the project url from the api url."""
parsed = urlparse(api_url)
hostname = parsed.hostname or ""
if "api." in hostname:
hostname = hostname.replace("api.", "", 1)
if "localhost" in hostname:
# Remove the port
hostname = "localhost"
url = urlunparse(parsed._replace(netloc=hostname))
return f"{url}/projects/p/{project_id}?eval=true"
def _wrap_in_chain_factory(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str = "<my_dataset>",
@ -172,15 +179,6 @@ def _wrap_in_chain_factory(
return llm_or_chain_factory
def _first_example(examples: Iterator[Example]) -> Tuple[Example, Iterator[Example]]:
"""Get the first example while chaining it back and preserving the iterator."""
try:
example: Example = next(examples)
except StopIteration:
raise ValueError("No examples provided.")
return example, itertools.chain([example], examples)
def _get_prompt(inputs: Dict[str, Any]) -> str:
"""Get prompt from inputs.
@ -277,31 +275,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
)
def _get_project_name(
project_name: Optional[str],
llm_or_chain_factory: MCF,
) -> str:
"""
Get the project name.
Args:
project_name: The project name if manually specified.
llm_or_chain_factory: The Chain or language model constructor.
Returns:
The project name.
"""
if project_name is not None:
return project_name
if isinstance(llm_or_chain_factory, BaseLanguageModel):
model_name = llm_or_chain_factory.__class__.__name__
else:
model_name = llm_or_chain_factory().__class__.__name__
hex = uuid.uuid4().hex
return f"{hex}-{model_name}"
## Shared Validation Utilities
## Shared data validation utilities
def _validate_example_inputs_for_language_model(
first_example: Example,
input_mapper: Optional[Callable[[Dict], Any]],
@ -373,22 +347,20 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs(
examples: Iterator[Example],
example: Example,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]],
) -> Iterator[Example]:
) -> None:
"""Validate that the example inputs are valid for the model."""
first_example, examples = _first_example(examples)
if isinstance(llm_or_chain_factory, BaseLanguageModel):
_validate_example_inputs_for_language_model(first_example, input_mapper)
_validate_example_inputs_for_language_model(example, input_mapper)
else:
chain = llm_or_chain_factory()
if isinstance(chain, Chain):
# Otherwise it's a runnable
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
_validate_example_inputs_for_chain(example, chain, input_mapper)
elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}")
return examples
## Shared Evaluator Setup Utilities
@ -396,13 +368,12 @@ def _validate_example_inputs(
def _setup_evaluation(
llm_or_chain_factory: MCF,
examples: Iterator[Example],
evaluation: Optional[RunEvalConfig],
examples: List[Example],
evaluation: Optional[smith_eval.RunEvalConfig],
data_type: DataType,
) -> Tuple[Optional[List[RunEvaluator]], Iterator[Example]]:
) -> Optional[List[RunEvaluator]]:
"""Configure the evaluators to run on the results of the chain."""
if evaluation:
first_example, examples = _first_example(examples)
if isinstance(llm_or_chain_factory, BaseLanguageModel):
run_inputs, run_outputs = None, None
run_type = "llm"
@ -422,18 +393,18 @@ def _setup_evaluation(
evaluation,
run_type,
data_type,
list(first_example.outputs) if first_example.outputs else None,
list(examples[0].outputs) if examples[0].outputs else None,
run_inputs,
run_outputs,
)
else:
# TODO: Create a default helpfulness evaluator
run_evaluators = None
return run_evaluators, examples
return run_evaluators
def _determine_input_key(
config: RunEvalConfig,
config: smith_eval.RunEvalConfig,
run_inputs: Optional[List[str]],
) -> Optional[str]:
input_key = None
@ -452,7 +423,7 @@ def _determine_input_key(
def _determine_prediction_key(
config: RunEvalConfig,
config: smith_eval.RunEvalConfig,
run_outputs: Optional[List[str]],
) -> Optional[str]:
prediction_key = None
@ -473,7 +444,7 @@ def _determine_prediction_key(
def _determine_reference_key(
config: RunEvalConfig,
config: smith_eval.RunEvalConfig,
example_outputs: Optional[List[str]],
) -> Optional[str]:
if config.reference_key:
@ -491,7 +462,7 @@ def _determine_reference_key(
def _construct_run_evaluator(
eval_config: Union[EvaluatorType, str, EvalConfig],
eval_config: Union[EvaluatorType, str, smith_eval_config.EvalConfig],
eval_llm: Optional[BaseLanguageModel],
run_type: str,
data_type: DataType,
@ -513,11 +484,11 @@ def _construct_run_evaluator(
if isinstance(evaluator_, StringEvaluator):
if evaluator_.requires_reference and reference_key is None:
raise ValueError(
f"Must specify reference_key in RunEvalConfig to use"
f"Must specify reference_key in smith_eval.RunEvalConfig to use"
f" evaluator of type {eval_type_tag} with"
f" dataset with multiple output keys: {example_outputs}."
)
run_evaluator = StringRunEvaluatorChain.from_run_and_data_type(
run_evaluator = smith_eval.StringRunEvaluatorChain.from_run_and_data_type(
evaluator_,
run_type,
data_type,
@ -534,7 +505,7 @@ def _construct_run_evaluator(
def _get_keys(
config: RunEvalConfig,
config: smith_eval.RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
@ -546,7 +517,7 @@ def _get_keys(
def _load_run_evaluators(
config: RunEvalConfig,
config: smith_eval.RunEvalConfig,
run_type: str,
data_type: DataType,
example_outputs: Optional[List[str]],
@ -593,7 +564,7 @@ def _load_run_evaluators(
run_evaluators.append(custom_evaluator)
elif isinstance(custom_evaluator, StringEvaluator):
run_evaluators.append(
StringRunEvaluatorChain.from_run_and_data_type(
smith_eval.StringRunEvaluatorChain.from_run_and_data_type(
custom_evaluator,
run_type,
data_type,
@ -694,10 +665,9 @@ async def _arun_chain(
async def _arun_llm_or_chain(
example: Example,
llm_or_chain_factory: MCF,
config: RunnableConfig,
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""Asynchronously run the Chain or language model.
@ -712,15 +682,6 @@ async def _arun_llm_or_chain(
Returns:
A list of outputs.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else:
previous_example_ids = None
chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
)
@ -730,8 +691,8 @@ async def _arun_llm_or_chain(
output: Any = await _arun_llm(
llm_or_chain_factory,
example.inputs,
tags=tags,
callbacks=callbacks,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
else:
@ -739,198 +700,19 @@ async def _arun_llm_or_chain(
output = await _arun_chain(
chain,
example.inputs,
tags=tags,
callbacks=callbacks,
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
result = {"Error": str(e)}
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
return result
async def _gather_with_concurrency(
n: int,
initializer: Callable[[], Coroutine[Any, Any, Any]],
*async_funcs: Callable[
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
],
) -> List[Any]:
"""Run coroutines with a concurrency limit.
Args:
n: The maximum number of concurrent tasks.
initializer: A coroutine that initializes shared resources for the tasks.
async_funcs: The async_funcs to be run concurrently.
Returns:
A list of results from the coroutines.
"""
semaphore = asyncio.Semaphore(n)
job_state = {"num_processed": 0}
callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
for _ in range(n):
callback_queue.put_nowait(await initializer())
async def run_coroutine_with_semaphore(
async_func: Callable[
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
]
) -> Any:
async with semaphore:
callbacks = await callback_queue.get()
try:
result = await async_func(callbacks, job_state)
finally:
callback_queue.put_nowait(callbacks)
return result
results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
while callback_queue:
try:
callbacks = callback_queue.get_nowait()
except asyncio.QueueEmpty:
break
for callback in callbacks:
if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
callback.wait_for_futures()
return results
async def _callbacks_initializer(
project_name: Optional[str],
client: Client,
run_evaluators: Sequence[RunEvaluator],
evaluation_handler_collector: List[EvaluatorCallbackHandler],
) -> List[BaseTracer]:
"""
Initialize a tracer to share across tasks.
Args:
project_name: The project name for the tracer.
client: The client to use for the tracer.
run_evaluators: The evaluators to run.
evaluation_handler_collector: A list to collect the evaluators.
Used to wait for the evaluators to finish.
Returns:
The callbacks for this thread.
"""
callbacks: List[BaseTracer] = []
if project_name:
callbacks.append(
LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
)
if run_evaluators:
callback = EvaluatorCallbackHandler(
client=client,
evaluators=run_evaluators,
# We already have concurrency, don't want to overload the machine
max_workers=1,
)
callbacks.append(callback)
evaluation_handler_collector.append(callback)
return callbacks
async def _arun_on_examples(
client: Client,
examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
data_type: DataType = DataType.kv,
) -> Dict[str, Any]:
"""
Asynchronously run the chain on examples and store traces
to the specified project name.
Args:
client: LangSmith client to use to log feedback and runs.
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
concurrency_level: The number of async tasks to run concurrently.
project_name: Project name to use when tracing runs.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
data_type: The dataset's data type. This is used to determine determine
how to deserialize the reference data and model compatibility.
Returns:
A dictionary mapping example ids to the model outputs.
"""
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
run_evaluators, examples = _setup_evaluation(
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
results: Dict[str, dict] = {}
async def process_example(
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
) -> None:
"""Process a single example."""
result = await _arun_llm_or_chain(
example,
wrapped_model,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
logger.warning(
f"{chain_or_llm} failed for example {example.id} "
f"with inputs {example.inputs}"
f"\n{repr(e)}"
)
results[str(example.id)] = {"output": result}
job_state["num_processed"] += 1
if verbose:
print(
f"Processed examples: {job_state['num_processed']}",
end="\r",
flush=True,
)
evaluation_handlers: List[EvaluatorCallbackHandler] = []
await _gather_with_concurrency(
concurrency_level,
functools.partial(
_callbacks_initializer,
project_name=project_name,
client=client,
evaluation_handler_collector=evaluation_handlers,
run_evaluators=run_evaluators or [],
),
*(functools.partial(process_example, e) for e in examples),
)
all_feedback = {}
for handler in evaluation_handlers:
handler.wait_for_futures()
all_feedback.update(handler.logged_feedback)
# join the results and feedback on the example id
for example_id, output_dict in results.items():
feedback = all_feedback.get(example_id, [])
output_dict["feedback"] = feedback
return results
result = {"Error": repr(e)}
return result
## Sync Utilities
@ -1011,10 +793,9 @@ def _run_chain(
def _run_llm_or_chain(
example: Example,
llm_or_chain_factory: MCF,
config: RunnableConfig,
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str, LLMResult, ChatResult]:
"""
@ -1030,15 +811,6 @@ def _run_llm_or_chain(
Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
The outputs of the model or chain.
"""
if callbacks:
previous_example_ids = [
getattr(tracer, "example_id", None) for tracer in callbacks
]
for tracer in callbacks:
if hasattr(tracer, "example_id"):
tracer.example_id = example.id
else:
previous_example_ids = None
chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
)
@ -1048,8 +820,8 @@ def _run_llm_or_chain(
output: Any = _run_llm(
llm_or_chain_factory,
example.inputs,
callbacks,
tags=tags,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
else:
@ -1057,98 +829,22 @@ def _run_llm_or_chain(
output = _run_chain(
chain,
example.inputs,
callbacks,
tags=tags,
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
)
result = output
except Exception as e:
error_type = type(e).__name__
logger.warning(
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
f"{chain_or_llm} failed for example {example.id} "
f"with inputs {example.inputs}"
f"\nError Type: {error_type}, Message: {e}"
)
result = {"Error": str(e)}
if callbacks and previous_example_ids:
for example_id, tracer in zip(previous_example_ids, callbacks):
if hasattr(tracer, "example_id"):
tracer.example_id = example_id
result = {"Error": repr(e)}
return result
def _run_on_examples(
client: Client,
examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
project_name: Optional[str] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
data_type: DataType = DataType.kv,
) -> Dict[str, Any]:
"""
Run the Chain or language model on examples and store
traces to the specified project name.
Args:
client: LangSmith client to use to log feedback and runs.
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
evaluation: Optional evaluation configuration to use when evaluating
project_name: Name of the project to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
data_type: The dataset's data type. This is used to determine determine
how to deserialize the reference data and model compatibility.
Returns:
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, dict] = {}
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
run_evaluators, examples = _setup_evaluation(
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evaluation_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
)
callbacks: List[BaseCallbackHandler] = [tracer, evaluation_handler]
for i, example in enumerate(examples):
result = _run_llm_or_chain(
example,
wrapped_model,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = {"output": result}
tracer.wait_for_futures()
evaluation_handler.wait_for_futures()
all_feedback = evaluation_handler.logged_feedback
# join the results and feedback on the example id
for example_id, output_dict in results.items():
feedback = all_feedback.get(example_id, [])
output_dict["feedback"] = feedback
return results
## Public API
@ -1156,10 +852,9 @@ def _prepare_eval_run(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
project_name: str,
) -> Tuple[MCF, str, Dataset, List[Example]]:
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, wrapped_model)
try:
project = client.create_project(project_name)
except ValueError as e:
@ -1168,21 +863,95 @@ def _prepare_eval_run(
raise ValueError(
f"Project {project_name} already exists. Please use a different name."
)
project_url = _get_eval_project_url(client.api_url, project.id)
print(
f"View the evaluation results for project '{project_name}' at:\n{project_url}"
f"View the evaluation results for project '{project_name}' at:\n{project.url}"
)
dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id))
examples = list(client.list_examples(dataset_id=dataset.id))
if not examples:
raise ValueError(f"Dataset {dataset_name} has no example rows.")
return wrapped_model, project_name, dataset, examples
def _prepare_run_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type
)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
progress_bar = progress.ProgressBarCallback(len(examples))
configs = [
RunnableConfig(
callbacks=[
LangChainTracer(
project_name=project_name,
client=client,
use_threading=False,
example_id=example.id,
),
EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
max_workers=0,
example_id=example.id,
),
progress_bar,
],
tags=tags or [],
max_concurrency=concurrency_level,
)
for example in examples
]
return wrapped_model, project_name, examples, configs
def _collect_test_results(
examples: List[Example],
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
configs: List[RunnableConfig],
project_name: str,
) -> TestResult:
wait_for_all_tracers()
all_feedback = {}
for c in configs:
for callback in cast(list, c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler):
all_feedback.update(callback.logged_feedback)
results = {}
for example, output in zip(examples, batch_results):
feedback = all_feedback.get(str(example.id), [])
results[str(example.id)] = {
"output": output,
"input": example.inputs,
"feedback": feedback,
}
if example.outputs:
results[str(example.id)]["reference"] = example.outputs
return TestResult(
project_name=project_name,
results=results,
)
async def arun_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
@ -1227,7 +996,7 @@ async def arun_on_dataset(
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import RunEvalConfig, arun_on_dataset
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
# Chains may have memory. Passing in a constructor function lets the
# evaluation framework avoid cross-contamination between runs.
@ -1240,12 +1009,12 @@ async def arun_on_dataset(
return chain
# Load off-the-shelf evaluators via config or the EvaluatorType (string or enum)
evaluation_config = RunEvalConfig(
evaluation_config = smith_eval.RunEvalConfig(
evaluators=[
"qa", # "Correctness" against a reference answer
"embedding_distance",
RunEvalConfig.Criteria("helpfulness"),
RunEvalConfig.Criteria({
smith_eval.RunEvalConfig.Criteria("helpfulness"),
smith_eval.RunEvalConfig.Criteria({
"fifth-grader-score": "Do you have to be smarter than a fifth grader to answer this question?"
}),
]
@ -1286,7 +1055,7 @@ async def arun_on_dataset(
return {"score": prediction == reference}
evaluation_config = RunEvalConfig(
evaluation_config = smith_eval.RunEvalConfig(
custom_evaluators = [MyStringEvaluator()],
)
@ -1299,51 +1068,43 @@ async def arun_on_dataset(
""" # noqa: E501
if kwargs:
warnings.warn(
"The following arguments are deprecated and will "
"be removed in a future release: "
"The following arguments are deprecated and "
"will be removed in a future release: "
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
)
results = await _arun_on_examples(
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
examples,
wrapped_model,
concurrency_level=concurrency_level,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
return TestResult(
project_name=project_name,
results=results,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
tags,
input_mapper,
concurrency_level,
)
def _handle_coroutine(coro: Coroutine) -> Any:
"""
Handles a coroutine from a sync context.
Args:
coro (asyncio.coroutine): The coroutine to be handled.
Returns:
any: The result of the executed coroutine.
"""
# Check if there's a running event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError: # No event loop
return asyncio.run(coro)
if loop.is_running():
return loop.run_until_complete(coro)
else:
return asyncio.run(coro)
batch_results = await runnable_utils.gather_with_concurrency(
configs[0].get("max_concurrency"),
*map(
functools.partial(
_arun_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
),
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
print("\n Eval quantiles:")
print(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results
def run_on_dataset(
@ -1351,7 +1112,7 @@ def run_on_dataset(
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[RunEvalConfig] = None,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
verbose: bool = False,
@ -1397,7 +1158,7 @@ def run_on_dataset(
from langsmith import Client
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
# Chains may have memory. Passing in a constructor function lets the
# evaluation framework avoid cross-contamination between runs.
@ -1410,12 +1171,12 @@ def run_on_dataset(
return chain
# Load off-the-shelf evaluators via config or the EvaluatorType (string or enum)
evaluation_config = RunEvalConfig(
evaluation_config = smith_eval.RunEvalConfig(
evaluators=[
"qa", # "Correctness" against a reference answer
"embedding_distance",
RunEvalConfig.Criteria("helpfulness"),
RunEvalConfig.Criteria({
smith_eval.RunEvalConfig.Criteria("helpfulness"),
smith_eval.RunEvalConfig.Criteria({
"fifth-grader-score": "Do you have to be smarter than a fifth grader to answer this question?"
}),
]
@ -1456,7 +1217,7 @@ def run_on_dataset(
return {"score": prediction == reference}
evaluation_config = RunEvalConfig(
evaluation_config = smith_eval.RunEvalConfig(
custom_evaluators = [MyStringEvaluator()],
)
@ -1474,37 +1235,35 @@ def run_on_dataset(
f"{kwargs.keys()}.",
DeprecationWarning,
)
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
client,
dataset_name,
llm_or_chain_factory,
project_name,
evaluation,
tags,
input_mapper,
concurrency_level,
)
if concurrency_level in (0, 1):
results = _run_on_examples(
client,
examples,
wrapped_model,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
)
else:
# TODO: Use runnables and the batch method
coro = _arun_on_examples(
client,
examples,
wrapped_model,
concurrency_level=concurrency_level,
project_name=project_name,
verbose=verbose,
tags=tags,
evaluation=evaluation,
input_mapper=input_mapper,
data_type=dataset.data_type,
with runnable_config.get_executor_for_config(configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
)
)
results = _handle_coroutine(coro)
return TestResult(
project_name=project_name,
results=results,
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
print("\n Eval quantiles:")
print(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results

@ -148,13 +148,27 @@ class ChainStringRunMapper(StringRunMapper):
def map(self, run: Run) -> Dict[str, str]:
"""Maps the Run to a dictionary."""
if not run.outputs:
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
raise ValueError(
f"Run with ID {run.id} lacks outputs required for evaluation."
" Ensure the Run has valid outputs."
)
if self.input_key is not None and self.input_key not in run.inputs:
raise ValueError(f"Run {run.id} does not have input key {self.input_key}.")
raise ValueError(
f"Run with ID {run.id} is missing the expected input key"
f" '{self.input_key}'.\nAvailable input keys in this Run"
f" are: {run.inputs.keys()}.\nAdjust the evaluator's"
f" input_key or ensure your input data includes key"
f" '{self.input_key}'."
)
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
available_keys = ", ".join(run.outputs.keys())
raise ValueError(
f"Run {run.id} does not have prediction key {self.prediction_key}."
f"Run with ID {run.id} doesn't have the expected prediction key"
f" '{self.prediction_key}'. Available prediction keys in this Run are:"
f" {available_keys}. Adjust the evaluator's prediction_key or"
" ensure the Run object's outputs the expected key."
)
else:
input_ = self._get_key(run.inputs, self.input_key, "input")
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")

@ -5,7 +5,6 @@ import pytest
from langsmith import Client as Client
from langsmith.schemas import DataType
from langchain.callbacks.tracers.evaluation import wait_for_all_evaluators
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.evaluation import EvaluatorType
@ -22,7 +21,6 @@ def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
# chain or llm passes for the feedback provided.
runs = list(client.list_runs(project_name=_project_name, execution_order=1))
assert len(runs) == 4
wait_for_all_evaluators()
feedback = list(client.list_feedback(run_ids=[run.id for run in runs]))
assert len(feedback) == 8
assert all([f.score == 1 for f in feedback])

@ -181,11 +181,15 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]}
result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
result = _run_llm_or_chain(
example,
{"callbacks": [], "tags": []},
llm_or_chain_factory=lambda: mock_chain,
input_mapper=input_mapper,
)
assert result == {"output": "2", "the right input": "1"}
bad_result = _run_llm_or_chain(
example,
lambda: mock_chain,
example, {"callbacks": [], "tags": []}, llm_or_chain_factory=lambda: mock_chain
)
assert "Error" in bad_result
@ -195,7 +199,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
return "the right input"
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
llm_result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
llm_result = _run_llm_or_chain(
example,
{"callbacks": [], "tags": []},
llm_or_chain_factory=mock_llm,
input_mapper=llm_input_mapper,
)
assert isinstance(llm_result, str)
assert llm_result == "somenumber"
@ -324,10 +333,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
)
expected = {
uuid_: {
"output": {"result": f"Result for example {uuid.UUID(uuid_)}"},
str(example.id): {
"output": {
"result": f"Result for example {uuid.UUID(str(example.id))}"
},
"input": {"input": example.inputs["input"]},
"reference": {"output": example.outputs["output"]},
"feedback": [],
}
for uuid_ in uuids
for example in examples
}
assert results["results"] == expected

Loading…
Cancel
Save