core[patch]: fix beta, deprecated typing (#18877)

**Description:** 

While not technically incorrect, the TypeVar used for the `@beta`
decorator prevented pyright (and thus most vscode users) from correctly
seeing the types of functions/classes decorated with `@beta`.

This is in part due to a small bug in pyright
(https://github.com/microsoft/pyright/issues/7448 ) - however, the
`Type` bound in the typevar `C = TypeVar("C", Type, Callable)` is not
doing anything - classes are `Callables` by default, so by my
understanding binding to `Type` does not actually provide any more
safety - the modified annotation still works correctly for both
functions, properties, and classes.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/18444/head^2
Luca Dorigo 3 months ago committed by GitHub
parent 263ee78886
commit f19229c564
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -403,7 +403,7 @@ class _RedisCacheBase(BaseCache, ABC):
if results:
for _, text in results.items():
try:
generations.append(loads(text))
generations.append(loads(cast(str, text)))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "

@ -9,11 +9,12 @@ https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecati
This module is for internal use only. Do not use it in your own code.
We may change the API at any time with no warning.
"""
import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal
@ -25,7 +26,7 @@ class LangChainBetaWarning(DeprecationWarning):
# PUBLIC API
T = TypeVar("T", Type, Callable)
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
def beta(
@ -143,7 +144,7 @@ def beta(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance
)
return obj
return cast(T, obj)
elif isinstance(obj, property):
if not _obj_type:
@ -202,7 +203,7 @@ def beta(
"""
wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc
return wrapper
return cast(T, wrapper)
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
@ -225,9 +226,10 @@ def beta(
)
if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc)
finalized = finalize(awarning_emitting_wrapper, new_doc)
else:
return finalize(warning_emitting_wrapper, new_doc)
finalized = finalize(warning_emitting_wrapper, new_doc)
return cast(T, finalized)
return beta

@ -14,7 +14,7 @@ import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal
@ -30,7 +30,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# PUBLIC API
T = TypeVar("T", Type, Callable)
T = TypeVar("T", bound=Union[Type, Callable[..., Any]])
def deprecated(
@ -182,7 +182,7 @@ def deprecated(
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
warn_if_direct_instance
)
return obj
return cast(T, obj)
elif isinstance(obj, property):
if not _obj_type:
@ -241,7 +241,7 @@ def deprecated(
"""
wrapper = functools.wraps(wrapped)(wrapper)
wrapper.__doc__ = new_doc
return wrapper
return cast(T, wrapper)
old_doc = inspect.cleandoc(old_doc or "").strip("\n")
@ -267,9 +267,10 @@ def deprecated(
)
if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc)
finalized = finalize(awarning_emitting_wrapper, new_doc)
else:
return finalize(warning_emitting_wrapper, new_doc)
finalized = finalize(warning_emitting_wrapper, new_doc)
return cast(T, finalized)
return deprecate

@ -308,7 +308,7 @@ def convert_to_openai_function(
elif isinstance(function, type) and issubclass(function, BaseModel):
return cast(Dict, convert_pydantic_to_openai_function(function))
elif isinstance(function, BaseTool):
return format_tool_to_openai_function(function)
return cast(Dict, format_tool_to_openai_function(function))
elif callable(function):
return convert_python_function_to_openai_function(function)
else:

@ -23,7 +23,9 @@ def _fake_runnable(
class FakeStructuredChatModel(FakeListChatModel):
"""Fake ChatModel for testing purposes."""
def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable:
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema))
@property

@ -1,5 +1,5 @@
import asyncio
from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Type, cast
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
@ -93,9 +93,14 @@ def optional_enum_field(
return Field(..., description=description + additional_info, **field_kwargs)
class _Graph(BaseModel):
nodes: Optional[List]
relationships: Optional[List]
def create_simple_model(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
) -> Any:
) -> Type[_Graph]:
"""
Simple model allows to limit node and/or relationship types.
Doesn't have any node or relationship properties.
@ -128,7 +133,7 @@ def create_simple_model(
rel_types, description="The type of the relationship.", is_rel=True
)
class DynamicGraph(BaseModel):
class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships."""
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes")
@ -194,7 +199,7 @@ class LLMGraphTransformer:
llm: BaseLanguageModel,
allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [],
prompt: Optional[ChatPromptTemplate] = default_prompt,
prompt: ChatPromptTemplate = default_prompt,
strict_mode: bool = True,
) -> None:
if not hasattr(llm, "with_structured_output"):
@ -217,7 +222,7 @@ class LLMGraphTransformer:
an LLM based on the model's schema and constraints.
"""
text = document.page_content
raw_schema = self.chain.invoke({"input": text})
raw_schema = cast(_Graph, self.chain.invoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]
if raw_schema.nodes
@ -268,7 +273,7 @@ class LLMGraphTransformer:
graph document.
"""
text = document.page_content
raw_schema = await self.chain.ainvoke({"input": text})
raw_schema = cast(_Graph, await self.chain.ainvoke({"input": text}))
nodes = (
[map_to_base_node(node) for node in raw_schema.nodes]

Loading…
Cancel
Save