From 429f4dbe4d5d38c215ec38f765c8398558183918 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:53:49 -0700 Subject: [PATCH] Add Input Mapper in run_on_dataset (#6894) If you create a dataset from runs and run the same chain or llm on it later, it usually works great. If you have an agent dataset and want to run a different agent on it, or have more complex schema, it's hard for us to automatically map these values every time. This PR lets you pass in an input_mapper function that converts the example inputs to whatever format your model expects --- langchain/client/runner_utils.py | 76 +++++++++++++++++--- tests/unit_tests/client/test_runner_utils.py | 55 ++++++++++++++ 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index 7141a07ee1..10f380b445 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -139,6 +139,7 @@ async def _arun_llm( *, tags: Optional[List[str]] = None, callbacks: Callbacks = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[LLMResult, ChatResult]: """ Asynchronously run the language model. @@ -148,6 +149,7 @@ async def _arun_llm( inputs: The input dictionary. tags: Optional tags to add to the run. callbacks: Optional callbacks to use during the run. + input_mapper: Optional function to map inputs to the expected format. Returns: The LLMResult or ChatResult. @@ -155,7 +157,13 @@ async def _arun_llm( ValueError: If the LLM type is unsupported. InputFormatError: If the input format is invalid. """ - if isinstance(llm, BaseLLM): + if input_mapper is not None: + if not isinstance(llm, (BaseLLM, BaseChatModel)): + raise ValueError(f"Unsupported LLM type {type(llm).__name__}") + llm_output = await llm.agenerate( + input_mapper(inputs), callbacks=callbacks, tags=tags + ) + elif isinstance(llm, BaseLLM): try: llm_prompts = _get_prompts(inputs) llm_output = await llm.agenerate( @@ -191,6 +199,7 @@ async def _arun_llm_or_chain( *, tags: Optional[List[str]] = None, callbacks: Optional[List[BaseCallbackHandler]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """ Asynchronously run the Chain or language model. @@ -201,6 +210,7 @@ async def _arun_llm_or_chain( n_repetitions: The number of times to run the model on each example. tags: Optional tags to add to the run. callbacks: Optional callbacks to use during the run. + input_mapper: Optional function to map the input to the expected format. Returns: A list of outputs. @@ -223,12 +233,16 @@ async def _arun_llm_or_chain( example.inputs, tags=tags, callbacks=callbacks, + input_mapper=input_mapper, ) else: chain = llm_or_chain_factory() - inputs_ = example.inputs - if len(inputs_) == 1: - inputs_ = next(iter(inputs_.values())) + if input_mapper is not None: + inputs_ = input_mapper(example.inputs) + else: + inputs_ = example.inputs + if len(inputs_) == 1: + inputs_ = next(iter(inputs_.values())) output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) outputs.append(output) except Exception as e: @@ -333,6 +347,7 @@ async def arun_on_examples( client: Optional[LangChainPlusClient] = None, tags: Optional[List[str]] = None, run_evaluators: Optional[Sequence[RunEvaluator]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Dict[str, Any]: """ Asynchronously run the chain on examples and store traces @@ -354,6 +369,11 @@ async def arun_on_examples( client will be created using the credentials in the environment. tags: Tags to add to each run in the project. run_evaluators: Evaluators to run on the results of the chain. + input_mapper: function to map to the inputs dictionary from an Example + to the format expected by the model to be evaluated. This is useful if + your model needs to deserialize more complex schema or if your dataset + has inputs with keys that differ from what is expected by your chain + or agent. Returns: A dictionary mapping example ids to the model outputs. @@ -377,6 +397,7 @@ async def arun_on_examples( num_repetitions, tags=tags, callbacks=callbacks, + input_mapper=input_mapper, ) results[str(example.id)] = result job_state["num_processed"] += 1 @@ -407,6 +428,7 @@ def run_llm( callbacks: Callbacks, *, tags: Optional[List[str]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[LLMResult, ChatResult]: """ Run the language model on the example. @@ -416,14 +438,18 @@ def run_llm( inputs: The input dictionary. callbacks: The callbacks to use during the run. tags: Optional tags to add to the run. - + input_mapper: function to map to the inputs dictionary from an Example Returns: The LLMResult or ChatResult. Raises: ValueError: If the LLM type is unsupported. InputFormatError: If the input format is invalid. """ - if isinstance(llm, BaseLLM): + if input_mapper is not None: + if not isinstance(llm, (BaseLLM, BaseChatModel)): + raise ValueError(f"Unsupported LLM type {type(llm).__name__}") + llm_output = llm.generate(input_mapper(inputs), callbacks=callbacks, tags=tags) + elif isinstance(llm, BaseLLM): try: llm_prompts = _get_prompts(inputs) llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags) @@ -455,6 +481,7 @@ def run_llm_or_chain( *, tags: Optional[List[str]] = None, callbacks: Optional[List[BaseCallbackHandler]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """ Run the Chain or language model synchronously. @@ -483,13 +510,20 @@ def run_llm_or_chain( try: if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = run_llm( - llm_or_chain_factory, example.inputs, callbacks, tags=tags + llm_or_chain_factory, + example.inputs, + callbacks, + tags=tags, + input_mapper=input_mapper, ) else: chain = llm_or_chain_factory() - inputs_ = example.inputs - if len(inputs_) == 1: - inputs_ = next(iter(inputs_.values())) + if input_mapper is not None: + inputs_ = input_mapper(example.inputs) + else: + inputs_ = example.inputs + if len(inputs_) == 1: + inputs_ = next(iter(inputs_.values())) output = chain(inputs_, callbacks=callbacks, tags=tags) outputs.append(output) except Exception as e: @@ -512,6 +546,7 @@ def run_on_examples( client: Optional[LangChainPlusClient] = None, tags: Optional[List[str]] = None, run_evaluators: Optional[Sequence[RunEvaluator]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Dict[str, Any]: """ Run the Chain or language model on examples and store @@ -532,6 +567,11 @@ def run_on_examples( will be created using the credentials in the environment. tags: Tags to add to each run in the project. run_evaluators: Evaluators to run on the results of the chain. + input_mapper: A function to map to the inputs dictionary from an Example + to the format expected by the model to be evaluated. This is useful if + your model needs to deserialize more complex schema or if your dataset + has inputs with keys that differ from what is expected by your chain + or agent. Returns: A dictionary mapping example ids to the model outputs. @@ -552,6 +592,7 @@ def run_on_examples( num_repetitions, tags=tags, callbacks=callbacks, + input_mapper=input_mapper, ) if verbose: print(f"{i+1} processed", flush=True, end="\r") @@ -599,6 +640,7 @@ async def arun_on_dataset( client: Optional[LangChainPlusClient] = None, tags: Optional[List[str]] = None, run_evaluators: Optional[Sequence[RunEvaluator]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Dict[str, Any]: """ Asynchronously run the Chain or language model on a dataset @@ -620,7 +662,11 @@ async def arun_on_dataset( client will be created using the credentials in the environment. tags: Tags to add to each run in the project. run_evaluators: Evaluators to run on the results of the chain. - + input_mapper: A function to map to the inputs dictionary from an Example + to the format expected by the model to be evaluated. This is useful if + your model needs to deserialize more complex schema or if your dataset + has inputs with keys that differ from what is expected by your chain + or agent. Returns: A dictionary containing the run's project name and the resulting model outputs. """ @@ -638,6 +684,7 @@ async def arun_on_dataset( client=client_, tags=tags, run_evaluators=run_evaluators, + input_mapper=input_mapper, ) return { "project_name": project_name, @@ -655,6 +702,7 @@ def run_on_dataset( client: Optional[LangChainPlusClient] = None, tags: Optional[List[str]] = None, run_evaluators: Optional[Sequence[RunEvaluator]] = None, + input_mapper: Optional[Callable[[Dict], Any]] = None, ) -> Dict[str, Any]: """ Run the Chain or language model on a dataset and store traces @@ -676,6 +724,11 @@ def run_on_dataset( will be created using the credentials in the environment. tags: Tags to add to each run in the project. run_evaluators: Evaluators to run on the results of the chain. + input_mapper: A function to map to the inputs dictionary from an Example + to the format expected by the model to be evaluated. This is useful if + your model needs to deserialize more complex schema or if your dataset + has inputs with keys that differ from what is expected by your chain + or agent. Returns: A dictionary containing the run's project name and the resulting model outputs. @@ -693,6 +746,7 @@ def run_on_dataset( tags=tags, run_evaluators=run_evaluators, client=client_, + input_mapper=input_mapper, ) return { "project_name": project_name, diff --git a/tests/unit_tests/client/test_runner_utils.py b/tests/unit_tests/client/test_runner_utils.py index bada74b79c..3c767cdfcc 100644 --- a/tests/unit_tests/client/test_runner_utils.py +++ b/tests/unit_tests/client/test_runner_utils.py @@ -10,13 +10,16 @@ from langchainplus_sdk.schemas import Dataset, Example from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain +from langchain.chains.transform import TransformChain from langchain.client.runner_utils import ( InputFormatError, _get_messages, _get_prompts, arun_on_dataset, run_llm, + run_llm_or_chain, ) +from langchain.schema import LLMResult from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM @@ -75,6 +78,57 @@ def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None: _get_prompts(inputs) +def test_run_llm_or_chain_with_input_mapper() -> None: + example = Example( + id=uuid.uuid4(), + created_at=_CREATED_AT, + inputs={"the wrong input": "1", "another key": "2"}, + outputs={"output": "2"}, + dataset_id=str(uuid.uuid4()), + ) + + def run_val(inputs: dict) -> dict: + assert "the right input" in inputs + return {"output": "2"} + + mock_chain = TransformChain( + input_variables=["the right input"], + output_variables=["output"], + transform=run_val, + ) + + def input_mapper(inputs: dict) -> dict: + assert "the wrong input" in inputs + return {"the right input": inputs["the wrong input"]} + + result = run_llm_or_chain( + example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper + ) + assert len(result) == 1 + assert result[0] == {"output": "2", "the right input": "1"} + bad_result = run_llm_or_chain( + example, + lambda: mock_chain, + n_repetitions=1, + ) + assert len(bad_result) == 1 + assert "Error" in bad_result[0] + + # Try with LLM + def llm_input_mapper(inputs: dict) -> List[str]: + assert "the wrong input" in inputs + return ["the right input"] + + mock_llm = FakeLLM(queries={"the right input": "somenumber"}) + result = run_llm_or_chain( + example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper + ) + assert len(result) == 1 + llm_result = result[0] + assert isinstance(llm_result, LLMResult) + assert llm_result.generations[0][0].text == "somenumber" + + @pytest.mark.parametrize( "inputs", [ @@ -171,6 +225,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: n_repetitions: int, tags: Optional[List[str]] = None, callbacks: Optional[Any] = None, + **kwargs: Any, ) -> List[Dict[str, Any]]: return [ {"result": f"Result for example {example.id}"} for _ in range(n_repetitions)