Refactored formatting (#8191)

Refactored `formatting.py`. The same as
https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099
formatting.py is in the root code folder. This creates the
`langchain.formatting: Formatting` group on the API Reference navigation
ToC, on the same level as Chains and Agents which is incorrect.

Refactoring:

- moved formatting.py content into utils/formatting.py
- I did not add the backwards compatibility ref in the original
formatting.py. It seems unnecessary.
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Leonid Ganeline 2023-07-24 11:34:15 -07:00 committed by GitHub
parent 4928f7a9f5
commit 848454d1e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 39 deletions

View File

@ -1,38 +1,4 @@
"""Utilities for formatting strings.""" """DEPRECATED: Kept for backwards compatibility."""
from string import Formatter from langchain.utils.formatting import StrictFormatter, formatter
from typing import Any, List, Mapping, Sequence, Union
__all__ = ["StrictFormatter", "formatter"]
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

View File

@ -5,10 +5,10 @@ import warnings
from abc import ABC from abc import ABC
from typing import Any, Callable, Dict, List, Set from typing import Any, Callable, Dict, List, Set
from langchain.formatting import formatter
from langchain.schema import BasePromptTemplate from langchain.schema import BasePromptTemplate
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.utils import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:

View File

@ -5,6 +5,7 @@ These functions do not depend on any other langchain modules.
""" """
from langchain.utils.env import get_from_dict_or_env, get_from_env from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.formatting import StrictFormatter, formatter
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import ( from langchain.utils.utils import (
@ -17,10 +18,12 @@ from langchain.utils.utils import (
) )
__all__ = [ __all__ = [
"StrictFormatter",
"check_package_version", "check_package_version",
"comma_list", "comma_list",
"cosine_similarity", "cosine_similarity",
"cosine_similarity_top_k", "cosine_similarity_top_k",
"formatter",
"get_from_dict_or_env", "get_from_dict_or_env",
"get_from_env", "get_from_env",
"get_pydantic_field_names", "get_pydantic_field_names",

View File

@ -0,0 +1,38 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

View File

@ -1,7 +1,7 @@
"""Test formatting functionality.""" """Test formatting functionality."""
import pytest import pytest
from langchain.formatting import formatter from langchain.utils import formatter
def test_valid_formatting() -> None: def test_valid_formatting() -> None: