From 848454d1e7fa5e492ee74fbef7a7b3fd95485f5d Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Mon, 24 Jul 2023 11:34:15 -0700 Subject: [PATCH] Refactored `formatting` (#8191) Refactored `formatting.py`. The same as https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099 formatting.py is in the root code folder. This creates the `langchain.formatting: Formatting` group on the API Reference navigation ToC, on the same level as Chains and Agents which is incorrect. Refactoring: - moved formatting.py content into utils/formatting.py - I did not add the backwards compatibility ref in the original formatting.py. It seems unnecessary. --------- Co-authored-by: Bagatur --- libs/langchain/langchain/formatting.py | 40 ++----------------- libs/langchain/langchain/prompts/base.py | 2 +- libs/langchain/langchain/utils/__init__.py | 3 ++ libs/langchain/langchain/utils/formatting.py | 38 ++++++++++++++++++ .../tests/unit_tests/test_formatting.py | 2 +- 5 files changed, 46 insertions(+), 39 deletions(-) create mode 100644 libs/langchain/langchain/utils/formatting.py diff --git a/libs/langchain/langchain/formatting.py b/libs/langchain/langchain/formatting.py index 3b3b597b08..ebb865c957 100644 --- a/libs/langchain/langchain/formatting.py +++ b/libs/langchain/langchain/formatting.py @@ -1,38 +1,4 @@ -"""Utilities for formatting strings.""" -from string import Formatter -from typing import Any, List, Mapping, Sequence, Union +"""DEPRECATED: Kept for backwards compatibility.""" +from langchain.utils.formatting import StrictFormatter, formatter - -class StrictFormatter(Formatter): - """A subclass of formatter that checks for extra keys.""" - - def check_unused_args( - self, - used_args: Sequence[Union[int, str]], - args: Sequence, - kwargs: Mapping[str, Any], - ) -> None: - """Check to see if extra parameters are passed.""" - extra = set(kwargs).difference(used_args) - if extra: - raise KeyError(extra) - - def vformat( - self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] - ) -> str: - """Check that no arguments are provided.""" - if len(args) > 0: - raise ValueError( - "No arguments should be provided, " - "everything should be passed as keyword arguments." - ) - return super().vformat(format_string, args, kwargs) - - def validate_input_variables( - self, format_string: str, input_variables: List[str] - ) -> None: - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} - super().format(format_string, **dummy_inputs) - - -formatter = StrictFormatter() +__all__ = ["StrictFormatter", "formatter"] diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index 0e39316488..d5426cca6c 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -5,10 +5,10 @@ import warnings from abc import ABC from typing import Any, Callable, Dict, List, Set -from langchain.formatting import formatter from langchain.schema import BasePromptTemplate from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.prompt import PromptValue +from langchain.utils import formatter def jinja2_formatter(template: str, **kwargs: Any) -> str: diff --git a/libs/langchain/langchain/utils/__init__.py b/libs/langchain/langchain/utils/__init__.py index e3db0ddcac..74b0bb87d7 100644 --- a/libs/langchain/langchain/utils/__init__.py +++ b/libs/langchain/langchain/utils/__init__.py @@ -5,6 +5,7 @@ These functions do not depend on any other langchain modules. """ from langchain.utils.env import get_from_dict_or_env, get_from_env +from langchain.utils.formatting import StrictFormatter, formatter from langchain.utils.math import cosine_similarity, cosine_similarity_top_k from langchain.utils.strings import comma_list, stringify_dict, stringify_value from langchain.utils.utils import ( @@ -17,10 +18,12 @@ from langchain.utils.utils import ( ) __all__ = [ + "StrictFormatter", "check_package_version", "comma_list", "cosine_similarity", "cosine_similarity_top_k", + "formatter", "get_from_dict_or_env", "get_from_env", "get_pydantic_field_names", diff --git a/libs/langchain/langchain/utils/formatting.py b/libs/langchain/langchain/utils/formatting.py new file mode 100644 index 0000000000..3b3b597b08 --- /dev/null +++ b/libs/langchain/langchain/utils/formatting.py @@ -0,0 +1,38 @@ +"""Utilities for formatting strings.""" +from string import Formatter +from typing import Any, List, Mapping, Sequence, Union + + +class StrictFormatter(Formatter): + """A subclass of formatter that checks for extra keys.""" + + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: + """Check to see if extra parameters are passed.""" + extra = set(kwargs).difference(used_args) + if extra: + raise KeyError(extra) + + def vformat( + self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] + ) -> str: + """Check that no arguments are provided.""" + if len(args) > 0: + raise ValueError( + "No arguments should be provided, " + "everything should be passed as keyword arguments." + ) + return super().vformat(format_string, args, kwargs) + + def validate_input_variables( + self, format_string: str, input_variables: List[str] + ) -> None: + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + super().format(format_string, **dummy_inputs) + + +formatter = StrictFormatter() diff --git a/libs/langchain/tests/unit_tests/test_formatting.py b/libs/langchain/tests/unit_tests/test_formatting.py index 168e580b7b..482615608e 100644 --- a/libs/langchain/tests/unit_tests/test_formatting.py +++ b/libs/langchain/tests/unit_tests/test_formatting.py @@ -1,7 +1,7 @@ """Test formatting functionality.""" import pytest -from langchain.formatting import formatter +from langchain.utils import formatter def test_valid_formatting() -> None: