diff --git a/libs/langchain/langchain/retrievers/web_research.py b/libs/langchain/langchain/retrievers/web_research.py index a5ad41007f..008df92e0b 100644 --- a/libs/langchain/langchain/retrievers/web_research.py +++ b/libs/langchain/langchain/retrievers/web_research.py @@ -35,7 +35,7 @@ class SearchQueries(BaseModel): DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate( input_variables=["question"], template="""<> \n You are an assistant tasked with improving Google search - results. \n <> \n\n [INST] Generate FIVE Google search queries that + results. \n <> \n\n [INST] Generate THREE Google search queries that are similar to this question. The output should be a numbered list of questions and each should have a question mark at the end: \n\n {question} [/INST]""", ) @@ -43,7 +43,7 @@ DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate( DEFAULT_SEARCH_PROMPT = PromptTemplate( input_variables=["question"], template="""You are an assistant tasked with improving Google search - results. Generate FIVE Google search queries that are similar to + results. Generate THREE Google search queries that are similar to this question. The output should be a numbered list of questions and each should have a question mark at the end: {question}""", ) @@ -73,7 +73,6 @@ class WebResearchRetriever(BaseRetriever): ) llm_chain: LLMChain search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper") - max_splits_per_doc: int = Field(100, description="Maximum splits per document") num_search_results: int = Field(1, description="Number of pages per Google search") text_splitter: RecursiveCharacterTextSplitter = Field( RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50), @@ -90,10 +89,9 @@ class WebResearchRetriever(BaseRetriever): llm: BaseLLM, search: GoogleSearchAPIWrapper, prompt: Optional[BasePromptTemplate] = None, - max_splits_per_doc: int = 100, num_search_results: int = 1, text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( - chunk_size=1500, chunk_overlap=50 + chunk_size=1500, chunk_overlap=150 ), ) -> "WebResearchRetriever": """Initialize from llm using default template. @@ -103,7 +101,6 @@ class WebResearchRetriever(BaseRetriever): llm: llm for search question generation search: GoogleSearchAPIWrapper prompt: prompt to generating search questions - max_splits_per_doc: Maximum splits per document to keep num_search_results: Number of pages per Google search text_splitter: Text splitter for splitting web pages into chunks @@ -131,14 +128,30 @@ class WebResearchRetriever(BaseRetriever): vectorstore=vectorstore, llm_chain=llm_chain, search=search, - max_splits_per_doc=max_splits_per_doc, num_search_results=num_search_results, text_splitter=text_splitter, ) + def clean_search_query(self, query: str) -> str: + # Some search tools (e.g., Google) will + # fail to return results if query has a + # leading digit: 1. "LangCh..." + # Check if the first character is a digit + if query[0].isdigit(): + # Find the position of the first quote + first_quote_pos = query.find('"') + if first_quote_pos != -1: + # Extract the part of the string after the quote + query = query[first_quote_pos + 1 :] + # Remove the trailing quote if present + if query.endswith('"'): + query = query[:-1] + return query.strip() + def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]: """Returns num_serch_results pages per Google search.""" - result = self.search.results(query, num_search_results) + query_clean = self.clean_search_query(query) + result = self.search.results(query_clean, num_search_results) return result def _get_relevant_documents(