mirror of https://github.com/hwchase17/langchain
core[minor], langchain[patch]: `tools` dependencies refactoring (#18759)
The `langchain.tools` [namespace](https://api.python.langchain.com/en/latest/langchain_api_reference.html#module-langchain.tools) can be completely eliminated by moving one class and 3 functions into `core`. It makes sense since the class and functions are very core.pull/19956/head^2
parent
77eba10f47
commit
45d045b2c5
@ -1,90 +1,15 @@
|
|||||||
from functools import partial
|
from langchain_core.tools import (
|
||||||
from typing import Optional
|
RetrieverInput,
|
||||||
|
ToolsRenderer,
|
||||||
from langchain_core.callbacks.manager import (
|
create_retriever_tool,
|
||||||
Callbacks,
|
render_text_description,
|
||||||
)
|
render_text_description_and_args,
|
||||||
from langchain_core.prompts import (
|
|
||||||
BasePromptTemplate,
|
|
||||||
PromptTemplate,
|
|
||||||
aformat_document,
|
|
||||||
format_document,
|
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
|
|
||||||
from langchain.tools import Tool
|
|
||||||
|
|
||||||
|
|
||||||
class RetrieverInput(BaseModel):
|
|
||||||
"""Input to the retriever."""
|
|
||||||
|
|
||||||
query: str = Field(description="query to look up in retriever")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
query: str,
|
|
||||||
retriever: BaseRetriever,
|
|
||||||
document_prompt: BasePromptTemplate,
|
|
||||||
document_separator: str,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
) -> str:
|
|
||||||
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
|
||||||
return document_separator.join(
|
|
||||||
format_document(doc, document_prompt) for doc in docs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
query: str,
|
|
||||||
retriever: BaseRetriever,
|
|
||||||
document_prompt: BasePromptTemplate,
|
|
||||||
document_separator: str,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
) -> str:
|
|
||||||
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
|
||||||
return document_separator.join(
|
|
||||||
[await aformat_document(doc, document_prompt) for doc in docs]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_retriever_tool(
|
|
||||||
retriever: BaseRetriever,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
*,
|
|
||||||
document_prompt: Optional[BasePromptTemplate] = None,
|
|
||||||
document_separator: str = "\n\n",
|
|
||||||
) -> Tool:
|
|
||||||
"""Create a tool to do retrieval of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retriever: The retriever to use for the retrieval
|
|
||||||
name: The name for the tool. This will be passed to the language model,
|
|
||||||
so should be unique and somewhat descriptive.
|
|
||||||
description: The description for the tool. This will be passed to the language
|
|
||||||
model, so should be descriptive.
|
|
||||||
|
|
||||||
Returns:
|
__all__ = [
|
||||||
Tool class to pass to an agent
|
"RetrieverInput",
|
||||||
"""
|
"ToolsRenderer",
|
||||||
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
"create_retriever_tool",
|
||||||
func = partial(
|
"render_text_description",
|
||||||
_get_relevant_documents,
|
"render_text_description_and_args",
|
||||||
retriever=retriever,
|
]
|
||||||
document_prompt=document_prompt,
|
|
||||||
document_separator=document_separator,
|
|
||||||
)
|
|
||||||
afunc = partial(
|
|
||||||
_aget_relevant_documents,
|
|
||||||
retriever=retriever,
|
|
||||||
document_prompt=document_prompt,
|
|
||||||
document_separator=document_separator,
|
|
||||||
)
|
|
||||||
return Tool(
|
|
||||||
name=name,
|
|
||||||
description=description,
|
|
||||||
func=func,
|
|
||||||
coroutine=afunc,
|
|
||||||
args_schema=RetrieverInput,
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue