Refactored `formatting` (#8191)

Refactored `formatting.py`. The same as
https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099
formatting.py is in the root code folder. This creates the
`langchain.formatting: Formatting` group on the API Reference navigation
ToC, on the same level as Chains and Agents which is incorrect.

Refactoring:

- moved formatting.py content into utils/formatting.py
- I did not add the backwards compatibility ref in the original
formatting.py. It seems unnecessary.
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8103/head^2
Leonid Ganeline 1 year ago committed by GitHub
parent 4928f7a9f5
commit 848454d1e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,38 +1,4 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
"""DEPRECATED: Kept for backwards compatibility."""
from langchain.utils.formatting import StrictFormatter, formatter
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()
__all__ = ["StrictFormatter", "formatter"]

@ -5,10 +5,10 @@ import warnings
from abc import ABC
from typing import Any, Callable, Dict, List, Set
from langchain.formatting import formatter
from langchain.schema import BasePromptTemplate
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue
from langchain.utils import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:

@ -5,6 +5,7 @@ These functions do not depend on any other langchain modules.
"""
from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.formatting import StrictFormatter, formatter
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
@ -17,10 +18,12 @@ from langchain.utils.utils import (
)
__all__ = [
"StrictFormatter",
"check_package_version",
"comma_list",
"cosine_similarity",
"cosine_similarity_top_k",
"formatter",
"get_from_dict_or_env",
"get_from_env",
"get_pydantic_field_names",

@ -0,0 +1,38 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
"everything should be passed as keyword arguments."
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

@ -1,7 +1,7 @@
"""Test formatting functionality."""
import pytest
from langchain.formatting import formatter
from langchain.utils import formatter
def test_valid_formatting() -> None:

Loading…
Cancel
Save