Normalize Trajectory Eval Score (#7668)

This commit is contained in:
William FH 2023-07-13 09:58:28 -07:00 committed by GitHub
parent 5f03cc3511
commit aab2a7cd4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 13 deletions

View File

@ -37,7 +37,7 @@ name of the dataset to load.
- Grading the accuracy of a response against ground truth answers: :class:`QAEvalChain <langchain.evaluation.qa.eval_chain.QAEvalChain>` - Grading the accuracy of a response against ground truth answers: :class:`QAEvalChain <langchain.evaluation.qa.eval_chain.QAEvalChain>`
- Comparing the output of two models: :class:`PairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.PairwiseStringEvalChain>` or :class:`LabeledPairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.LabeledPairwiseStringEvalChain>` when there is additionally a reference label. - Comparing the output of two models: :class:`PairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.PairwiseStringEvalChain>` or :class:`LabeledPairwiseStringEvalChain <langchain.evaluation.comparison.eval_chain.LabeledPairwiseStringEvalChain>` when there is additionally a reference label.
- Judging the efficacy of an agent's tool usage: :class:`TrajectoryEvalChain <langchain.evaluation.agents.trajectory_eval_chain.TrajectoryEvalChain>` - Judging the efficacy of an agent's tool usage: :class:`TrajectoryEvalChain <langchain.evaluation.agents.trajectory_eval_chain.TrajectoryEvalChain>`
- Checking whether an output complies with a set of criteria: :class:`CriteriaEvalChain <langchain.evaluation.criteria.eval_chain.CriteriaEvalChain>` - Checking whether an output complies with a set of criteria: :class:`CriteriaEvalChain <langchain.evaluation.criteria.eval_chain.CriteriaEvalChain>` or :class:`LabeledCriteriaEvalChain <langchain.evaluation.criteria.eval_chain.LabeledCriteriaEvalChain>` when there is additionally a reference label.
- Computing semantic difference between a prediction and reference: :class:`EmbeddingDistanceEvalChain <langchain.evaluation.embedding_distance.base.EmbeddingDistanceEvalChain>` or between two predictions: :class:`PairwiseEmbeddingDistanceEvalChain <langchain.evaluation.embedding_distance.base.PairwiseEmbeddingDistanceEvalChain>` - Computing semantic difference between a prediction and reference: :class:`EmbeddingDistanceEvalChain <langchain.evaluation.embedding_distance.base.EmbeddingDistanceEvalChain>` or between two predictions: :class:`PairwiseEmbeddingDistanceEvalChain <langchain.evaluation.embedding_distance.base.PairwiseEmbeddingDistanceEvalChain>`
- Measuring the string distance between a prediction and reference :class:`StringDistanceEvalChain <langchain.evaluation.string_distance.base.StringDistanceEvalChain>` or between two predictions :class:`PairwiseStringDistanceEvalChain <langchain.evaluation.string_distance.base.PairwiseStringDistanceEvalChain>` - Measuring the string distance between a prediction and reference :class:`StringDistanceEvalChain <langchain.evaluation.string_distance.base.StringDistanceEvalChain>` or between two predictions :class:`PairwiseStringDistanceEvalChain <langchain.evaluation.string_distance.base.PairwiseStringDistanceEvalChain>`

View File

@ -27,8 +27,12 @@ from langchain.tools.base import BaseTool
class TrajectoryEval(NamedTuple): class TrajectoryEval(NamedTuple):
score: int """A named tuple containing the score and reasoning for a trajectory."""
score: float
"""The score for the trajectory, normalized from 0 to 1.s"""
reasoning: str reasoning: str
"""The reasoning for the score."""
class TrajectoryOutputParser(BaseOutputParser): class TrajectoryOutputParser(BaseOutputParser):
@ -43,11 +47,11 @@ class TrajectoryOutputParser(BaseOutputParser):
text (str): The output text to parse. text (str): The output text to parse.
Returns: Returns:
TrajectoryEval: A named tuple containing the score and reasoning. TrajectoryEval: A named tuple containing the normalized score and reasoning.
Raises: Raises:
OutputParserException: If the score is not found in the output text or OutputParserException: If the score is not found in the output text or
if the score is not a digit in the range 1-5. if the LLM's score is not a digit in the range 1-5.
""" """
if "Score:" not in text: if "Score:" not in text:
raise OutputParserException( raise OutputParserException(
@ -66,8 +70,8 @@ class TrajectoryOutputParser(BaseOutputParser):
raise OutputParserException( raise OutputParserException(
f"Score is not a digit in the range 1-5: {text}" f"Score is not a digit in the range 1-5: {text}"
) )
normalized_score = (int(score_str) - 1) / 4
return TrajectoryEval(score=int(score_str), reasoning=reasoning) return TrajectoryEval(score=normalized_score, reasoning=reasoning)
class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain): class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
@ -90,7 +94,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
\"\"\"Very helpful answers to geography questions.\"\"\" \"\"\"Very helpful answers to geography questions.\"\"\"
return f"{country}? IDK - We may never know {question}." return f"{country}? IDK - We may never know {question}."
llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent = initialize_agent( agent = initialize_agent(
tools=[geography_answers], tools=[geography_answers],
llm=llm, llm=llm,

View File

@ -70,7 +70,7 @@ def test_trajectory_eval_chain(
agent_trajectory=intermediate_steps, agent_trajectory=intermediate_steps,
prediction="I like pie.", prediction="I like pie.",
) )
assert res["score"] == 5 assert res["score"] == 1.0
# Test when ref is provided # Test when ref is provided
res = chain.evaluate_agent_trajectory( res = chain.evaluate_agent_trajectory(
input="What is your favorite food?", input="What is your favorite food?",
@ -78,7 +78,7 @@ def test_trajectory_eval_chain(
prediction="I like pie.", prediction="I like pie.",
reference="Paris", reference="Paris",
) )
assert res["score"] == 1 assert res["score"] == 0.0
def test_trajectory_eval_chain_no_tools( def test_trajectory_eval_chain_no_tools(
@ -97,14 +97,14 @@ def test_trajectory_eval_chain_no_tools(
agent_trajectory=intermediate_steps, agent_trajectory=intermediate_steps,
prediction="I like pie.", prediction="I like pie.",
) )
assert res["score"] == 5 assert res["score"] == 1.0
res = chain.evaluate_agent_trajectory( res = chain.evaluate_agent_trajectory(
input="What is your favorite food?", input="What is your favorite food?",
agent_trajectory=intermediate_steps, agent_trajectory=intermediate_steps,
prediction="I like pie.", prediction="I like pie.",
reference="Paris", reference="Paris",
) )
assert res["score"] == 1 assert res["score"] == 0.0
def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None: def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> None:
@ -123,7 +123,7 @@ def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> Non
"answer": "I like pie.", "answer": "I like pie.",
} }
) )
assert res["score"] == 5 assert res["score"] == 1.0
res = chain( res = chain(
{ {
@ -133,4 +133,4 @@ def test_old_api_works(intermediate_steps: List[Tuple[AgentAction, str]]) -> Non
"reference": "Paris", "reference": "Paris",
} }
) )
assert res["score"] == 1 assert res["score"] == 0.0