Add Input Mapper in run_on_dataset (#6894)

If you create a dataset from runs and run the same chain or llm on it
later, it usually works great.

If you have an agent dataset and want to run a different agent on it, or
have more complex schema, it's hard for us to automatically map these
values every time. This PR lets you pass in an input_mapper function
that converts the example inputs to whatever format your model expects
pull/6948/head
Zander Chase 1 year ago committed by GitHub
parent 76d03f398d
commit 429f4dbe4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -139,6 +139,7 @@ async def _arun_llm(
*,
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[LLMResult, ChatResult]:
"""
Asynchronously run the language model.
@ -148,6 +149,7 @@ async def _arun_llm(
inputs: The input dictionary.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map inputs to the expected format.
Returns:
The LLMResult or ChatResult.
@ -155,7 +157,13 @@ async def _arun_llm(
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM):
if input_mapper is not None:
if not isinstance(llm, (BaseLLM, BaseChatModel)):
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
llm_output = await llm.agenerate(
input_mapper(inputs), callbacks=callbacks, tags=tags
)
elif isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
llm_output = await llm.agenerate(
@ -191,6 +199,7 @@ async def _arun_llm_or_chain(
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""
Asynchronously run the Chain or language model.
@ -201,6 +210,7 @@ async def _arun_llm_or_chain(
n_repetitions: The number of times to run the model on each example.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
input_mapper: Optional function to map the input to the expected format.
Returns:
A list of outputs.
@ -223,12 +233,16 @@ async def _arun_llm_or_chain(
example.inputs,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
inputs_ = example.inputs
if len(inputs_) == 1:
inputs_ = next(iter(inputs_.values()))
if input_mapper is not None:
inputs_ = input_mapper(example.inputs)
else:
inputs_ = example.inputs
if len(inputs_) == 1:
inputs_ = next(iter(inputs_.values()))
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
outputs.append(output)
except Exception as e:
@ -333,6 +347,7 @@ async def arun_on_examples(
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Dict[str, Any]:
"""
Asynchronously run the chain on examples and store traces
@ -354,6 +369,11 @@ async def arun_on_examples(
client will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
input_mapper: function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary mapping example ids to the model outputs.
@ -377,6 +397,7 @@ async def arun_on_examples(
num_repetitions,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
results[str(example.id)] = result
job_state["num_processed"] += 1
@ -407,6 +428,7 @@ def run_llm(
callbacks: Callbacks,
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[LLMResult, ChatResult]:
"""
Run the language model on the example.
@ -416,14 +438,18 @@ def run_llm(
inputs: The input dictionary.
callbacks: The callbacks to use during the run.
tags: Optional tags to add to the run.
input_mapper: function to map to the inputs dictionary from an Example
Returns:
The LLMResult or ChatResult.
Raises:
ValueError: If the LLM type is unsupported.
InputFormatError: If the input format is invalid.
"""
if isinstance(llm, BaseLLM):
if input_mapper is not None:
if not isinstance(llm, (BaseLLM, BaseChatModel)):
raise ValueError(f"Unsupported LLM type {type(llm).__name__}")
llm_output = llm.generate(input_mapper(inputs), callbacks=callbacks, tags=tags)
elif isinstance(llm, BaseLLM):
try:
llm_prompts = _get_prompts(inputs)
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
@ -455,6 +481,7 @@ def run_llm_or_chain(
*,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""
Run the Chain or language model synchronously.
@ -483,13 +510,20 @@ def run_llm_or_chain(
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = run_llm(
llm_or_chain_factory, example.inputs, callbacks, tags=tags
llm_or_chain_factory,
example.inputs,
callbacks,
tags=tags,
input_mapper=input_mapper,
)
else:
chain = llm_or_chain_factory()
inputs_ = example.inputs
if len(inputs_) == 1:
inputs_ = next(iter(inputs_.values()))
if input_mapper is not None:
inputs_ = input_mapper(example.inputs)
else:
inputs_ = example.inputs
if len(inputs_) == 1:
inputs_ = next(iter(inputs_.values()))
output = chain(inputs_, callbacks=callbacks, tags=tags)
outputs.append(output)
except Exception as e:
@ -512,6 +546,7 @@ def run_on_examples(
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Dict[str, Any]:
"""
Run the Chain or language model on examples and store
@ -532,6 +567,11 @@ def run_on_examples(
will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary mapping example ids to the model outputs.
@ -552,6 +592,7 @@ def run_on_examples(
num_repetitions,
tags=tags,
callbacks=callbacks,
input_mapper=input_mapper,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
@ -599,6 +640,7 @@ async def arun_on_dataset(
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Dict[str, Any]:
"""
Asynchronously run the Chain or language model on a dataset
@ -620,7 +662,11 @@ async def arun_on_dataset(
client will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary containing the run's project name and the resulting model outputs.
"""
@ -638,6 +684,7 @@ async def arun_on_dataset(
client=client_,
tags=tags,
run_evaluators=run_evaluators,
input_mapper=input_mapper,
)
return {
"project_name": project_name,
@ -655,6 +702,7 @@ def run_on_dataset(
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
run_evaluators: Optional[Sequence[RunEvaluator]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
) -> Dict[str, Any]:
"""
Run the Chain or language model on a dataset and store traces
@ -676,6 +724,11 @@ def run_on_dataset(
will be created using the credentials in the environment.
tags: Tags to add to each run in the project.
run_evaluators: Evaluators to run on the results of the chain.
input_mapper: A function to map to the inputs dictionary from an Example
to the format expected by the model to be evaluated. This is useful if
your model needs to deserialize more complex schema or if your dataset
has inputs with keys that differ from what is expected by your chain
or agent.
Returns:
A dictionary containing the run's project name and the resulting model outputs.
@ -693,6 +746,7 @@ def run_on_dataset(
tags=tags,
run_evaluators=run_evaluators,
client=client_,
input_mapper=input_mapper,
)
return {
"project_name": project_name,

@ -10,13 +10,16 @@ from langchainplus_sdk.schemas import Dataset, Example
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.transform import TransformChain
from langchain.client.runner_utils import (
InputFormatError,
_get_messages,
_get_prompts,
arun_on_dataset,
run_llm,
run_llm_or_chain,
)
from langchain.schema import LLMResult
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -75,6 +78,57 @@ def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
_get_prompts(inputs)
def test_run_llm_or_chain_with_input_mapper() -> None:
example = Example(
id=uuid.uuid4(),
created_at=_CREATED_AT,
inputs={"the wrong input": "1", "another key": "2"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
)
def run_val(inputs: dict) -> dict:
assert "the right input" in inputs
return {"output": "2"}
mock_chain = TransformChain(
input_variables=["the right input"],
output_variables=["output"],
transform=run_val,
)
def input_mapper(inputs: dict) -> dict:
assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]}
result = run_llm_or_chain(
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
)
assert len(result) == 1
assert result[0] == {"output": "2", "the right input": "1"}
bad_result = run_llm_or_chain(
example,
lambda: mock_chain,
n_repetitions=1,
)
assert len(bad_result) == 1
assert "Error" in bad_result[0]
# Try with LLM
def llm_input_mapper(inputs: dict) -> List[str]:
assert "the wrong input" in inputs
return ["the right input"]
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
result = run_llm_or_chain(
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
)
assert len(result) == 1
llm_result = result[0]
assert isinstance(llm_result, LLMResult)
assert llm_result.generations[0][0].text == "somenumber"
@pytest.mark.parametrize(
"inputs",
[
@ -171,6 +225,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
n_repetitions: int,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)

Loading…
Cancel
Save