diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 2464f7c741..585c3f5028 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import functools +import inspect import itertools import logging import uuid @@ -19,6 +20,7 @@ from typing import ( Sequence, Tuple, Union, + cast, ) from urllib.parse import urlparse, urlunparse @@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator from langchain.schema import ChatResult, LLMResult from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage, messages_from_dict +from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain logger = logging.getLogger(__name__) -MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] +MODEL_OR_CHAIN_FACTORY = Union[ + Callable[[], Union[Chain, Runnable]], + BaseLanguageModel, + Callable[[dict], Any], + Runnable, + Chain, +] +MCF = Union[Callable[[], Union[Chain, Runnable]], BaseLanguageModel] class InputFormatError(Exception): @@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str: def _wrap_in_chain_factory( - llm_or_chain_factory: Union[Chain, MODEL_OR_CHAIN_FACTORY], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, dataset_name: str = "", -) -> MODEL_OR_CHAIN_FACTORY: +) -> MCF: """Forgive the user if they pass in a chain without memory instead of a chain factory. It's a common mistake. Raise a more helpful error message as well.""" if isinstance(llm_or_chain_factory, Chain): @@ -105,11 +115,31 @@ def _wrap_in_chain_factory( return lambda: chain elif isinstance(llm_or_chain_factory, BaseLanguageModel): return llm_or_chain_factory + elif isinstance(llm_or_chain_factory, Runnable): + # Memory may exist here, but it's not elegant to check all those cases. + lcf = llm_or_chain_factory + return lambda: lcf elif callable(llm_or_chain_factory): - _model = llm_or_chain_factory() + try: + _model = llm_or_chain_factory() # type: ignore[call-arg] + except TypeError: + # It's an arbitrary function, wrap it in a RunnableLambda + user_func = cast(Callable, llm_or_chain_factory) + sig = inspect.signature(user_func) + logger.info(f"Wrapping function {sig} as RunnableLambda.") + wrapped = RunnableLambda(user_func) + return lambda: wrapped + constructor = cast(Callable, llm_or_chain_factory) if isinstance(_model, BaseLanguageModel): + # It's not uncommon to do an LLM constructor instead of raw LLM, + # so we'll unpack it for the user. return _model - return llm_or_chain_factory + elif not isinstance(_model, Runnable): + # This is unlikely to happen - a constructor for a model function + return lambda: RunnableLambda(constructor) + else: + # Typical correct case + return constructor # noqa return llm_or_chain_factory @@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: def _get_project_name( project_name: Optional[str], - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, ) -> str: """ Get the project name. @@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain( def _validate_example_inputs( examples: Iterator[Example], - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, input_mapper: Optional[Callable[[Dict], Any]], ) -> Iterator[Example]: """Validate that the example inputs are valid for the model.""" @@ -324,7 +354,11 @@ def _validate_example_inputs( _validate_example_inputs_for_language_model(first_example, input_mapper) else: chain = llm_or_chain_factory() - _validate_example_inputs_for_chain(first_example, chain, input_mapper) + if isinstance(chain, Chain): + # Otherwise it's a runnable + _validate_example_inputs_for_chain(first_example, chain, input_mapper) + elif isinstance(chain, Runnable): + logger.debug(f"Skipping input validation for {chain}") return examples @@ -332,7 +366,7 @@ def _validate_example_inputs( def _setup_evaluation( - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, examples: Iterator[Example], evaluation: Optional[RunEvalConfig], data_type: DataType, @@ -353,8 +387,8 @@ def _setup_evaluation( "Please specify a dataset with the default 'kv' data type." ) chain = llm_or_chain_factory() - run_inputs = chain.input_keys - run_outputs = chain.output_keys + run_inputs = chain.input_keys if isinstance(chain, Chain) else None + run_outputs = chain.output_keys if isinstance(chain, Chain) else None run_evaluators = _load_run_evaluators( evaluation, run_type, @@ -372,17 +406,15 @@ def _setup_evaluation( def _determine_input_key( config: RunEvalConfig, run_inputs: Optional[List[str]], - run_type: str, ) -> Optional[str]: + input_key = None if config.input_key: input_key = config.input_key if run_inputs and input_key not in run_inputs: raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}") - elif run_type == "llm": - input_key = None elif run_inputs and len(run_inputs) == 1: input_key = run_inputs[0] - else: + elif run_inputs is not None and len(run_inputs) > 1: raise ValueError( f"Must specify input key for model with multiple inputs: {run_inputs}" ) @@ -393,19 +425,17 @@ def _determine_input_key( def _determine_prediction_key( config: RunEvalConfig, run_outputs: Optional[List[str]], - run_type: str, ) -> Optional[str]: + prediction_key = None if config.prediction_key: prediction_key = config.prediction_key if run_outputs and prediction_key not in run_outputs: raise ValueError( f"Prediction key {prediction_key} not in run outputs {run_outputs}" ) - elif run_type == "llm": - prediction_key = None elif run_outputs and len(run_outputs) == 1: prediction_key = run_outputs[0] - else: + elif run_outputs is not None and len(run_outputs) > 1: raise ValueError( f"Must specify prediction key for model" f" with multiple outputs: {run_outputs}" @@ -491,8 +521,8 @@ def _load_run_evaluators( """ eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) run_evaluators = [] - input_key = _determine_input_key(config, run_inputs, run_type) - prediction_key = _determine_prediction_key(config, run_outputs, run_type) + input_key = _determine_input_key(config, run_inputs) + prediction_key = _determine_prediction_key(config, run_outputs) reference_key = _determine_reference_key(config, example_outputs) for eval_config in config.evaluators: run_evaluator = _construct_run_evaluator( @@ -590,7 +620,7 @@ async def _arun_llm( async def _arun_chain( - chain: Chain, + chain: Union[Chain, Runnable], inputs: Dict[str, Any], callbacks: Callbacks, *, @@ -598,20 +628,22 @@ async def _arun_chain( input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[dict, str]: """Run a chain asynchronously on inputs.""" - if input_mapper is not None: - inputs_ = input_mapper(inputs) - output: Union[dict, str] = await chain.acall( - inputs_, callbacks=callbacks, tags=tags - ) + inputs_ = inputs if input_mapper is None else input_mapper(inputs) + if isinstance(chain, Chain): + if isinstance(inputs_, dict) and len(inputs_) == 1: + val = next(iter(inputs_.values())) + output = await chain.acall(val, callbacks=callbacks, tags=tags) + else: + output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) else: - inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs - output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + output = await chain.ainvoke(inputs_, config=runnable_config) return output async def _arun_llm_or_chain( example: Example, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, n_repetitions: int, *, tags: Optional[List[str]] = None, @@ -810,12 +842,12 @@ async def _arun_on_examples( Returns: A dictionary mapping example ids to the model outputs. """ - llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) - project_name = _get_project_name(project_name, llm_or_chain_factory) + wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory) + project_name = _get_project_name(project_name, wrapped_model) run_evaluators, examples = _setup_evaluation( - llm_or_chain_factory, examples, evaluation, data_type + wrapped_model, examples, evaluation, data_type ) - examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) + examples = _validate_example_inputs(examples, wrapped_model, input_mapper) results: Dict[str, List[Any]] = {} async def process_example( @@ -824,7 +856,7 @@ async def _arun_on_examples( """Process a single example.""" result = await _arun_llm_or_chain( example, - llm_or_chain_factory, + wrapped_model, num_repetitions, tags=tags, callbacks=callbacks, @@ -911,7 +943,7 @@ def _run_llm( def _run_chain( - chain: Chain, + chain: Union[Chain, Runnable], inputs: Dict[str, Any], callbacks: Callbacks, *, @@ -919,18 +951,22 @@ def _run_chain( input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[Dict, str]: """Run a chain on inputs.""" - if input_mapper is not None: - inputs_ = input_mapper(inputs) - output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags) + inputs_ = inputs if input_mapper is None else input_mapper(inputs) + if isinstance(chain, Chain): + if isinstance(inputs_, dict) and len(inputs_) == 1: + val = next(iter(inputs_.values())) + output = chain(val, callbacks=callbacks, tags=tags) + else: + output = chain(inputs_, callbacks=callbacks, tags=tags) else: - inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs - output = chain(inputs_, callbacks=callbacks, tags=tags) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + output = chain.invoke(inputs_, config=runnable_config) return output def _run_llm_or_chain( example: Example, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + llm_or_chain_factory: MCF, n_repetitions: int, *, tags: Optional[List[str]] = None, @@ -986,7 +1022,8 @@ def _run_llm_or_chain( outputs.append(output) except Exception as e: logger.warning( - f"{chain_or_llm} failed for example {example.id}. Error: {e}" + f"{chain_or_llm} failed for example {example.id} with inputs:" + f" {example.inputs}.\nError: {e}", ) outputs.append({"Error": str(e)}) if callbacks and previous_example_ids: @@ -1080,7 +1117,7 @@ def _prepare_eval_run( dataset_name: str, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, project_name: Optional[str], -) -> Tuple[MODEL_OR_CHAIN_FACTORY, str, Dataset, Iterator[Example]]: +) -> Tuple[MCF, str, Dataset, Iterator[Example]]: llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) project_name = _get_project_name(project_name, llm_or_chain_factory) try: diff --git a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py index acda8885e4..9696515a9c 100644 --- a/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/integration_tests/smith/evaluation/test_runner_utils.py @@ -10,9 +10,11 @@ from langchain.chains.llm import LLMChain from langchain.chat_models import ChatOpenAI from langchain.evaluation import EvaluatorType from langchain.llms.openai import OpenAI +from langchain.prompts.chat import ChatPromptTemplate from langchain.schema.messages import BaseMessage, HumanMessage from langchain.smith import RunEvalConfig, run_on_dataset from langchain.smith.evaluation import InputFormatError +from langchain.smith.evaluation.runner_utils import arun_on_dataset def _check_all_feedback_passed(_project_name: str, client: Client) -> None: @@ -427,3 +429,47 @@ def test_chain_on_kv_singleio_dataset( tags=["shouldpass"], ) _check_all_feedback_passed(eval_project_name, client) + + +@pytest.mark.asyncio +async def test_runnable_on_kv_singleio_dataset( + kv_singleio_dataset_name: str, eval_project_name: str, client: Client +) -> None: + runnable = ( + ChatPromptTemplate.from_messages([("human", "{the wackiest input}")]) + | ChatOpenAI() + ) + eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA]) + await arun_on_dataset( + client, + kv_singleio_dataset_name, + runnable, + evaluation=eval_config, + project_name=eval_project_name, + tags=["shouldpass"], + ) + _check_all_feedback_passed(eval_project_name, client) + + +@pytest.mark.asyncio +async def test_arb_func_on_kv_singleio_dataset( + kv_singleio_dataset_name: str, eval_project_name: str, client: Client +) -> None: + runnable = ( + ChatPromptTemplate.from_messages([("human", "{the wackiest input}")]) + | ChatOpenAI() + ) + + def my_func(x: dict) -> str: + return runnable.invoke(x).content + + eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA]) + await arun_on_dataset( + client, + kv_singleio_dataset_name, + my_func, + evaluation=eval_config, + project_name=eval_project_name, + tags=["shouldpass"], + ) + _check_all_feedback_passed(eval_project_name, client)