mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Support evaluating runnables and arbitrary functions (#8698)
Added a couple of "integration tests" for these that I ran. Main design point of feedback: at this point, would it just be better to have separate arguments for each type? Little confusing what is or isn't supported and what is the intended usage at this point since I try to wrap the function as runnable or pack or unpack chains/llms. ``` run_on_dataset( ... llm_or_chain_factory = None, llm = None, chain = NOne, runnable=None, function=None ): # raise error if none set ``` Downside with runnables and arbitrary function support is that you get much less helpful validation and error messages, but I don't think we should block you from this, at least.
This commit is contained in:
parent
d00a247da7
commit
c8f3615aa6
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import uuid
|
||||
@ -19,6 +20,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator
|
||||
from langchain.schema import ChatResult, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, messages_from_dict
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
|
||||
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
|
||||
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||
MODEL_OR_CHAIN_FACTORY = Union[
|
||||
Callable[[], Union[Chain, Runnable]],
|
||||
BaseLanguageModel,
|
||||
Callable[[dict], Any],
|
||||
Runnable,
|
||||
Chain,
|
||||
]
|
||||
MCF = Union[Callable[[], Union[Chain, Runnable]], BaseLanguageModel]
|
||||
|
||||
|
||||
class InputFormatError(Exception):
|
||||
@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str:
|
||||
|
||||
|
||||
def _wrap_in_chain_factory(
|
||||
llm_or_chain_factory: Union[Chain, MODEL_OR_CHAIN_FACTORY],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
dataset_name: str = "<my_dataset>",
|
||||
) -> MODEL_OR_CHAIN_FACTORY:
|
||||
) -> MCF:
|
||||
"""Forgive the user if they pass in a chain without memory instead of a chain
|
||||
factory. It's a common mistake. Raise a more helpful error message as well."""
|
||||
if isinstance(llm_or_chain_factory, Chain):
|
||||
@ -105,11 +115,31 @@ def _wrap_in_chain_factory(
|
||||
return lambda: chain
|
||||
elif isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
return llm_or_chain_factory
|
||||
elif isinstance(llm_or_chain_factory, Runnable):
|
||||
# Memory may exist here, but it's not elegant to check all those cases.
|
||||
lcf = llm_or_chain_factory
|
||||
return lambda: lcf
|
||||
elif callable(llm_or_chain_factory):
|
||||
_model = llm_or_chain_factory()
|
||||
try:
|
||||
_model = llm_or_chain_factory() # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
# It's an arbitrary function, wrap it in a RunnableLambda
|
||||
user_func = cast(Callable, llm_or_chain_factory)
|
||||
sig = inspect.signature(user_func)
|
||||
logger.info(f"Wrapping function {sig} as RunnableLambda.")
|
||||
wrapped = RunnableLambda(user_func)
|
||||
return lambda: wrapped
|
||||
constructor = cast(Callable, llm_or_chain_factory)
|
||||
if isinstance(_model, BaseLanguageModel):
|
||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||
# so we'll unpack it for the user.
|
||||
return _model
|
||||
return llm_or_chain_factory
|
||||
elif not isinstance(_model, Runnable):
|
||||
# This is unlikely to happen - a constructor for a model function
|
||||
return lambda: RunnableLambda(constructor)
|
||||
else:
|
||||
# Typical correct case
|
||||
return constructor # noqa
|
||||
return llm_or_chain_factory
|
||||
|
||||
|
||||
@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
|
||||
|
||||
def _get_project_name(
|
||||
project_name: Optional[str],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
llm_or_chain_factory: MCF,
|
||||
) -> str:
|
||||
"""
|
||||
Get the project name.
|
||||
@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain(
|
||||
|
||||
def _validate_example_inputs(
|
||||
examples: Iterator[Example],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]],
|
||||
) -> Iterator[Example]:
|
||||
"""Validate that the example inputs are valid for the model."""
|
||||
@ -324,7 +354,11 @@ def _validate_example_inputs(
|
||||
_validate_example_inputs_for_language_model(first_example, input_mapper)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
|
||||
if isinstance(chain, Chain):
|
||||
# Otherwise it's a runnable
|
||||
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
|
||||
elif isinstance(chain, Runnable):
|
||||
logger.debug(f"Skipping input validation for {chain}")
|
||||
return examples
|
||||
|
||||
|
||||
@ -332,7 +366,7 @@ def _validate_example_inputs(
|
||||
|
||||
|
||||
def _setup_evaluation(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
llm_or_chain_factory: MCF,
|
||||
examples: Iterator[Example],
|
||||
evaluation: Optional[RunEvalConfig],
|
||||
data_type: DataType,
|
||||
@ -353,8 +387,8 @@ def _setup_evaluation(
|
||||
"Please specify a dataset with the default 'kv' data type."
|
||||
)
|
||||
chain = llm_or_chain_factory()
|
||||
run_inputs = chain.input_keys
|
||||
run_outputs = chain.output_keys
|
||||
run_inputs = chain.input_keys if isinstance(chain, Chain) else None
|
||||
run_outputs = chain.output_keys if isinstance(chain, Chain) else None
|
||||
run_evaluators = _load_run_evaluators(
|
||||
evaluation,
|
||||
run_type,
|
||||
@ -372,17 +406,15 @@ def _setup_evaluation(
|
||||
def _determine_input_key(
|
||||
config: RunEvalConfig,
|
||||
run_inputs: Optional[List[str]],
|
||||
run_type: str,
|
||||
) -> Optional[str]:
|
||||
input_key = None
|
||||
if config.input_key:
|
||||
input_key = config.input_key
|
||||
if run_inputs and input_key not in run_inputs:
|
||||
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
|
||||
elif run_type == "llm":
|
||||
input_key = None
|
||||
elif run_inputs and len(run_inputs) == 1:
|
||||
input_key = run_inputs[0]
|
||||
else:
|
||||
elif run_inputs is not None and len(run_inputs) > 1:
|
||||
raise ValueError(
|
||||
f"Must specify input key for model with multiple inputs: {run_inputs}"
|
||||
)
|
||||
@ -393,19 +425,17 @@ def _determine_input_key(
|
||||
def _determine_prediction_key(
|
||||
config: RunEvalConfig,
|
||||
run_outputs: Optional[List[str]],
|
||||
run_type: str,
|
||||
) -> Optional[str]:
|
||||
prediction_key = None
|
||||
if config.prediction_key:
|
||||
prediction_key = config.prediction_key
|
||||
if run_outputs and prediction_key not in run_outputs:
|
||||
raise ValueError(
|
||||
f"Prediction key {prediction_key} not in run outputs {run_outputs}"
|
||||
)
|
||||
elif run_type == "llm":
|
||||
prediction_key = None
|
||||
elif run_outputs and len(run_outputs) == 1:
|
||||
prediction_key = run_outputs[0]
|
||||
else:
|
||||
elif run_outputs is not None and len(run_outputs) > 1:
|
||||
raise ValueError(
|
||||
f"Must specify prediction key for model"
|
||||
f" with multiple outputs: {run_outputs}"
|
||||
@ -491,8 +521,8 @@ def _load_run_evaluators(
|
||||
"""
|
||||
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
||||
run_evaluators = []
|
||||
input_key = _determine_input_key(config, run_inputs, run_type)
|
||||
prediction_key = _determine_prediction_key(config, run_outputs, run_type)
|
||||
input_key = _determine_input_key(config, run_inputs)
|
||||
prediction_key = _determine_prediction_key(config, run_outputs)
|
||||
reference_key = _determine_reference_key(config, example_outputs)
|
||||
for eval_config in config.evaluators:
|
||||
run_evaluator = _construct_run_evaluator(
|
||||
@ -590,7 +620,7 @@ async def _arun_llm(
|
||||
|
||||
|
||||
async def _arun_chain(
|
||||
chain: Chain,
|
||||
chain: Union[Chain, Runnable],
|
||||
inputs: Dict[str, Any],
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
@ -598,20 +628,22 @@ async def _arun_chain(
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str]:
|
||||
"""Run a chain asynchronously on inputs."""
|
||||
if input_mapper is not None:
|
||||
inputs_ = input_mapper(inputs)
|
||||
output: Union[dict, str] = await chain.acall(
|
||||
inputs_, callbacks=callbacks, tags=tags
|
||||
)
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
if isinstance(chain, Chain):
|
||||
if isinstance(inputs_, dict) and len(inputs_) == 1:
|
||||
val = next(iter(inputs_.values()))
|
||||
output = await chain.acall(val, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
|
||||
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
output = await chain.ainvoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
llm_or_chain_factory: MCF,
|
||||
n_repetitions: int,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
@ -810,12 +842,12 @@ async def _arun_on_examples(
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
run_evaluators, examples = _setup_evaluation(
|
||||
llm_or_chain_factory, examples, evaluation, data_type
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
results: Dict[str, List[Any]] = {}
|
||||
|
||||
async def process_example(
|
||||
@ -824,7 +856,7 @@ async def _arun_on_examples(
|
||||
"""Process a single example."""
|
||||
result = await _arun_llm_or_chain(
|
||||
example,
|
||||
llm_or_chain_factory,
|
||||
wrapped_model,
|
||||
num_repetitions,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
@ -911,7 +943,7 @@ def _run_llm(
|
||||
|
||||
|
||||
def _run_chain(
|
||||
chain: Chain,
|
||||
chain: Union[Chain, Runnable],
|
||||
inputs: Dict[str, Any],
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
@ -919,18 +951,22 @@ def _run_chain(
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[Dict, str]:
|
||||
"""Run a chain on inputs."""
|
||||
if input_mapper is not None:
|
||||
inputs_ = input_mapper(inputs)
|
||||
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
|
||||
if isinstance(chain, Chain):
|
||||
if isinstance(inputs_, dict) and len(inputs_) == 1:
|
||||
val = next(iter(inputs_.values()))
|
||||
output = chain(val, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
else:
|
||||
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
|
||||
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
|
||||
output = chain.invoke(inputs_, config=runnable_config)
|
||||
return output
|
||||
|
||||
|
||||
def _run_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
llm_or_chain_factory: MCF,
|
||||
n_repetitions: int,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
@ -986,7 +1022,8 @@ def _run_llm_or_chain(
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id}. Error: {e}"
|
||||
f"{chain_or_llm} failed for example {example.id} with inputs:"
|
||||
f" {example.inputs}.\nError: {e}",
|
||||
)
|
||||
outputs.append({"Error": str(e)})
|
||||
if callbacks and previous_example_ids:
|
||||
@ -1080,7 +1117,7 @@ def _prepare_eval_run(
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
) -> Tuple[MODEL_OR_CHAIN_FACTORY, str, Dataset, Iterator[Example]]:
|
||||
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
||||
try:
|
||||
|
@ -10,9 +10,11 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.evaluation import EvaluatorType
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from langchain.smith.evaluation import InputFormatError
|
||||
from langchain.smith.evaluation.runner_utils import arun_on_dataset
|
||||
|
||||
|
||||
def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
|
||||
@ -427,3 +429,47 @@ def test_chain_on_kv_singleio_dataset(
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runnable_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
runnable = (
|
||||
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
|
||||
| ChatOpenAI()
|
||||
)
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
await arun_on_dataset(
|
||||
client,
|
||||
kv_singleio_dataset_name,
|
||||
runnable,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arb_func_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
runnable = (
|
||||
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
|
||||
| ChatOpenAI()
|
||||
)
|
||||
|
||||
def my_func(x: dict) -> str:
|
||||
return runnable.invoke(x).content
|
||||
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
await arun_on_dataset(
|
||||
client,
|
||||
kv_singleio_dataset_name,
|
||||
my_func,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
Loading…
Reference in New Issue
Block a user