Clean queries prior to search (#8309)

With some search tools, we see no results returned if the query is a
numeric list.

E.g., if we pass:
```
'1. "LangChain vs LangSmith: How do they differ?"'
```

We see:
```
No good Google Search Result was found
```

Local testing w/ Streamlit:

![image](https://github.com/langchain-ai/langchain/assets/122662504/0a7e3dca-59e8-415e-8df6-bd9e4ea962ee)
This commit is contained in:
Lance Martin 2023-07-26 11:48:28 -07:00 committed by GitHub
parent 6b88fbd9bb
commit 77c0582243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,7 +35,7 @@ class SearchQueries(BaseModel):
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""<<SYS>> \n You are an assistant tasked with improving Google search
results. \n <</SYS>> \n\n [INST] Generate FIVE Google search queries that
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that
are similar to this question. The output should be a numbered list of questions
and each should have a question mark at the end: \n\n {question} [/INST]""",
)
@ -43,7 +43,7 @@ DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
DEFAULT_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an assistant tasked with improving Google search
results. Generate FIVE Google search queries that are similar to
results. Generate THREE Google search queries that are similar to
this question. The output should be a numbered list of questions and each
should have a question mark at the end: {question}""",
)
@ -73,7 +73,6 @@ class WebResearchRetriever(BaseRetriever):
)
llm_chain: LLMChain
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
max_splits_per_doc: int = Field(100, description="Maximum splits per document")
num_search_results: int = Field(1, description="Number of pages per Google search")
text_splitter: RecursiveCharacterTextSplitter = Field(
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
@ -90,10 +89,9 @@ class WebResearchRetriever(BaseRetriever):
llm: BaseLLM,
search: GoogleSearchAPIWrapper,
prompt: Optional[BasePromptTemplate] = None,
max_splits_per_doc: int = 100,
num_search_results: int = 1,
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=50
chunk_size=1500, chunk_overlap=150
),
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
@ -103,7 +101,6 @@ class WebResearchRetriever(BaseRetriever):
llm: llm for search question generation
search: GoogleSearchAPIWrapper
prompt: prompt to generating search questions
max_splits_per_doc: Maximum splits per document to keep
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
@ -131,14 +128,30 @@ class WebResearchRetriever(BaseRetriever):
vectorstore=vectorstore,
llm_chain=llm_chain,
search=search,
max_splits_per_doc=max_splits_per_doc,
num_search_results=num_search_results,
text_splitter=text_splitter,
)
def clean_search_query(self, query: str) -> str:
# Some search tools (e.g., Google) will
# fail to return results if query has a
# leading digit: 1. "LangCh..."
# Check if the first character is a digit
if query[0].isdigit():
# Find the position of the first quote
first_quote_pos = query.find('"')
if first_quote_pos != -1:
# Extract the part of the string after the quote
query = query[first_quote_pos + 1 :]
# Remove the trailing quote if present
if query.endswith('"'):
query = query[:-1]
return query.strip()
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
"""Returns num_serch_results pages per Google search."""
result = self.search.results(query, num_search_results)
query_clean = self.clean_search_query(query)
result = self.search.results(query_clean, num_search_results)
return result
def _get_relevant_documents(