mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Refactored formatting
(#8191)
Refactored `formatting.py`. The same as https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099 formatting.py is in the root code folder. This creates the `langchain.formatting: Formatting` group on the API Reference navigation ToC, on the same level as Chains and Agents which is incorrect. Refactoring: - moved formatting.py content into utils/formatting.py - I did not add the backwards compatibility ref in the original formatting.py. It seems unnecessary. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4928f7a9f5
commit
848454d1e7
@ -1,38 +1,4 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""DEPRECATED: Kept for backwards compatibility."""
|
||||||
from string import Formatter
|
from langchain.utils.formatting import StrictFormatter, formatter
|
||||||
from typing import Any, List, Mapping, Sequence, Union
|
|
||||||
|
|
||||||
|
__all__ = ["StrictFormatter", "formatter"]
|
||||||
class StrictFormatter(Formatter):
|
|
||||||
"""A subclass of formatter that checks for extra keys."""
|
|
||||||
|
|
||||||
def check_unused_args(
|
|
||||||
self,
|
|
||||||
used_args: Sequence[Union[int, str]],
|
|
||||||
args: Sequence,
|
|
||||||
kwargs: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""Check to see if extra parameters are passed."""
|
|
||||||
extra = set(kwargs).difference(used_args)
|
|
||||||
if extra:
|
|
||||||
raise KeyError(extra)
|
|
||||||
|
|
||||||
def vformat(
|
|
||||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Check that no arguments are provided."""
|
|
||||||
if len(args) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"No arguments should be provided, "
|
|
||||||
"everything should be passed as keyword arguments."
|
|
||||||
)
|
|
||||||
return super().vformat(format_string, args, kwargs)
|
|
||||||
|
|
||||||
def validate_input_variables(
|
|
||||||
self, format_string: str, input_variables: List[str]
|
|
||||||
) -> None:
|
|
||||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
|
||||||
super().format(format_string, **dummy_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
formatter = StrictFormatter()
|
|
||||||
|
@ -5,10 +5,10 @@ import warnings
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Callable, Dict, List, Set
|
from typing import Any, Callable, Dict, List, Set
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.utils import formatter
|
||||||
|
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
|
@ -5,6 +5,7 @@ These functions do not depend on any other langchain modules.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
||||||
|
from langchain.utils.formatting import StrictFormatter, formatter
|
||||||
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||||
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
||||||
from langchain.utils.utils import (
|
from langchain.utils.utils import (
|
||||||
@ -17,10 +18,12 @@ from langchain.utils.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"StrictFormatter",
|
||||||
"check_package_version",
|
"check_package_version",
|
||||||
"comma_list",
|
"comma_list",
|
||||||
"cosine_similarity",
|
"cosine_similarity",
|
||||||
"cosine_similarity_top_k",
|
"cosine_similarity_top_k",
|
||||||
|
"formatter",
|
||||||
"get_from_dict_or_env",
|
"get_from_dict_or_env",
|
||||||
"get_from_env",
|
"get_from_env",
|
||||||
"get_pydantic_field_names",
|
"get_pydantic_field_names",
|
||||||
|
38
libs/langchain/langchain/utils/formatting.py
Normal file
38
libs/langchain/langchain/utils/formatting.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Utilities for formatting strings."""
|
||||||
|
from string import Formatter
|
||||||
|
from typing import Any, List, Mapping, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
class StrictFormatter(Formatter):
|
||||||
|
"""A subclass of formatter that checks for extra keys."""
|
||||||
|
|
||||||
|
def check_unused_args(
|
||||||
|
self,
|
||||||
|
used_args: Sequence[Union[int, str]],
|
||||||
|
args: Sequence,
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Check to see if extra parameters are passed."""
|
||||||
|
extra = set(kwargs).difference(used_args)
|
||||||
|
if extra:
|
||||||
|
raise KeyError(extra)
|
||||||
|
|
||||||
|
def vformat(
|
||||||
|
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""Check that no arguments are provided."""
|
||||||
|
if len(args) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"No arguments should be provided, "
|
||||||
|
"everything should be passed as keyword arguments."
|
||||||
|
)
|
||||||
|
return super().vformat(format_string, args, kwargs)
|
||||||
|
|
||||||
|
def validate_input_variables(
|
||||||
|
self, format_string: str, input_variables: List[str]
|
||||||
|
) -> None:
|
||||||
|
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||||
|
super().format(format_string, **dummy_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
formatter = StrictFormatter()
|
@ -1,7 +1,7 @@
|
|||||||
"""Test formatting functionality."""
|
"""Test formatting functionality."""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.utils import formatter
|
||||||
|
|
||||||
|
|
||||||
def test_valid_formatting() -> None:
|
def test_valid_formatting() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user