Support evaluating runnables and arbitrary functions (#8698)

Added a couple of "integration tests" for these that I ran.

Main design point of feedback: at this point, would it just be better to
have separate arguments for each type? Little confusing what is or isn't
supported and what is the intended usage at this point since I try to
wrap the function as runnable or pack or unpack chains/llms.

```
run_on_dataset(
...
llm_or_chain_factory = None,
llm = None,
chain = NOne,
runnable=None,
function=None
):
# raise error if none set
```

Downside with runnables and arbitrary function support is that you get
much less helpful validation and error messages, but I don't think we
should block you from this, at least.
This commit is contained in:
William FH 2023-08-04 16:39:04 -07:00 committed by GitHub
parent d00a247da7
commit c8f3615aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 44 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import itertools
import logging
import uuid
@ -19,6 +20,7 @@ from typing import (
Sequence,
Tuple,
Union,
cast,
)
from urllib.parse import urlparse, urlunparse
@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator
from langchain.schema import ChatResult, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, messages_from_dict
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
MODEL_OR_CHAIN_FACTORY = Union[
Callable[[], Union[Chain, Runnable]],
BaseLanguageModel,
Callable[[dict], Any],
Runnable,
Chain,
]
MCF = Union[Callable[[], Union[Chain, Runnable]], BaseLanguageModel]
class InputFormatError(Exception):
@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str:
def _wrap_in_chain_factory(
llm_or_chain_factory: Union[Chain, MODEL_OR_CHAIN_FACTORY],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str = "<my_dataset>",
) -> MODEL_OR_CHAIN_FACTORY:
) -> MCF:
"""Forgive the user if they pass in a chain without memory instead of a chain
factory. It's a common mistake. Raise a more helpful error message as well."""
if isinstance(llm_or_chain_factory, Chain):
@ -105,11 +115,31 @@ def _wrap_in_chain_factory(
return lambda: chain
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
return llm_or_chain_factory
elif isinstance(llm_or_chain_factory, Runnable):
# Memory may exist here, but it's not elegant to check all those cases.
lcf = llm_or_chain_factory
return lambda: lcf
elif callable(llm_or_chain_factory):
_model = llm_or_chain_factory()
try:
_model = llm_or_chain_factory() # type: ignore[call-arg]
except TypeError:
# It's an arbitrary function, wrap it in a RunnableLambda
user_func = cast(Callable, llm_or_chain_factory)
sig = inspect.signature(user_func)
logger.info(f"Wrapping function {sig} as RunnableLambda.")
wrapped = RunnableLambda(user_func)
return lambda: wrapped
constructor = cast(Callable, llm_or_chain_factory)
if isinstance(_model, BaseLanguageModel):
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model
return llm_or_chain_factory
elif not isinstance(_model, Runnable):
# This is unlikely to happen - a constructor for a model function
return lambda: RunnableLambda(constructor)
else:
# Typical correct case
return constructor # noqa
return llm_or_chain_factory
@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
def _get_project_name(
project_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MCF,
) -> str:
"""
Get the project name.
@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs(
examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MCF,
input_mapper: Optional[Callable[[Dict], Any]],
) -> Iterator[Example]:
"""Validate that the example inputs are valid for the model."""
@ -324,7 +354,11 @@ def _validate_example_inputs(
_validate_example_inputs_for_language_model(first_example, input_mapper)
else:
chain = llm_or_chain_factory()
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
if isinstance(chain, Chain):
# Otherwise it's a runnable
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
elif isinstance(chain, Runnable):
logger.debug(f"Skipping input validation for {chain}")
return examples
@ -332,7 +366,7 @@ def _validate_example_inputs(
def _setup_evaluation(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MCF,
examples: Iterator[Example],
evaluation: Optional[RunEvalConfig],
data_type: DataType,
@ -353,8 +387,8 @@ def _setup_evaluation(
"Please specify a dataset with the default 'kv' data type."
)
chain = llm_or_chain_factory()
run_inputs = chain.input_keys
run_outputs = chain.output_keys
run_inputs = chain.input_keys if isinstance(chain, Chain) else None
run_outputs = chain.output_keys if isinstance(chain, Chain) else None
run_evaluators = _load_run_evaluators(
evaluation,
run_type,
@ -372,17 +406,15 @@ def _setup_evaluation(
def _determine_input_key(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_type: str,
) -> Optional[str]:
input_key = None
if config.input_key:
input_key = config.input_key
if run_inputs and input_key not in run_inputs:
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
elif run_type == "llm":
input_key = None
elif run_inputs and len(run_inputs) == 1:
input_key = run_inputs[0]
else:
elif run_inputs is not None and len(run_inputs) > 1:
raise ValueError(
f"Must specify input key for model with multiple inputs: {run_inputs}"
)
@ -393,19 +425,17 @@ def _determine_input_key(
def _determine_prediction_key(
config: RunEvalConfig,
run_outputs: Optional[List[str]],
run_type: str,
) -> Optional[str]:
prediction_key = None
if config.prediction_key:
prediction_key = config.prediction_key
if run_outputs and prediction_key not in run_outputs:
raise ValueError(
f"Prediction key {prediction_key} not in run outputs {run_outputs}"
)
elif run_type == "llm":
prediction_key = None
elif run_outputs and len(run_outputs) == 1:
prediction_key = run_outputs[0]
else:
elif run_outputs is not None and len(run_outputs) > 1:
raise ValueError(
f"Must specify prediction key for model"
f" with multiple outputs: {run_outputs}"
@ -491,8 +521,8 @@ def _load_run_evaluators(
"""
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = []
input_key = _determine_input_key(config, run_inputs, run_type)
prediction_key = _determine_prediction_key(config, run_outputs, run_type)
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator(
@ -590,7 +620,7 @@ async def _arun_llm(
async def _arun_chain(
chain: Chain,
chain: Union[Chain, Runnable],
inputs: Dict[str, Any],
callbacks: Callbacks,
*,
@ -598,20 +628,22 @@ async def _arun_chain(
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
if input_mapper is not None:
inputs_ = input_mapper(inputs)
output: Union[dict, str] = await chain.acall(
inputs_, callbacks=callbacks, tags=tags
)
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = await chain.acall(val, callbacks=callbacks, tags=tags)
else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else:
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = await chain.ainvoke(inputs_, config=runnable_config)
return output
async def _arun_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MCF,
n_repetitions: int,
*,
tags: Optional[List[str]] = None,
@ -810,12 +842,12 @@ async def _arun_on_examples(
Returns:
A dictionary mapping example ids to the model outputs.
"""
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, wrapped_model)
run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type
wrapped_model, examples, evaluation, data_type
)
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
results: Dict[str, List[Any]] = {}
async def process_example(
@ -824,7 +856,7 @@ async def _arun_on_examples(
"""Process a single example."""
result = await _arun_llm_or_chain(
example,
llm_or_chain_factory,
wrapped_model,
num_repetitions,
tags=tags,
callbacks=callbacks,
@ -911,7 +943,7 @@ def _run_llm(
def _run_chain(
chain: Chain,
chain: Union[Chain, Runnable],
inputs: Dict[str, Any],
callbacks: Callbacks,
*,
@ -919,18 +951,22 @@ def _run_chain(
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[Dict, str]:
"""Run a chain on inputs."""
if input_mapper is not None:
inputs_ = input_mapper(inputs)
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags)
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
if isinstance(chain, Chain):
if isinstance(inputs_, dict) and len(inputs_) == 1:
val = next(iter(inputs_.values()))
output = chain(val, callbacks=callbacks, tags=tags)
else:
output = chain(inputs_, callbacks=callbacks, tags=tags)
else:
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
output = chain(inputs_, callbacks=callbacks, tags=tags)
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
output = chain.invoke(inputs_, config=runnable_config)
return output
def _run_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MCF,
n_repetitions: int,
*,
tags: Optional[List[str]] = None,
@ -986,7 +1022,8 @@ def _run_llm_or_chain(
outputs.append(output)
except Exception as e:
logger.warning(
f"{chain_or_llm} failed for example {example.id}. Error: {e}"
f"{chain_or_llm} failed for example {example.id} with inputs:"
f" {example.inputs}.\nError: {e}",
)
outputs.append({"Error": str(e)})
if callbacks and previous_example_ids:
@ -1080,7 +1117,7 @@ def _prepare_eval_run(
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
) -> Tuple[MODEL_OR_CHAIN_FACTORY, str, Dataset, Iterator[Example]]:
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory)
try:

View File

@ -10,9 +10,11 @@ from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.evaluation import EvaluatorType
from langchain.llms.openai import OpenAI
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith.evaluation import InputFormatError
from langchain.smith.evaluation.runner_utils import arun_on_dataset
def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
@ -427,3 +429,47 @@ def test_chain_on_kv_singleio_dataset(
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
@pytest.mark.asyncio
async def test_runnable_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
await arun_on_dataset(
client,
kv_singleio_dataset_name,
runnable,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
@pytest.mark.asyncio
async def test_arb_func_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
def my_func(x: dict) -> str:
return runnable.invoke(x).content
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
await arun_on_dataset(
client,
kv_singleio_dataset_name,
my_func,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)