From 2aaf86ddae1363332c58eb11462aed8fbb310b75 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 10 Jun 2024 14:00:12 -0700 Subject: [PATCH] core: fix mustache falsy cases (#22747) --- libs/core/langchain_core/utils/mustache.py | 2 + .../tests/unit_tests/prompts/test_prompt.py | 54 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 06258375f5..09d0284ece 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -353,6 +353,8 @@ def _get_key( if scope._CHEVRON_return_scope_when_falsy: # type: ignore return scope except AttributeError: + if scope in (0, False): + return scope return scope or "" except (AttributeError, KeyError, IndexError, ValueError): # We couldn't find the key in the current scope diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index b68316986b..3bb062d6ac 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,5 +1,6 @@ """Test functionality related to prompts.""" +from typing import Any, Dict, Union from unittest import mock import pytest @@ -499,3 +500,56 @@ async def test_prompt_ainvoke_with_metadata() -> None: assert len(tracer.traced_runs) == 1 assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore + + +@pytest.mark.parametrize( + "value, expected", + [ + ("0", "0"), + (0, "0"), + (0.0, "0.0"), + (False, "False"), + ("", ""), + ( + None, + { + "mustache": "", + "f-string": "None", + }, + ), + ( + [], + { + "mustache": "", + "f-string": "[]", + }, + ), + ( + {}, + { + "mustache": "", + "f-string": "{}", + }, + ), + ], +) +@pytest.mark.parametrize("template_format", ["f-string", "mustache"]) +def test_prompt_falsy_vars( + template_format: str, value: Any, expected: Union[str, Dict[str, str]] +) -> None: + # each line is value, f-string, mustache + if template_format == "f-string": + template = "{my_var}" + elif template_format == "mustache": + template = "{{my_var}}" + else: + raise ValueError(f"Invalid template format: {template_format}") + + prompt = PromptTemplate.from_template(template, template_format=template_format) + + result = prompt.invoke({"my_var": value}) + + expected_output = ( + expected if not isinstance(expected, dict) else expected[template_format] + ) + assert result.to_string() == expected_output