core[patch]: fix deprecation pydantic bug (#25204)

#25004 is incompatible with pydantic < 1.10.17. Introduces fix for this.
This commit is contained in:
Bagatur 2024-08-08 16:39:38 -07:00 committed by GitHub
parent dc7423e88f
commit 7040013140
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 24 deletions

View File

@ -30,7 +30,8 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API
T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
# Last Any should be FieldInfoV1 but this leads to circular imports
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
def _validate_deprecation_params(
@ -133,7 +134,7 @@ def deprecated(
_package: str = package,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
from pydantic.v1.fields import FieldInfo # pydantic: ignore
from langchain_core.utils.pydantic import FieldInfoV1
def emit_warning() -> None:
"""Emit the warning."""
@ -208,9 +209,7 @@ def deprecated(
)
return cast(T, obj)
elif isinstance(obj, FieldInfo):
from langchain_core.pydantic_v1 import Field
elif isinstance(obj, FieldInfoV1):
wrapped = None
if not _obj_type:
_obj_type = "attribute"
@ -219,58 +218,64 @@ def deprecated(
old_doc = obj.description
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
return Field(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,
alias=obj.alias,
exclude=obj.exclude,
return cast(
T,
FieldInfoV1(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,
alias=obj.alias,
exclude=obj.exclude,
),
)
elif isinstance(obj, property):
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or obj.fget.__qualname__
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__
class _deprecated_property(property):
"""A deprecated property."""
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def]
super().__init__(fget, fset, fdel, doc)
self.__orig_fget = fget
self.__orig_fset = fset
self.__orig_fdel = fdel
def __get__(self, instance, owner=None):
def __get__(self, instance, owner=None): # type: ignore[no-untyped-def]
if instance is not None or owner is not None:
emit_warning()
return self.fget(instance)
def __set__(self, instance, value):
def __set__(self, instance, value): # type: ignore[no-untyped-def]
if instance is not None:
emit_warning()
return self.fset(instance, value)
def __delete__(self, instance):
def __delete__(self, instance): # type: ignore[no-untyped-def]
if instance is not None:
emit_warning()
return self.fdel(instance)
def __set_name__(self, owner, set_name):
def __set_name__(self, owner, set_name): # type: ignore[no-untyped-def]
nonlocal _name
if _name == "<lambda>":
_name = set_name
def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any:
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the property."""
return _deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
return cast(
T,
_deprecated_property(
fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc
),
)
else:
_name = _name or obj.__qualname__
_name = _name or cast(Union[Type, Callable], obj).__qualname__
if not _obj_type:
# edge case: when a function is within another function
# within a test, this will call it a "method" not a "function"

View File

@ -26,9 +26,13 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic.fields import FieldInfo as FieldInfoV1
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
@ -272,7 +276,6 @@ if PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.fields import FieldInfo as FieldInfoV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1.fields import FieldInfo as FieldInfoV1
@overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
@ -304,11 +307,10 @@ if PYDANTIC_MAJOR_VERSION == 2:
raise TypeError(f"Expected a Pydantic model. Got {type(model)}")
elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
from pydantic.fields import FieldInfo as FieldInfoV1_
def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1_]:
) -> Dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else: