mirror of
synced 2024-11-06 03:20:49 +00:00
0.2rc migrations - [x] Move memory - [x] Move remaining retrievers - [x] graph_qa chains - [x] some dependency from evaluation code potentially on math utils - [x] Move openapi chain from `langchain.chains.api.openapi` to `langchain_community.chains.openapi` - [x] Migrate `langchain.chains.ernie_functions` to `langchain_community.chains.ernie_functions` - [x] migrate `langchain/chains/llm_requests.py` to `langchain_community.chains.llm_requests` - [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder` -> `langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder` (namespace not ideal, but it needs to be moved to `langchain` to avoid circular deps) - [x] unit tests langchain -- add pytest.mark.community to some unit tests that will stay in langchain - [x] unit tests community -- move unit tests that depend on community to community - [x] mv integration tests that depend on community to community - [x] mypy checks Other todo - [x] Make deprecation warnings not noisy (need to use warn deprecated and check that things are implemented properly) - [x] Update deprecation messages with timeline for code removal (likely we actually won't be removing things until 0.4 release) -- will give people more time to transition their code. - [ ] Add information to deprecation warning to show users how to migrate their code base using langchain-cli - [ ] Remove any unnecessary requirements in langchain (e.g., is SQLALchemy required?) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
224 lines
8.0 KiB
224 lines
8.0 KiB
import logging
import re
from typing import List, Optional
from langchain.chains import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain_core.callbacks import (
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.document_transformers import Html2TextTransformer
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import GoogleSearchAPIWrapper
logger = logging.getLogger(__name__)
class SearchQueries(BaseModel):
"""Search queries to research for the user's goal."""
queries: List[str] = Field(
..., description="List of search queries to look up on Google"
template="""<<SYS>> \n You are an assistant tasked with improving Google search \
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \
are similar to this question. The output should be a numbered list of questions \
and each should have a question mark at the end: \n\n {question} [/INST]""",
template="""You are an assistant tasked with improving Google search \
results. Generate THREE Google search queries that are similar to \
this question. The output should be a numbered list of questions and each \
should have a question mark at the end: {question}""",
class QuestionListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of numbered questions."""
def parse(self, text: str) -> List[str]:
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
return lines
class WebResearchRetriever(BaseRetriever):
"""`Google Search API` retriever."""
# Inputs
vectorstore: VectorStore = Field(
..., description="Vector store for storing web pages"
llm_chain: LLMChain
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
num_search_results: int = Field(1, description="Number of pages per Google search")
text_splitter: TextSplitter = Field(
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
description="Text splitter for splitting web pages into chunks",
url_database: List[str] = Field(
default_factory=list, description="List of processed URLs"
def from_llm(
vectorstore: VectorStore,
llm: BaseLLM,
search: GoogleSearchAPIWrapper,
prompt: Optional[BasePromptTemplate] = None,
num_search_results: int = 1,
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=150
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
vectorstore: Vector store for storing web pages
llm: llm for search question generation
search: GoogleSearchAPIWrapper
prompt: prompt to generating search questions
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
if not prompt:
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
(lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT)
prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm)
# Use chat model prompt
llm_chain = LLMChain(
return cls(
def clean_search_query(self, query: str) -> str:
# Some search tools (e.g., Google) will
# fail to return results if query has a
# leading digit: 1. "LangCh..."
# Check if the first character is a digit
if query[0].isdigit():
# Find the position of the first quote
first_quote_pos = query.find('"')
if first_quote_pos != -1:
# Extract the part of the string after the quote
query = query[first_quote_pos + 1 :]
# Remove the trailing quote if present
if query.endswith('"'):
query = query[:-1]
return query.strip()
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
"""Returns num_search_results pages per Google search."""
query_clean = self.clean_search_query(query)
result = self.search.results(query_clean, num_search_results)
return result
def _get_relevant_documents(
query: str,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Search Google for documents related to the query input.
query: user query
Relevant documents from all various urls.
# Get search questions
logger.info("Generating questions for Google Search ...")
result = self.llm_chain({"question": query})
logger.info(f"Questions for Google Search (raw): {result}")
questions = result["text"]
logger.info(f"Questions for Google Search: {questions}")
# Get urls
logger.info("Searching for relevant urls...")
urls_to_look = []
for query in questions:
# Google search
search_results = self.search_tool(query, self.num_search_results)
logger.info("Searching for relevant urls...")
logger.info(f"Search results: {search_results}")
for res in search_results:
if res.get("link", None):
# Relevant urls
urls = set(urls_to_look)
# Check for any new urls that we have not processed
new_urls = list(urls.difference(self.url_database))
logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore
if new_urls:
loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True)
html2text = Html2TextTransformer()
logger.info("Indexing new urls...")
docs = loader.load()
docs = list(html2text.transform_documents(docs))
docs = self.text_splitter.split_documents(docs)
# Search for relevant splits
# TODO: make this async
logger.info("Grabbing most relevant splits from urls...")
docs = []
for query in questions:
# Get unique docs
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs
unique_documents = list(unique_documents_dict.values())
return unique_documents
async def _aget_relevant_documents(
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError