core[minor], langchain[patch]: tools dependencies refactoring (#18759)

The `langchain.tools`
[namespace](https://api.python.langchain.com/en/latest/langchain_api_reference.html#module-langchain.tools)
can be completely eliminated by moving one class and 3 functions into
`core`. It makes sense since the class and functions are very core.
This commit is contained in:
Leonid Ganeline 2024-04-16 11:15:09 -07:00 committed by GitHub
parent 77eba10f47
commit 45d045b2c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 136 additions and 124 deletions

View File

@ -23,6 +23,7 @@ import inspect
import uuid import uuid
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
@ -32,9 +33,17 @@ from langchain_core.callbacks import (
BaseCallbackManager, BaseCallbackManager,
CallbackManager, CallbackManager,
CallbackManagerForToolRun, CallbackManagerForToolRun,
)
from langchain_core.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import ( from langchain_core.pydantic_v1 import (
BaseModel, BaseModel,
Extra, Extra,
@ -44,6 +53,7 @@ from langchain_core.pydantic_v1 import (
root_validator, root_validator,
validate_arguments, validate_arguments,
) )
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable, Runnable,
RunnableConfig, RunnableConfig,
@ -920,3 +930,111 @@ def tool(
return _partial return _partial
else: else:
raise ValueError("Too many arguments for tool decorator") raise ValueError("Too many arguments for tool decorator")
class RetrieverInput(BaseModel):
"""Input to the retriever."""
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
async def _aget_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
[await aformat_document(doc, document_prompt) for doc in docs]
)
def create_retriever_tool(
retriever: BaseRetriever,
name: str,
description: str,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)
ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
.. code-block:: markdown
search: This tool is used for search
calculator: This tool is used for math
"""
return "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
def render_text_description_and_args(tools: List[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text.
Output will be in the format of:
.. code-block:: markdown
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
"""
tool_strings = []
for tool in tools:
args_schema = str(tool.args)
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
return "\n".join(tool_strings)

View File

@ -4,10 +4,13 @@ Depending on the LLM you are using and the prompting strategy you are using,
you may want Tools to be rendered in a different way. you may want Tools to be rendered in a different way.
This module contains various ways to render tools. This module contains various ways to render tools.
""" """
from typing import Callable, List
# For backwards compatibility # For backwards compatibility
from langchain_core.tools import BaseTool from langchain_core.tools import (
ToolsRenderer,
render_text_description,
render_text_description_and_args,
)
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
format_tool_to_openai_function, format_tool_to_openai_function,
format_tool_to_openai_tool, format_tool_to_openai_tool,
@ -20,37 +23,3 @@ __all__ = [
"format_tool_to_openai_tool", "format_tool_to_openai_tool",
"format_tool_to_openai_function", "format_tool_to_openai_function",
] ]
ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
.. code-block:: markdown
search: This tool is used for search
calculator: This tool is used for math
"""
return "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
def render_text_description_and_args(tools: List[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text.
Output will be in the format of:
.. code-block:: markdown
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
"""
tool_strings = []
for tool in tools:
args_schema = str(tool.args)
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
return "\n".join(tool_strings)

View File

@ -1,90 +1,15 @@
from functools import partial from langchain_core.tools import (
from typing import Optional RetrieverInput,
ToolsRenderer,
from langchain_core.callbacks.manager import ( create_retriever_tool,
Callbacks, render_text_description,
) render_text_description_and_args,
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
aformat_document,
format_document,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain.tools import Tool
class RetrieverInput(BaseModel):
"""Input to the retriever."""
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
) )
__all__ = [
async def _aget_relevant_documents( "RetrieverInput",
query: str, "ToolsRenderer",
retriever: BaseRetriever, "create_retriever_tool",
document_prompt: BasePromptTemplate, "render_text_description",
document_separator: str, "render_text_description_and_args",
callbacks: Callbacks = None, ]
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return document_separator.join(
[await aformat_document(doc, document_prompt) for doc in docs]
)
def create_retriever_tool(
retriever: BaseRetriever,
name: str,
description: str,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)