diff --git a/langchain/chains/question_answering/map_rerank_prompt.py b/langchain/chains/question_answering/map_rerank_prompt.py index e73439541d..c8041c6c3a 100644 --- a/langchain/chains/question_answering/map_rerank_prompt.py +++ b/langchain/chains/question_answering/map_rerank_prompt.py @@ -3,7 +3,7 @@ from langchain.output_parsers.regex import RegexParser from langchain.prompts import PromptTemplate output_parser = RegexParser( - regex=r"(.*?)\nScore: (.*)", + regex=r"(.*?)\nScore: (\d*)", output_keys=["answer", "score"], ) diff --git a/tests/unit_tests/chains/question_answering/__init__.py b/tests/unit_tests/chains/question_answering/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py b/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py new file mode 100644 index 0000000000..61409af3bb --- /dev/null +++ b/tests/unit_tests/chains/question_answering/test_map_rerank_prompt.py @@ -0,0 +1,17 @@ +"""Test map_rerank parser""" +import pytest + +from langchain.chains.question_answering.map_rerank_prompt import output_parser + +GOOD_SCORE = "foo bar answer.\nScore: 80" +SCORE_WITH_EXPLANATION = "foo bar answer.\nScore: 80 (fully answers the question, but could provide more detail on the specific error message)" # noqa: E501 + + +@pytest.mark.parametrize("answer", (GOOD_SCORE, SCORE_WITH_EXPLANATION)) +def test_parse_scores(answer: str) -> None: + result = output_parser.parse(answer) + + assert result["answer"] == "foo bar answer." + + score = int(result["score"]) + assert score == 80