bind_functions convenience method (#12518)

I always take 20-30 seconds to re-discover where the
`convert_to_openai_function` wrapper lives in our codebase. Chat
langchain [has no
clue](https://smith.langchain.com/public/3989d687-18c7-4108-958e-96e88803da86/r)
what to do either. There's the older `create_openai_fn_chain` , but we
haven't been recommending it in LCEL. The example we show in the
[cookbook](https://python.langchain.com/docs/expression_language/how_to/binding#attaching-openai-functions)
is really verbose.


General function calling should be as simple as possible to do, so this
seems a bit more ergonomic to me (feel free to disagree). Another option
would be to directly coerce directly in the class's init (or when
calling invoke), if provided. I'm not 100% set against that. That
approach may be too easy but not simple. This PR feels like a decent
compromise between simple and easy.

```
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field


class Category(str, Enum):
    """The category of the issue."""

    bug = "bug"
    nit = "nit"
    improvement = "improvement"
    other = "other"


class IssueClassification(BaseModel):
    """Classify an issue."""

    category: Category
    other_description: Optional[str] = Field(
        description="If classified as 'other', the suggested other category"
    )
    

from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI().bind_functions([IssueClassification])
llm.invoke("This PR adds a convenience wrapper to the bind argument")

# AIMessage(content='', additional_kwargs={'function_call': {'name': 'IssueClassification', 'arguments': '{\n  "category": "improvement"\n}'}})
```
pull/12634/head
William FH 11 months ago committed by GitHub
parent 3143324984
commit bfd719f9d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
@ -29,8 +30,9 @@ from langchain.chat_models.base import (
_generate_from_stream,
)
from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import (
AIMessageChunk,
BaseMessage,
@ -41,11 +43,13 @@ from langchain.schema.messages import (
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import Runnable
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
@ -540,3 +544,45 @@ class ChatOpenAI(BaseChatModel):
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
function_call: Optional[str] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model.
Args:
functions: A list of function definitions to bind to this chat model.
Can be a dictionary, pydantic model, or callable. Pydantic
models and callables will be automatically converted to
their schema dictionary representation.
function_call: Which function to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any).
kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
from langchain.chains.openai_functions.base import convert_to_openai_function
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
function_call_ = None
if function_call is not None:
if len(formatted_functions) != 1:
raise ValueError(
"When specifying `function_call`, you must provide exactly one "
"function."
)
if formatted_functions[0]["name"] != function_call:
raise ValueError(
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
function_call_ = {"name": function_call}
kwargs = {**kwargs, "function_call": function_call_}
return super().bind(
functions=formatted_functions,
**kwargs,
)

@ -9,7 +9,9 @@ from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import (
ChatGeneration,
ChatResult,
@ -297,6 +299,46 @@ async def test_async_chat_openai_streaming_with_function() -> None:
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_openai_bind_functions() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
class Person(BaseModel):
"""Identifying information about a person."""
name: str = Field(..., title="Name", description="The person's name")
age: int = Field(..., title="Age", description="The person's age")
fav_food: Optional[str] = Field(
default=None, title="Fav Food", description="The person's favorite food"
)
chat = ChatOpenAI(
max_tokens=30,
n=1,
streaming=True,
).bind_functions(functions=[Person], function_call="Person")
prompt = ChatPromptTemplate.from_messages(
[
("system", "Use the provided Person function"),
("user", "{input}"),
]
)
chain = prompt | chat | JsonOutputFunctionsParser(args_only=True)
message = HumanMessage(content="Sally is 13 years old")
response = await chain.abatch([{"input": message}])
assert isinstance(response, list)
assert len(response) == 1
for generation in response:
assert isinstance(generation, dict)
assert "name" in generation
assert "age" in generation
def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.

Loading…
Cancel
Save