mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)
After this PR chat models will correctly handle pydantic 2 with bind_tools and with_structured_output. ```python import pydantic print(pydantic.__version__) ``` 2.8.2 ```python from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ``` [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}] <class '__main__.Add'> ``` ```python from langchain_openai import ChatOpenAI from pydantic.v1 import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ```python [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}] <class '__main__.Add'> ``` Addresses issues: https://github.com/langchain-ai/langchain/issues/22782 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
199e64d372
commit
236e957abb
@ -24,6 +24,7 @@ from langchain_core.output_parsers import (
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
from langchain_community.output_parsers.ernie_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
@ -94,7 +95,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg == "return":
|
||||
continue
|
||||
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
|
||||
if isinstance(arg_type, type) and is_basemodel_subclass(arg_type):
|
||||
# Mypy error:
|
||||
# "type" has no attribute "schema"
|
||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
||||
@ -156,7 +157,7 @@ def convert_to_ernie_function(
|
||||
"""
|
||||
if isinstance(function, dict):
|
||||
return function
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
return cast(Dict, convert_pydantic_to_ernie_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_ernie_function(function)
|
||||
@ -185,7 +186,7 @@ def get_ernie_output_parser(
|
||||
only the function arguments and not the function name.
|
||||
"""
|
||||
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
|
||||
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
|
||||
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
name: fn for name, fn in zip(function_names, functions)
|
||||
|
@ -40,11 +40,17 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -769,7 +775,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
llm = self.bind_tools([schema])
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
|
@ -57,6 +57,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
@ -443,7 +444,7 @@ class ChatEdenAI(BaseChatModel):
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
llm = self.bind_tools([schema], tool_choice="required")
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
|
@ -46,10 +46,15 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
class ChatLlamaCpp(BaseChatModel):
|
||||
@ -525,7 +530,7 @@ class ChatLlamaCpp(BaseChatModel):
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
|
@ -53,11 +53,16 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@ -865,7 +870,7 @@ class ChatTongyi(BaseChatModel):
|
||||
"""
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
llm = self.bind_tools([schema])
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
|
@ -55,11 +55,16 @@ from langchain_core.outputs import (
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
@ -1162,7 +1167,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"with_structured_output is not implemented for this model."
|
||||
)
|
||||
llm = self.bind_tools([schema], tool_choice="any")
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
|
@ -82,6 +82,7 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
@ -300,7 +301,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
@ -332,7 +333,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
|
@ -22,6 +22,7 @@ from typing import (
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables.base import Runnable as RunnableType
|
||||
@ -229,7 +230,7 @@ def node_data_json(
|
||||
"name": node_data_str(node.id, node.data),
|
||||
},
|
||||
}
|
||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
||||
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
|
||||
json = (
|
||||
{
|
||||
"type": "schema",
|
||||
|
@ -28,6 +28,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -100,7 +101,11 @@ def convert_pydantic_to_openai_function(
|
||||
Returns:
|
||||
The function description.
|
||||
"""
|
||||
schema = dereference_refs(model.schema())
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema() # Pydantic 2
|
||||
else:
|
||||
schema = model.schema() # Pydantic 1
|
||||
schema = dereference_refs(schema)
|
||||
schema.pop("definitions", None)
|
||||
title = schema.pop("title", "")
|
||||
default_description = schema.pop("description", "")
|
||||
@ -272,7 +277,7 @@ def convert_to_openai_function(
|
||||
"description": function.pop("description"),
|
||||
"parameters": function,
|
||||
}
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
elif isinstance(function, BaseTool):
|
||||
return cast(Dict, format_tool_to_openai_function(function))
|
||||
|
@ -8,12 +8,13 @@ from langchain_core.load.load import loads
|
||||
from langchain_core.prompts.structured import StructuredPrompt
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
schema: Union[Dict, Type[BaseModel]], _: Any
|
||||
) -> Union[BaseModel, Dict]:
|
||||
if isclass(schema) and issubclass(schema, BaseModel):
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
return schema(name="yo", value=42)
|
||||
else:
|
||||
params = cast(Dict, schema)["parameters"]
|
||||
|
@ -34,11 +34,14 @@ from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||
|
||||
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
||||
|
||||
@ -75,14 +78,10 @@ _DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and (
|
||||
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
|
||||
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||
)
|
||||
|
||||
|
||||
def _is_pydantic_object(obj: Any) -> bool:
|
||||
return isinstance(obj, BaseModel)
|
||||
|
||||
|
||||
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||
"""Convert a tool to an Ollama tool."""
|
||||
description = None
|
||||
@ -93,7 +92,7 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||
schema = tool.tool_call_schema.schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
elif _is_pydantic_object(tool):
|
||||
elif is_basemodel_instance(tool):
|
||||
schema = tool.get_input_schema().schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
|
@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.utils.pydantic import is_basemodel_instance
|
||||
|
||||
|
||||
class SyntheticDataGenerator(BaseModel):
|
||||
@ -63,8 +64,10 @@ class SyntheticDataGenerator(BaseModel):
|
||||
"""Prevents duplicates by adding previously generated examples to the few shot
|
||||
list."""
|
||||
if self.template and self.template.examples:
|
||||
if isinstance(example, BaseModel):
|
||||
formatted_example = self._format_dict_to_string(example.dict())
|
||||
if is_basemodel_instance(example):
|
||||
formatted_example = self._format_dict_to_string(
|
||||
cast(BaseModel, example).dict()
|
||||
)
|
||||
elif isinstance(example, dict):
|
||||
formatted_example = self._format_dict_to_string(example)
|
||||
else:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
@ -10,6 +10,7 @@ from langchain_core.output_parsers.openai_functions import (
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
||||
@ -45,7 +46,7 @@ def create_qa_with_structure_chain(
|
||||
|
||||
"""
|
||||
if output_parser == "pydantic":
|
||||
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
|
||||
if not (isinstance(schema, type) and is_basemodel_subclass(schema)):
|
||||
raise ValueError(
|
||||
"Must provide a pydantic class for schema when output_parser is "
|
||||
"'pydantic'."
|
||||
@ -60,10 +61,10 @@ def create_qa_with_structure_chain(
|
||||
f"Got unexpected output_parser: {output_parser}. "
|
||||
f"Should be one of `pydantic` or `base`."
|
||||
)
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
schema_dict = schema.schema()
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
schema_dict = cast(dict, schema.schema())
|
||||
else:
|
||||
schema_dict = schema
|
||||
schema_dict = cast(dict, schema)
|
||||
function = {
|
||||
"name": schema_dict["title"],
|
||||
"description": schema_dict["description"],
|
||||
|
@ -24,6 +24,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -465,7 +466,7 @@ def _get_openai_tool_output_parser(
|
||||
*,
|
||||
first_tool_only: bool = False,
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
if isinstance(tool, type) and issubclass(tool, BaseModel):
|
||||
if isinstance(tool, type) and is_basemodel_subclass(tool):
|
||||
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
|
||||
PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only)
|
||||
)
|
||||
@ -493,7 +494,7 @@ def get_openai_output_parser(
|
||||
not a Pydantic class, then the output parser will automatically extract
|
||||
only the function arguments and not the function name.
|
||||
"""
|
||||
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
|
||||
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
||||
@ -516,7 +517,7 @@ def _create_openai_json_runnable(
|
||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||
) -> Runnable:
|
||||
""""""
|
||||
if isinstance(output_schema, type) and issubclass(output_schema, BaseModel):
|
||||
if isinstance(output_schema, type) and is_basemodel_subclass(output_schema):
|
||||
output_parser = output_parser or PydanticOutputParser(
|
||||
pydantic_object=output_schema, # type: ignore
|
||||
)
|
||||
|
@ -50,7 +50,12 @@ from langchain_core.output_parsers import (
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableMap,
|
||||
@ -63,6 +68,7 @@ from langchain_core.utils import (
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
from langchain_anthropic.output_parsers import extract_tool_calls
|
||||
|
||||
@ -994,7 +1000,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
tool_name = convert_to_anthropic_tool(schema)["name"]
|
||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
|
@ -69,7 +69,12 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
@ -81,6 +86,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -938,7 +944,7 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
|
||||
|
@ -66,7 +66,12 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
@ -78,6 +83,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
class ChatGroq(BaseChatModel):
|
||||
@ -1053,7 +1059,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
|
@ -388,7 +388,7 @@ def test_json_mode_structured_output() -> None:
|
||||
result = chat.invoke(
|
||||
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
||||
)
|
||||
assert type(result) == Joke
|
||||
assert type(result) is Joke
|
||||
assert len(result.setup) != 0
|
||||
assert len(result.punchline) != 0
|
||||
|
||||
|
@ -173,7 +173,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
|
||||
):
|
||||
res = llm.invoke("bar")
|
||||
assert res.content == "Bar Baz"
|
||||
assert type(res) == AIMessage
|
||||
assert type(res) is AIMessage
|
||||
assert completed
|
||||
|
||||
|
||||
@ -195,7 +195,7 @@ async def test_groq_ainvoke(mock_completion: dict) -> None:
|
||||
):
|
||||
res = await llm.ainvoke("bar")
|
||||
assert res.content == "Bar Baz"
|
||||
assert type(res) == AIMessage
|
||||
assert type(res) is AIMessage
|
||||
assert completed
|
||||
|
||||
|
||||
|
@ -63,11 +63,17 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -779,7 +785,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
|
@ -36,6 +36,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
@ -54,7 +55,7 @@ class _AllReturnType(TypedDict):
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
class AzureChatOpenAI(BaseChatOpenAI):
|
||||
|
@ -86,6 +86,7 @@ from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1765,7 +1766,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
|
@ -1,9 +1,14 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Set, Tuple, Union, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
create_model,
|
||||
)
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
from langchain_core.utils.pydantic import is_basemodel_instance
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -160,8 +165,8 @@ def get_param_fields(endpoint_spec: dict) -> dict:
|
||||
def model_to_dict(
|
||||
item: Union[BaseModel, List, Dict[str, Any]],
|
||||
) -> Any:
|
||||
if isinstance(item, BaseModel):
|
||||
return item.dict()
|
||||
if is_basemodel_instance(item):
|
||||
return cast(BaseModel, item).dict()
|
||||
elif isinstance(item, dict):
|
||||
return {key: model_to_dict(value) for key, value in item.items()}
|
||||
elif isinstance(item, list):
|
||||
|
@ -1,20 +1,58 @@
|
||||
"""Unit tests for chat models."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
class Person(BaseModel):
|
||||
|
||||
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
|
||||
def generate_schema_pydantic_v1_from_2() -> Any:
|
||||
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||
if PYDANTIC_MAJOR_VERSION != 2:
|
||||
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
class PersonB(BaseModel):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
return PersonB
|
||||
|
||||
|
||||
def generate_schema_pydantic() -> Any:
|
||||
"""Works with either pydantic 1 or 2"""
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class PersonA(BaseModelProper):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = FieldProper(..., description="The name of the person.")
|
||||
age: int = FieldProper(..., description="The age of the person.")
|
||||
|
||||
return PersonA
|
||||
|
||||
|
||||
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
|
||||
|
||||
|
||||
@tool
|
||||
def my_adder_tool(a: int, b: int) -> int:
|
||||
"""Takes two integers, a and b, and returns their sum."""
|
||||
@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
tool_model = model.bind_tools(
|
||||
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
|
||||
)
|
||||
tools = [my_adder_tool, my_adder]
|
||||
|
||||
for pydantic_model in TEST_PYDANTIC_MODELS:
|
||||
tools.extend([pydantic_model, pydantic_model.schema()])
|
||||
|
||||
# Doing a mypy ignore here since some of the tools are from pydantic
|
||||
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
|
||||
# so type checking does not become annoying to users.
|
||||
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
||||
@pytest.mark.parametrize("schema", [Person, Person.schema()])
|
||||
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
|
||||
def test_with_structured_output(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
assert model.with_structured_output(schema) is not None
|
||||
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
|
@ -0,0 +1,14 @@
|
||||
"""Utilities for working with pydantic models."""
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""Get the major version of Pydantic."""
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
return int(pydantic.__version__.split(".")[0])
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
|
||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
Loading…
Reference in New Issue
Block a user