From 4b7a85887ea0718fbe16ad462e05aa7710960e86 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 13 Nov 2023 20:54:03 -0800 Subject: [PATCH] arxiv retrieval agent improvement (#13329) --- .../retrieval-agent/retrieval_agent/chain.py | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/templates/retrieval-agent/retrieval_agent/chain.py b/templates/retrieval-agent/retrieval_agent/chain.py index fb3ab979c8..4c548fb064 100644 --- a/templates/retrieval-agent/retrieval_agent/chain.py +++ b/templates/retrieval-agent/retrieval_agent/chain.py @@ -4,21 +4,66 @@ from typing import List, Tuple from langchain.agents import AgentExecutor from langchain.agents.format_scratchpad import format_to_openai_function_messages from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser +from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chat_models import AzureChatOpenAI from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.pydantic_v1 import BaseModel, Field +from langchain.schema import BaseRetriever, Document from langchain.schema.messages import AIMessage, HumanMessage -from langchain.tools import ArxivQueryRun from langchain.tools.render import format_tool_to_openai_function -from langchain.utilities import ArxivAPIWrapper +from langchain.tools.retriever import create_retriever_tool +from langchain.utilities.arxiv import ArxivAPIWrapper -class ArxivInput(BaseModel): - query: str = Field(description="search query to look up") +class ArxivRetriever(BaseRetriever, ArxivAPIWrapper): + """`Arxiv` retriever. + It wraps load() to get_relevant_documents(). + It uses all ArxivAPIWrapper arguments without any change. + """ + + get_full_documents: bool = False + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + try: + if self.is_arxiv_identifier(query): + results = self.arxiv_search( + id_list=query.split(), + max_results=self.top_k_results, + ).results() + else: + results = self.arxiv_search( # type: ignore + query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results + ).results() + except self.arxiv_exceptions as ex: + return [Document(page_content=f"Arxiv exception: {ex}")] + docs = [ + Document( + page_content=result.summary, + metadata={ + "Published": result.updated.date(), + "Title": result.title, + "Authors": ", ".join(a.name for a in result.authors), + }, + ) + for result in results + ] + return docs + + +description = ( + "A wrapper around Arxiv.org " + "Useful for when you need to answer questions about Physics, Mathematics, " + "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, " + "Electrical Engineering, and Economics " + "from scientific articles on arxiv.org. " + "Input should be a search query." +) # Create the tool -arxiv_tool = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(), args_schema=ArxivInput) +arxiv_tool = create_retriever_tool(ArxivRetriever(), "arxiv", description) tools = [arxiv_tool] llm = AzureChatOpenAI( temperature=0,