mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Re-use Trajectory Evaluator (#7248)
Use the trajectory eval chain in the run evaluation implementation and update the prepare inputs method to apply to both asynca nd sync
This commit is contained in:
parent
e8f24164f0
commit
576880abc5
@ -244,41 +244,11 @@ The following is the expected answer. Use this to measure correctness:
|
||||
return ["score", "reasoning"]
|
||||
return ["score"]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Tags to add to the chain run.
|
||||
metadata: Metadata to add to the chain run.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if "reference" not in inputs:
|
||||
inputs["reference"] = ""
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_only_outputs=return_only_outputs,
|
||||
callbacks=callbacks,
|
||||
tags=tags,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
||||
return super().prep_inputs(inputs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -298,7 +268,10 @@ The following is the expected answer. Use this to measure correctness:
|
||||
chain_input = {**inputs}
|
||||
if self.agent_tools:
|
||||
chain_input["tool_descriptions"] = self._tools_description
|
||||
raw_output = self.eval_chain.run(chain_input)
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
raw_output = self.eval_chain.run(
|
||||
chain_input, callbacks=_run_manager.get_child()
|
||||
)
|
||||
parsed_output = self.output_parser.parse(raw_output)
|
||||
|
||||
if self.return_reasoning:
|
||||
@ -324,7 +297,10 @@ The following is the expected answer. Use this to measure correctness:
|
||||
chain_input = {**inputs}
|
||||
if self.agent_tools:
|
||||
chain_input["tool_descriptions"] = self._tools_description
|
||||
raw_output = await self.eval_chain.arun(chain_input)
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
raw_output = await self.eval_chain.arun(
|
||||
chain_input, callbacks=_run_manager.get_child()
|
||||
)
|
||||
parsed_output = self.output_parser.parse(raw_output)
|
||||
|
||||
if self.return_reasoning:
|
||||
@ -358,7 +334,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
"question": input,
|
||||
"agent_trajectory": self.get_agent_trajectory(agent_trajectory),
|
||||
"answer": prediction,
|
||||
"reference": self._format_reference(reference),
|
||||
"reference": reference,
|
||||
}
|
||||
return self(inputs=inputs, callbacks=callbacks, **kwargs)
|
||||
|
||||
@ -388,7 +364,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
"question": input,
|
||||
"agent_trajectory": self.get_agent_trajectory(agent_trajectory),
|
||||
"answer": prediction,
|
||||
"reference": self._format_reference(reference),
|
||||
"reference": reference,
|
||||
}
|
||||
return await self.acall(
|
||||
inputs=inputs,
|
||||
|
@ -1,14 +1,14 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence, Union
|
||||
|
||||
from langchainplus_sdk.evaluation import EvaluationResult
|
||||
from langchainplus_sdk.schemas import Example, Run, RunTypeEnum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.evaluation.agents.trajectory_eval_prompt import (
|
||||
EVAL_CHAT_PROMPT as TRAJECTORY_PROMPT,
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import (
|
||||
TrajectoryEvalChain,
|
||||
TrajectoryOutputParser,
|
||||
)
|
||||
from langchain.evaluation.criteria.eval_chain import (
|
||||
CriteriaEvalChain,
|
||||
@ -24,7 +24,7 @@ from langchain.evaluation.run_evaluators.base import (
|
||||
RunEvaluatorOutputParser,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
_QA_PROMPTS = {
|
||||
@ -185,7 +185,7 @@ def get_criteria_evaluator(
|
||||
)
|
||||
|
||||
|
||||
class TrajectoryEvalOutputParser(RunEvaluatorOutputParser):
|
||||
class TrajectoryRunEvalOutputParser(RunEvaluatorOutputParser, TrajectoryOutputParser):
|
||||
evaluation_name: str = "Agent Trajectory"
|
||||
"""The name assigned to the evaluation feedback."""
|
||||
evaluator_info: dict = Field(default_factory=dict)
|
||||
@ -195,29 +195,12 @@ class TrajectoryEvalOutputParser(RunEvaluatorOutputParser):
|
||||
def _type(self) -> str:
|
||||
return "agent_trajectory_run_eval"
|
||||
|
||||
def parse(self, text: str) -> EvaluationResult:
|
||||
if "Score:" not in text:
|
||||
raise OutputParserException(
|
||||
f"Could not find score in model eval output: {text}"
|
||||
)
|
||||
|
||||
reasoning, score_str = text.split("Score: ")
|
||||
|
||||
reasoning, score_str = reasoning.strip(), score_str.strip()
|
||||
|
||||
score_str = next(
|
||||
(char for char in score_str if char.isdigit()), "0"
|
||||
) # Scan for first digit
|
||||
|
||||
if not 1 <= int(score_str) <= 5:
|
||||
raise OutputParserException(
|
||||
f"Score is not a digit in the range 1-5: {text}"
|
||||
)
|
||||
|
||||
def parse_chain_output(self, output: Dict[str, Any]) -> EvaluationResult:
|
||||
"""Parse the output of a run."""
|
||||
return EvaluationResult(
|
||||
key=self.evaluation_name,
|
||||
score=int(score_str),
|
||||
comment=reasoning,
|
||||
score=int(output["score"]),
|
||||
comment=output["reasoning"],
|
||||
evaluator_info=self.evaluator_info,
|
||||
)
|
||||
|
||||
@ -225,8 +208,6 @@ class TrajectoryEvalOutputParser(RunEvaluatorOutputParser):
|
||||
class TrajectoryInputMapper(RunEvaluatorInputMapper, BaseModel):
|
||||
"""Maps the Run and Optional[Example] to a dictionary."""
|
||||
|
||||
tool_descriptions: List[str]
|
||||
"""The descriptions for each of the tools available to the agent."""
|
||||
agent_input_key: str = "input"
|
||||
"""The key to load from the agent executor's run input dictionary."""
|
||||
agent_output_key: str = "output"
|
||||
@ -235,6 +216,8 @@ class TrajectoryInputMapper(RunEvaluatorInputMapper, BaseModel):
|
||||
"""The key to load from the tool executor's run input dictionary."""
|
||||
tool_output_key: str = "output"
|
||||
"""The key to load from the tool executor's run output dictionary."""
|
||||
reference_output_key: Optional[str] = None
|
||||
"""The key to use for selecting the reference answer."""
|
||||
|
||||
def map(self, run: Run, example: Optional[Example] = None) -> Dict[str, str]:
|
||||
"""Maps the Run and Optional[Example] to a dictionary"""
|
||||
@ -242,6 +225,17 @@ class TrajectoryInputMapper(RunEvaluatorInputMapper, BaseModel):
|
||||
raise ValueError("Run must have child runs to be evaluated.")
|
||||
if run.outputs is None:
|
||||
raise ValueError("Run must have outputs to be evaluated.")
|
||||
reference = ""
|
||||
if example is not None and example.outputs:
|
||||
if self.reference_output_key is not None:
|
||||
reference = example.outputs[self.reference_output_key]
|
||||
elif "output" in example.outputs:
|
||||
reference = example.outputs["output"]
|
||||
elif len(example.outputs) == 1:
|
||||
reference = next(iter(example.outputs.values()))
|
||||
else:
|
||||
raise ValueError("Could not infer the reference answer from ")
|
||||
|
||||
question = run.inputs[self.agent_input_key]
|
||||
tool_runs = [
|
||||
run_ for run_ in run.child_runs if run_.run_type == RunTypeEnum.tool
|
||||
@ -261,33 +255,26 @@ Tool output: {tool_output}"""
|
||||
)
|
||||
|
||||
return {
|
||||
"tool_descriptions": "\n\n".join(self.tool_descriptions),
|
||||
"question": question,
|
||||
"agent_trajectory": "\n\n".join(agent_steps),
|
||||
"answer": run.outputs[self.agent_output_key],
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
|
||||
def get_trajectory_evaluator(
|
||||
llm: BaseChatModel,
|
||||
agent_tools: Union[Sequence[str], Sequence[BaseTool]],
|
||||
agent_tools: Sequence[BaseTool],
|
||||
*,
|
||||
input_key: str = "input",
|
||||
prediction_key: str = "output",
|
||||
tool_input_key: str = "input",
|
||||
tool_output_key: str = "output",
|
||||
prompt: BasePromptTemplate = TRAJECTORY_PROMPT,
|
||||
reference_output_key: Optional[str] = None,
|
||||
evaluation_name: str = "Agent Trajectory",
|
||||
**kwargs: Any,
|
||||
) -> RunEvaluatorChain:
|
||||
"""Get an eval chain for grading a model's response against a map of criteria."""
|
||||
tool_descriptions = [
|
||||
f"Tool {i}: {tool.name}\nDescription: {tool.description}"
|
||||
if isinstance(tool, BaseTool)
|
||||
else f"Tool {i}: {tool}"
|
||||
for i, tool in enumerate(agent_tools, 1)
|
||||
]
|
||||
|
||||
input_mapper = kwargs.pop(
|
||||
"input_mapper",
|
||||
TrajectoryInputMapper(
|
||||
@ -295,14 +282,16 @@ def get_trajectory_evaluator(
|
||||
agent_output_key=prediction_key,
|
||||
tool_input_key=tool_input_key,
|
||||
tool_output_key=tool_output_key,
|
||||
tool_descriptions=tool_descriptions,
|
||||
reference_output_key=reference_output_key,
|
||||
),
|
||||
)
|
||||
parser = kwargs.pop(
|
||||
"output_parser",
|
||||
TrajectoryEvalOutputParser(evaluation_name=evaluation_name),
|
||||
TrajectoryRunEvalOutputParser(evaluation_name=evaluation_name),
|
||||
)
|
||||
eval_chain = TrajectoryEvalChain.from_llm(
|
||||
llm=llm, agent_tools=agent_tools, return_reasoning=True, **kwargs
|
||||
)
|
||||
eval_chain = LLMChain(llm=llm, prompt=prompt, **kwargs)
|
||||
tags = kwargs.pop("tags", [])
|
||||
return RunEvaluatorChain(
|
||||
eval_chain=eval_chain,
|
||||
|
@ -28,6 +28,7 @@ _PARSERS_TO_SKIP = {
|
||||
"BaseOutputParser",
|
||||
"FinishedOutputParser",
|
||||
"RouterOutputParser",
|
||||
"TrajectoryRunEvalOutputParser",
|
||||
}
|
||||
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
||||
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
||||
|
Loading…
Reference in New Issue
Block a user