From c8f3615aa6db0e40ba234e45c486ccf698dcf4d0 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 4 Aug 2023 16:39:04 -0700 Subject: [PATCH] Support evaluating runnables and arbitrary functions (#8698) Added a couple of "integration tests" for these that I ran. Main design point of feedback: at this point, would it just be better to have separate arguments for each type? Little confusing what is or isn't supported and what is the intended usage at this point since I try to wrap the function as runnable or pack or unpack chains/llms. ``` run_on_dataset( ... llm_or_chain_factory = None, llm = None, chain = NOne, runnable=None, function=None ): # raise error if none set ``` Downside with runnables and arbitrary function support is that you get much less helpful validation and error messages, but I don't think we should block you from this, at least. --- .../smith/evaluation/runner_utils.py | 125 ++++++++++++------ .../smith/evaluation/test_runner_utils.py | 46 +++++++ 2 files changed, 127 insertions(+), 44 deletions(-) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 2464f7c741..585c3f5028 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import functools +import inspect import itertools import logging import uuid @@ -19,6 +20,7 @@ from typing import ( Sequence, Tuple, Union, + cast, ) from urllib.parse import urlparse, urlunparse @@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator from langchain.schema import ChatResult, LLMResult from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage, messages_from_dict +from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain logger = logging.getLogger(__name__) -MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] +MODEL_OR_CHAIN_FACTORY = Union[ + Callable[[], Union[Chain, Runnable]], + BaseLanguageModel, + Callable[[dict], Any], + Runnable, + Chain, +] +MCF = Union[Callable[[], Union[Chain, Runnable]], BaseLanguageModel] class InputFormatError(Exception): @@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str: def _wrap_in_chain_factory( - llm_or_chain_factory: Union[Chain, MODEL_OR_CHAIN_FACTORY], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, dataset_name: str = "", -) -> MODEL_OR_CHAIN_FACTORY: +) -> MCF: """Forgive the user if they pass in a chain without memory instead of a chain factory. It's a common mistake. Raise a more helpful error message as well.""" if isinstance(llm_or_chain_factory, Chain): @@ -105,11 +115,31 @@ def _wrap_in_chain_factory( return lambda: chain elif isinstance(llm_or_chain_factory, BaseLanguageModel): return llm_or_chain_factory + elif isinstance(llm_or_chain_factory, Runnable): + # Memory may exist here, but it's not elegant to check all those cases. + lcf = llm_or_chain_factory + return lambda: lcf elif callable(llm_or_chain_factory): - _model = llm_or_chain_factory() + try: + _model = llm_or_chain_factory() # type: ignore[call-arg] + except TypeError: + # It's an arbitrary function, wrap it in a RunnableLambda + user_func = cast(Callable, llm_or_chain_factory) + sig = inspect.signature(user_func) + logger.info(f"Wrapping function {sig} as RunnableLambda.") + wrapped = RunnableLambda(user_func) + return lambda: wrapped + constructor = cast(Callable, llm_or_chain_factory) if isinstance(_model, BaseLanguageModel): + # It's not uncommon to do an LLM constructor instead of raw LLM, + # so we'll unpack it for the user. return _model - return llm_or_chain_factory + elif not isinstance(_model, Runnable): + # This is unlikely to happen - a constructor for a model function + return lambda: RunnableLambda(constructor) + else: + # Typical correct case + return constructor # noqa return llm_or_chain_factory @@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: def _get_project_name( project_name: Optional[str], - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, ) -> str: """ Get the project name. @@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain( def _validate_example_inputs( examples: Iterator[Example], - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, input_mapper: Optional[Callable[[Dict], Any]], ) -> Iterator[Example]: """Validate that the example inputs are valid for the model.""" @@ -324,7 +354,11 @@ def _validate_example_inputs( _validate_example_inputs_for_language_model(first_example, input_mapper) else: chain = llm_or_chain_factory() - _validate_example_inputs_for_chain(first_example, chain, input_mapper) + if isinstance(chain, Chain): + # Otherwise it's a runnable + _validate_example_inputs_for_chain(first_example, chain, input_mapper) + elif isinstance(chain, Runnable): + logger.debug(f"Skipping input validation for {chain}") return examples @@ -332,7 +366,7 @@ def _validate_example_inputs( def _setup_evaluation( - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, examples: Iterator[Example], evaluation: Optional[RunEvalConfig], data_type: DataType, @@ -353,8 +387,8 @@ def _setup_evaluation( "Please specify a dataset with the default 'kv' data type." ) chain = llm_or_chain_factory() - run_inputs = chain.input_keys - run_outputs = chain.output_keys + run_inputs = chain.input_keys if isinstance(chain, Chain) else None + run_outputs = chain.output_keys if isinstance(chain, Chain) else None run_evaluators = _load_run_evaluators( evaluation, run_type, @@ -372,17 +406,15 @@ def _setup_evaluation( def _determine_input_key( config: RunEvalConfig, run_inputs: Optional[List[str]], - run_type: str, ) -> Optional[str]: + input_key = None if config.input_key: input_key = config.input_key if run_inputs and input_key not in run_inputs: raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}") - elif run_type == "llm": - input_key = None elif run_inputs and len(run_inputs) == 1: input_key = run_inputs[0] - else: + elif run_inputs is not None and len(run_inputs) > 1: raise ValueError( f"Must specify input key for model with multiple inputs: {run_inputs}" ) @@ -393,19 +425,17 @@ def _determine_input_key( def _determine_prediction_key( config: RunEvalConfig, run_outputs: Optional[List[str]], - run_type: str, ) -> Optional[str]: + prediction_key = None if config.prediction_key: prediction_key = config.prediction_key if run_outputs and prediction_key not in run_outputs: raise ValueError( f"Prediction key {prediction_key} not in run outputs {run_outputs}" ) - elif run_type == "llm": - prediction_key = None elif run_outputs and len(run_outputs) == 1: prediction_key = run_outputs[0] - else: + elif run_outputs is not None and len(run_outputs) > 1: raise ValueError( f"Must specify prediction key for model" f" with multiple outputs: {run_outputs}" @@ -491,8 +521,8 @@ def _load_run_evaluators( """ eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) run_evaluators = [] - input_key = _determine_input_key(config, run_inputs, run_type) - prediction_key = _determine_prediction_key(config, run_outputs, run_type) + input_key = _determine_input_key(config, run_inputs) + prediction_key = _determine_prediction_key(config, run_outputs) reference_key = _determine_reference_key(config, example_outputs) for eval_config in config.evaluators: run_evaluator = _construct_run_evaluator( @@ -590,7 +620,7 @@ async def _arun_llm( async def _arun_chain( - chain: Chain, + chain: Union[Chain, Runnable], inputs: Dict[str, Any], callbacks: Callbacks, *, @@ -598,20 +628,22 @@ async def _arun_chain( input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[dict, str]: """Run a chain asynchronously on inputs.""" - if input_mapper is not None: - inputs_ = input_mapper(inputs) - output: Union[dict, str] = await chain.acall( - inputs_, callbacks=callbacks, tags=tags - ) + inputs_ = inputs if input_mapper is None else input_mapper(inputs) + if isinstance(chain, Chain): + if isinstance(inputs_, dict) and len(inputs_) == 1: + val = next(iter(inputs_.values())) + output = await chain.acall(val, callbacks=callbacks, tags=tags) + else: + output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) else: - inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs - output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + output = await chain.ainvoke(inputs_, config=runnable_config) return output async def _arun_llm_or_chain( example: Example, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, n_repetitions: int, *, tags: Optional[List[str]] = None, @@ -810,12 +842,12 @@ async def _arun_on_examples( Returns: A dictionary mapping example ids to the model outputs. """ - llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) - project_name = _get_project_name(project_name, llm_or_chain_factory) + wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory) + project_name = _get_project_name(project_name, wrapped_model) run_evaluators, examples = _setup_evaluation( - llm_or_chain_factory, examples, evaluation, data_type + wrapped_model, examples, evaluation, data_type ) - examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) + examples = _validate_example_inputs(examples, wrapped_model, input_mapper) results: Dict[str, List[Any]] = {} async def process_example( @@ -824,7 +856,7 @@ async def _arun_on_examples( """Process a single example.""" result = await _arun_llm_or_chain( example, - llm_or_chain_factory, + wrapped_model, num_repetitions, tags=tags, callbacks=callbacks, @@ -911,7 +943,7 @@ def _run_llm( def _run_chain( - chain: Chain, + chain: Union[Chain, Runnable], inputs: Dict[str, Any], callbacks: Callbacks, *, @@ -919,18 +951,22 @@ def _run_chain( input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[Dict, str]: """Run a chain on inputs.""" - if input_mapper is not None: - inputs_ = input_mapper(inputs) - output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags) + inputs_ = inputs if input_mapper is None else input_mapper(inputs) + if isinstance(chain, Chain): + if isinstance(inputs_, dict) and len(inputs_) == 1: + val = next(iter(inputs_.values())) + output = chain(val, callbacks=callbacks, tags=tags) + else: + output = chain(inputs_, callbacks=callbacks, tags=tags) else: - inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs - output = chain(inputs_, callbacks=callbacks, tags=tags) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + output = chain.invoke(inputs_, config=runnable_config) return output def _run_llm_or_chain( example: Example, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, n_repetitions: int, *, tags: Optional[List[str]] = None, @@ -986,7 +1022,8 @@ def _run_llm_or_chain( outputs.append(output) except Exception as e: logger.warning( - f"{chain_or_llm} failed for example {example.id}. Error: {e}" + f"{chain_or_llm} failed for example {example.id} with inputs:" + f" {example.inputs}.\nError: {e}", ) outputs.append({"Error": str(e)}) if callbacks and previous_example_ids: @@ -1080,7 +1117,7 @@ def _prepare_eval_run( dataset_name: str, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, project_name: Optional[str], -) -> Tuple[MODEL_OR_CHAIN_FACTORY, str, Dataset, Iterator[Example]]: +) -> Tuple[MCF, str, Dataset, Iterator[Example]]: llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) project_name = _get_project_name(project_name, llm_or_chain_factory) try: diff --git a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py index acda8885e4..9696515a9c 100644 --- a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py @@ -10,9 +10,11 @@ from langchain.chains.llm import LLMChain from langchain.chat_models import ChatOpenAI from langchain.evaluation import EvaluatorType from langchain.llms.openai import OpenAI +from langchain.prompts.chat import ChatPromptTemplate from langchain.schema.messages import BaseMessage, HumanMessage from langchain.smith import RunEvalConfig, run_on_dataset from langchain.smith.evaluation import InputFormatError +from langchain.smith.evaluation.runner_utils import arun_on_dataset def _check_all_feedback_passed(_project_name: str, client: Client) -> None: @@ -427,3 +429,47 @@ def test_chain_on_kv_singleio_dataset( tags=["shouldpass"], ) _check_all_feedback_passed(eval_project_name, client) + + +@pytest.mark.asyncio +async def test_runnable_on_kv_singleio_dataset( + kv_singleio_dataset_name: str, eval_project_name: str, client: Client +) -> None: + runnable = ( + ChatPromptTemplate.from_messages([("human", "{the wackiest input}")]) + | ChatOpenAI() + ) + eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA]) + await arun_on_dataset( + client, + kv_singleio_dataset_name, + runnable, + evaluation=eval_config, + project_name=eval_project_name, + tags=["shouldpass"], + ) + _check_all_feedback_passed(eval_project_name, client) + + +@pytest.mark.asyncio +async def test_arb_func_on_kv_singleio_dataset( + kv_singleio_dataset_name: str, eval_project_name: str, client: Client +) -> None: + runnable = ( + ChatPromptTemplate.from_messages([("human", "{the wackiest input}")]) + | ChatOpenAI() + ) + + def my_func(x: dict) -> str: + return runnable.invoke(x).content + + eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA]) + await arun_on_dataset( + client, + kv_singleio_dataset_name, + my_func, + evaluation=eval_config, + project_name=eval_project_name, + tags=["shouldpass"], + ) + _check_all_feedback_passed(eval_project_name, client)