diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index a09dfd3e9c..d2993c13ba 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -13,6 +13,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Type, Union, @@ -29,8 +30,9 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.schema import ChatGeneration, ChatResult +from langchain.schema.language_model import LanguageModelInput from langchain.schema.messages import ( AIMessageChunk, BaseMessage, @@ -41,11 +43,13 @@ from langchain.schema.messages import ( SystemMessageChunk, ) from langchain.schema.output import ChatGenerationChunk +from langchain.schema.runnable import Runnable from langchain.utils import get_from_dict_or_env, get_pydantic_field_names if TYPE_CHECKING: import tiktoken + logger = logging.getLogger(__name__) @@ -540,3 +544,45 @@ class ChatOpenAI(BaseChatModel): # every reply is primed with assistant num_tokens += 3 return num_tokens + + def bind_functions( + self, + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + function_call: Optional[str] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind functions (and other objects) to this chat model. + + Args: + functions: A list of function definitions to bind to this chat model. + Can be a dictionary, pydantic model, or callable. Pydantic + models and callables will be automatically converted to + their schema dictionary representation. + function_call: Which function to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any). + kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + from langchain.chains.openai_functions.base import convert_to_openai_function + + formatted_functions = [convert_to_openai_function(fn) for fn in functions] + function_call_ = None + if function_call is not None: + if len(formatted_functions) != 1: + raise ValueError( + "When specifying `function_call`, you must provide exactly one " + "function." + ) + if formatted_functions[0]["name"] != function_call: + raise ValueError( + f"Function call {function_call} was specified, but the only " + f"provided function was {formatted_functions[0]['name']}." + ) + function_call_ = {"name": function_call} + kwargs = {**kwargs, "function_call": function_call_} + return super().bind( + functions=formatted_functions, + **kwargs, + ) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index 5c8b0e43e6..e1da41c384 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -9,7 +9,9 @@ from langchain.chains.openai_functions import ( create_openai_fn_chain, ) from langchain.chat_models.openai import ChatOpenAI +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain.pydantic_v1 import BaseModel, Field from langchain.schema import ( ChatGeneration, ChatResult, @@ -297,6 +299,46 @@ async def test_async_chat_openai_streaming_with_function() -> None: assert all([chunk is not None for chunk in callback_handler._captured_chunks]) +@pytest.mark.scheduled +@pytest.mark.asyncio +async def test_async_chat_openai_bind_functions() -> None: + """Test ChatOpenAI wrapper with multiple completions.""" + + class Person(BaseModel): + """Identifying information about a person.""" + + name: str = Field(..., title="Name", description="The person's name") + age: int = Field(..., title="Age", description="The person's age") + fav_food: Optional[str] = Field( + default=None, title="Fav Food", description="The person's favorite food" + ) + + chat = ChatOpenAI( + max_tokens=30, + n=1, + streaming=True, + ).bind_functions(functions=[Person], function_call="Person") + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "Use the provided Person function"), + ("user", "{input}"), + ] + ) + + chain = prompt | chat | JsonOutputFunctionsParser(args_only=True) + + message = HumanMessage(content="Sally is 13 years old") + response = await chain.abatch([{"input": message}]) + + assert isinstance(response, list) + assert len(response) == 1 + for generation in response: + assert isinstance(generation, dict) + assert "name" in generation + assert "age" in generation + + def test_chat_openai_extra_kwargs() -> None: """Test extra kwargs to chat openai.""" # Check that foo is saved in extra_kwargs.