mirror of https://github.com/hwchase17/langchain
core[minor], langchain[patch]: `tools` dependencies refactoring (#18759)
The `langchain.tools` [namespace](https://api.python.langchain.com/en/latest/langchain_api_reference.html#module-langchain.tools) can be completely eliminated by moving one class and 3 functions into `core`. It makes sense since the class and functions are very core.pull/19956/head^2
parent
77eba10f47
commit
45d045b2c5
@ -1,90 +1,15 @@
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
aformat_document,
|
||||
format_document,
|
||||
from langchain_core.tools import (
|
||||
RetrieverInput,
|
||||
ToolsRenderer,
|
||||
create_retriever_tool,
|
||||
render_text_description,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain.tools import Tool
|
||||
|
||||
|
||||
class RetrieverInput(BaseModel):
|
||||
"""Input to the retriever."""
|
||||
|
||||
query: str = Field(description="query to look up in retriever")
|
||||
|
||||
|
||||
def _get_relevant_documents(
|
||||
query: str,
|
||||
retriever: BaseRetriever,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
callbacks: Callbacks = None,
|
||||
) -> str:
|
||||
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
||||
return document_separator.join(
|
||||
format_document(doc, document_prompt) for doc in docs
|
||||
)
|
||||
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
query: str,
|
||||
retriever: BaseRetriever,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
callbacks: Callbacks = None,
|
||||
) -> str:
|
||||
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
||||
return document_separator.join(
|
||||
[await aformat_document(doc, document_prompt) for doc in docs]
|
||||
)
|
||||
|
||||
|
||||
def create_retriever_tool(
|
||||
retriever: BaseRetriever,
|
||||
name: str,
|
||||
description: str,
|
||||
*,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = "\n\n",
|
||||
) -> Tool:
|
||||
"""Create a tool to do retrieval of documents.
|
||||
|
||||
Args:
|
||||
retriever: The retriever to use for the retrieval
|
||||
name: The name for the tool. This will be passed to the language model,
|
||||
so should be unique and somewhat descriptive.
|
||||
description: The description for the tool. This will be passed to the language
|
||||
model, so should be descriptive.
|
||||
|
||||
Returns:
|
||||
Tool class to pass to an agent
|
||||
"""
|
||||
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
||||
func = partial(
|
||||
_get_relevant_documents,
|
||||
retriever=retriever,
|
||||
document_prompt=document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
afunc = partial(
|
||||
_aget_relevant_documents,
|
||||
retriever=retriever,
|
||||
document_prompt=document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
return Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
func=func,
|
||||
coroutine=afunc,
|
||||
args_schema=RetrieverInput,
|
||||
)
|
||||
__all__ = [
|
||||
"RetrieverInput",
|
||||
"ToolsRenderer",
|
||||
"create_retriever_tool",
|
||||
"render_text_description",
|
||||
"render_text_description_and_args",
|
||||
]
|
||||
|
Loading…
Reference in New Issue