core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)

After this PR chat models will correctly handle pydantic 2 with
bind_tools and with_structured_output.


```python
import pydantic
print(pydantic.__version__)
```
2.8.2

```python
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}]
<class '__main__.Add'>
```


```python
from langchain_openai import ChatOpenAI
from pydantic.v1 import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```python
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}]
<class '__main__.Add'>
```

Addresses issues: https://github.com/langchain-ai/langchain/issues/22782

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Bagatur 2024-07-22 13:07:39 -07:00 committed by GitHub
parent 199e64d372
commit 236e957abb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 185 additions and 59 deletions

View File

@ -24,6 +24,7 @@ from langchain_core.output_parsers import (
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_community.output_parsers.ernie_functions import (
JsonOutputFunctionsParser,
@ -94,7 +95,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
for arg, arg_type in annotations.items():
if arg == "return":
continue
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
if isinstance(arg_type, type) and is_basemodel_subclass(arg_type):
# Mypy error:
# "type" has no attribute "schema"
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
@ -156,7 +157,7 @@ def convert_to_ernie_function(
"""
if isinstance(function, dict):
return function
elif isinstance(function, type) and issubclass(function, BaseModel):
elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_ernie_function(function))
elif callable(function):
return convert_python_function_to_ernie_function(function)
@ -185,7 +186,7 @@ def get_ernie_output_parser(
only the function arguments and not the function name.
"""
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions)

View File

@ -40,11 +40,17 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
logger = logging.getLogger(__name__)
@ -769,7 +775,7 @@ class QianfanChatEndpoint(BaseChatModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(

View File

@ -57,6 +57,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_community.utilities.requests import Requests
@ -443,7 +444,7 @@ class ChatEdenAI(BaseChatModel):
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
llm = self.bind_tools([schema], tool_choice="required")
if isinstance(schema, type) and issubclass(schema, BaseModel):
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)

View File

@ -46,10 +46,15 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
class ChatLlamaCpp(BaseChatModel):
@ -525,7 +530,7 @@ class ChatLlamaCpp(BaseChatModel):
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "

View File

@ -53,11 +53,16 @@ from langchain_core.outputs import (
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
@ -865,7 +870,7 @@ class ChatTongyi(BaseChatModel):
"""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(

View File

@ -55,11 +55,16 @@ from langchain_core.outputs import (
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.output_parsers.base import OutputParserLike
@ -1162,7 +1167,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"with_structured_output is not implemented for this model."
)
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and issubclass(schema, BaseModel):
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)

View File

@ -82,6 +82,7 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@ -300,7 +301,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
return root_type
return create_model(
@ -332,7 +333,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
return root_type
return create_model(

View File

@ -22,6 +22,7 @@ from typing import (
from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType
@ -229,7 +230,7 @@ def node_data_json(
"name": node_data_str(node.id, node.data),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
json = (
{
"type": "schema",

View File

@ -28,6 +28,7 @@ from langchain_core.messages import (
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
@ -100,7 +101,11 @@ def convert_pydantic_to_openai_function(
Returns:
The function description.
"""
schema = dereference_refs(model.schema())
if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2
else:
schema = model.schema() # Pydantic 1
schema = dereference_refs(schema)
schema.pop("definitions", None)
title = schema.pop("title", "")
default_description = schema.pop("description", "")
@ -272,7 +277,7 @@ def convert_to_openai_function(
"description": function.pop("description"),
"parameters": function,
}
elif isinstance(function, type) and issubclass(function, BaseModel):
elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_openai_function(function))
elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function))

View File

@ -8,12 +8,13 @@ from langchain_core.load.load import loads
from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
schema: Union[Dict, Type[BaseModel]], _: Any
) -> Union[BaseModel, Dict]:
if isclass(schema) and issubclass(schema, BaseModel):
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=42)
else:
params = cast(Dict, schema)["parameters"]

View File

@ -34,11 +34,14 @@ from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import (
BaseModel,
)
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables.base import RunnableMap
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
@ -75,14 +78,10 @@ _DictOrPydantic = Union[Dict, _BM]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and (
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
def _is_pydantic_object(obj: Any) -> bool:
return isinstance(obj, BaseModel)
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
@ -93,7 +92,7 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
schema = tool.tool_call_schema.schema()
name = tool.get_name()
description = tool.description
elif _is_pydantic_object(tool):
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description

View File

@ -1,11 +1,12 @@
import asyncio
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.utils.pydantic import is_basemodel_instance
class SyntheticDataGenerator(BaseModel):
@ -63,8 +64,10 @@ class SyntheticDataGenerator(BaseModel):
"""Prevents duplicates by adding previously generated examples to the few shot
list."""
if self.template and self.template.examples:
if isinstance(example, BaseModel):
formatted_example = self._format_dict_to_string(example.dict())
if is_basemodel_instance(example):
formatted_example = self._format_dict_to_string(
cast(BaseModel, example).dict()
)
elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example)
else:

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union, cast
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage
@ -10,6 +10,7 @@ from langchain_core.output_parsers.openai_functions import (
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
@ -45,7 +46,7 @@ def create_qa_with_structure_chain(
"""
if output_parser == "pydantic":
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
if not (isinstance(schema, type) and is_basemodel_subclass(schema)):
raise ValueError(
"Must provide a pydantic class for schema when output_parser is "
"'pydantic'."
@ -60,10 +61,10 @@ def create_qa_with_structure_chain(
f"Got unexpected output_parser: {output_parser}. "
f"Should be one of `pydantic` or `base`."
)
if isinstance(schema, type) and issubclass(schema, BaseModel):
schema_dict = schema.schema()
if isinstance(schema, type) and is_basemodel_subclass(schema):
schema_dict = cast(dict, schema.schema())
else:
schema_dict = schema
schema_dict = cast(dict, schema)
function = {
"name": schema_dict["title"],
"description": schema_dict["description"],

View File

@ -24,6 +24,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
@deprecated(
@ -465,7 +466,7 @@ def _get_openai_tool_output_parser(
*,
first_tool_only: bool = False,
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
if isinstance(tool, type) and issubclass(tool, BaseModel):
if isinstance(tool, type) and is_basemodel_subclass(tool):
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only)
)
@ -493,7 +494,7 @@ def get_openai_output_parser(
not a Pydantic class, then the output parser will automatically extract
only the function arguments and not the function name.
"""
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
convert_to_openai_function(fn)["name"]: fn for fn in functions
@ -516,7 +517,7 @@ def _create_openai_json_runnable(
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
) -> Runnable:
""""""
if isinstance(output_schema, type) and issubclass(output_schema, BaseModel):
if isinstance(output_schema, type) and is_basemodel_subclass(output_schema):
output_parser = output_parser or PydanticOutputParser(
pydantic_object=output_schema, # type: ignore
)

View File

@ -50,7 +50,12 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import (
Runnable,
RunnableMap,
@ -63,6 +68,7 @@ from langchain_core.utils import (
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_anthropic.output_parsers import extract_tool_calls
@ -994,7 +1000,7 @@ class ChatAnthropic(BaseChatModel):
tool_name = convert_to_anthropic_tool(schema)["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if isinstance(schema, type) and issubclass(schema, BaseModel):
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)

View File

@ -69,7 +69,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@ -81,6 +86,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -938,7 +944,7 @@ class ChatFireworks(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:

View File

@ -66,7 +66,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@ -78,6 +83,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
class ChatGroq(BaseChatModel):
@ -1053,7 +1059,7 @@ class ChatGroq(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
return isinstance(obj, type) and is_basemodel_subclass(obj)
class _FunctionCall(TypedDict):

View File

@ -388,7 +388,7 @@ def test_json_mode_structured_output() -> None:
result = chat.invoke(
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
)
assert type(result) == Joke
assert type(result) is Joke
assert len(result.setup) != 0
assert len(result.punchline) != 0

View File

@ -173,7 +173,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
assert type(res) == AIMessage
assert type(res) is AIMessage
assert completed
@ -195,7 +195,7 @@ async def test_groq_ainvoke(mock_completion: dict) -> None:
):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
assert type(res) == AIMessage
assert type(res) is AIMessage
assert completed

View File

@ -63,11 +63,17 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
logger = logging.getLogger(__name__)
@ -779,7 +785,7 @@ class ChatMistralAI(BaseChatModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(

View File

@ -36,6 +36,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai.chat_models.base import BaseChatOpenAI
@ -54,7 +55,7 @@ class _AllReturnType(TypedDict):
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
return isinstance(obj, type) and is_basemodel_subclass(obj)
class AzureChatOpenAI(BaseChatOpenAI):

View File

@ -86,6 +86,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -1765,7 +1766,7 @@ class ChatOpenAI(BaseChatOpenAI):
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:

View File

@ -1,9 +1,14 @@
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Dict, List, Set, Tuple, Union, cast
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
create_model,
)
from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_instance
@dataclass(frozen=True)
@ -160,8 +165,8 @@ def get_param_fields(endpoint_spec: dict) -> dict:
def model_to_dict(
item: Union[BaseModel, List, Dict[str, Any]],
) -> Any:
if isinstance(item, BaseModel):
return item.dict()
if is_basemodel_instance(item):
return cast(BaseModel, item).dict()
elif isinstance(item, dict):
return {key: model_to_dict(value) for key, value in item.items()}
elif isinstance(item, list):

View File

@ -1,20 +1,58 @@
"""Unit tests for chat models."""
from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel):
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class PersonA(BaseModelProper):
"""Record attributes of a person."""
name: str = FieldProper(..., description="The name of the person.")
age: int = FieldProper(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
if not self.has_tool_calling:
return
tool_model = model.bind_tools(
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
)
tools = [my_adder_tool, my_adder]
for pydantic_model in TEST_PYDANTIC_MODELS:
tools.extend([pydantic_model, pydantic_model.schema()])
# Doing a mypy ignore here since some of the tools are from pydantic
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
# so type checking does not become annoying to users.
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", [Person, Person.schema()])
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output(
self,
model: BaseChatModel,
@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str

View File

@ -0,0 +1,14 @@
"""Utilities for working with pydantic models."""
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()