Harrison/retrieval code (#1916)

tool-patch
Harrison Chase 1 year ago committed by GitHub
parent eb80d6e0e4
commit fab7994b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,12 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain
from langchain.chains.chat_vector_db.base import ChatVectorDBChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversational_retrieval.base import (
ChatVectorDBChain,
ConversationalRetrievalChain,
)
from langchain.chains.graph_qa.base import GraphQAChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain
@ -18,14 +21,15 @@ from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.pal.base import PALChain
from langchain.chains.qa_generation.base import QAGenerationChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.sql_database.base import (
SQLDatabaseChain,
SQLDatabaseSequentialChain,
)
from langchain.chains.transform import TransformChain
from langchain.chains.vector_db_qa.base import VectorDBQA
__all__ = [
"ConversationChain",
@ -54,4 +58,7 @@ __all__ = [
"GraphQAChain",
"ConstitutionalChain",
"QAGenerationChain",
"RetrievalQA",
"RetrievalQAWithSourcesChain",
"ConversationalRetrievalChain",
]

@ -1,18 +1,19 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
@ -25,21 +26,22 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
return buffer
class ChatVectorDBChain(Chain, BaseModel):
"""Chain for chatting with a vector database."""
class BaseConversationalRetrievalChain(Chain, BaseModel):
"""Chain for chatting with an index."""
vectorstore: VectorStore
combine_docs_chain: BaseCombineDocumentsChain
question_generator: LLMChain
output_key: str = "answer"
return_source_documents: bool = False
top_k_docs_for_context: int = 4
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
"""Return the source documents."""
@property
def _chain_type(self) -> str:
return "chat-vector-db"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
@ -57,44 +59,22 @@ class ChatVectorDBChain(Chain, BaseModel):
_output_keys = _output_keys + ["source_documents"]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> ChatVectorDBChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)
@abstractmethod
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str:
new_question = self.question_generator.run(
question=question, chat_history=chat_history_str
)
else:
new_question = question
docs = self.vectorstore.similarity_search(
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
)
docs = self._get_docs(new_question, inputs)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@ -108,7 +88,6 @@ class ChatVectorDBChain(Chain, BaseModel):
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str:
new_question = await self.question_generator.arun(
question=question, chat_history=chat_history_str
@ -116,9 +95,7 @@ class ChatVectorDBChain(Chain, BaseModel):
else:
new_question = question
# TODO: This blocks the event loop, but it's not clear how to avoid it.
docs = self.vectorstore.similarity_search(
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
)
docs = self._get_docs(new_question, inputs)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@ -132,3 +109,79 @@ class ChatVectorDBChain(Chain, BaseModel):
if self.get_chat_history:
raise ValueError("Chain not savable when `get_chat_history` is not None.")
super().save(file_path)
class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with an index."""
retriever: BaseRetriever
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
return self.retriever.get_relevant_texts(question)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
retriever: BaseRetriever,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)
class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with a vector database."""
vectorstore: VectorStore = Field(alias="vectorstore")
top_k_docs_for_context: int = 4
search_kwargs: dict = Field(default_factory=dict)
@property
def _chain_type(self) -> str:
return "chat-vector-db"
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)

@ -0,0 +1,20 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

@ -20,8 +20,8 @@ from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.pal.base import PALChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA
from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config
from langchain.utilities.loading import try_load_from_hub

@ -0,0 +1,46 @@
"""Question-answering with sources over an index."""
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
"""Question-answering with sources over an index."""
retriever: BaseRetriever = Field(exclude=True)
"""Index to connect to."""
reduce_k_below_max_tokens: bool = False
"""Reduce the number of results to return from store based on tokens limit"""
max_tokens_limit: int = 3375
"""Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance(
self.combine_documents_chain, StuffDocumentsChain
):
tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
doc.page_content
)
for doc in docs
]
token_count = sum(tokens[:num_docs])
while token_count > self.max_tokens_limit:
num_docs -= 1
token_count -= tokens[num_docs]
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
question = inputs[self.question_key]
docs = self.retriever.get_relevant_texts(question)
return self._reduce_tokens_below_limit(docs)

@ -1,6 +1,7 @@
"""Chain for question-answering against a vector database."""
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator
@ -12,43 +13,24 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
class VectorDBQA(Chain, BaseModel):
"""Chain for question-answering against a vector database.
Example:
.. code-block:: python
from langchain import OpenAI, VectorDBQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
"""
vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
class BaseRetrievalQA(Chain, BaseModel):
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
@ -69,45 +51,13 @@ class VectorDBQA(Chain, BaseModel):
_output_keys = _output_keys + ["source_documents"]
return _output_keys
# TODO: deprecate this
@root_validator(pre=True)
def load_combine_documents_chain(cls, values: Dict) -> Dict:
"""Validate question chain."""
if "combine_documents_chain" not in values:
if "llm" not in values:
raise ValueError(
"If `combine_documents_chain` not provided, `llm` should be."
)
llm = values.pop("llm")
prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm))
llm_chain = LLMChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
)
values["combine_documents_chain"] = combine_documents_chain
return values
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> VectorDBQA:
) -> BaseRetrievalQA:
"""Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt)
@ -129,7 +79,7 @@ class VectorDBQA(Chain, BaseModel):
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> VectorDBQA:
) -> BaseRetrievalQA:
"""Load chain from chain type."""
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
@ -137,8 +87,12 @@ class VectorDBQA(Chain, BaseModel):
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run similarity search and llm on input query.
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
@ -146,11 +100,62 @@ class VectorDBQA(Chain, BaseModel):
Example:
.. code-block:: python
res = vectordbqa({'query': 'This is my query'})
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
question = inputs[self.input_key]
docs = self._get_docs(question)
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against an index.
Example:
.. code-block:: python
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb)
"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_texts(question)
class VectorDBQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against a vector database."""
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(self, question: str) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
@ -161,12 +166,7 @@ class VectorDBQA(Chain, BaseModel):
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
return docs
@property
def _chain_type(self) -> str:

@ -1,39 +1,3 @@
"""Interface for interacting with a document."""
from typing import List
from langchain.schema import Document
from pydantic import BaseModel, Field
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
lookup_str: str = ""
lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property
def paragraphs(self) -> List[str]:
"""Paragraphs of the page."""
return self.page_content.split("\n\n")
@property
def summary(self) -> str:
"""Summary of the page (the first paragraph)."""
return self.paragraphs[0]
def lookup(self, string: str) -> str:
"""Lookup a term in the page, imitating cmd-F functionality."""
if string.lower() != self.lookup_str:
self.lookup_str = string.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
__all__ = ["Document"]

@ -2,8 +2,8 @@ from typing import Any, List, Optional, Type
from pydantic import BaseModel, Extra, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.vector_db_qa.base import VectorDBQA
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
@ -32,7 +32,9 @@ class VectorStoreIndexWrapper(BaseModel):
def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
"""Query the vectorstore."""
llm = llm or OpenAI(temperature=0)
chain = VectorDBQA.from_chain_type(llm, vectorstore=self.vectorstore, **kwargs)
chain = RetrievalQA.from_chain_type(
llm, retriver=self.vectorstore.as_retriever(), **kwargs
)
return chain.run(question)
def query_with_sources(
@ -40,8 +42,8 @@ class VectorStoreIndexWrapper(BaseModel):
) -> dict:
"""Query the vectorstore and get back sources."""
llm = llm or OpenAI(temperature=0)
chain = VectorDBQAWithSourcesChain.from_chain_type(
llm, vectorstore=self.vectorstore, **kwargs
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(), **kwargs
)
return chain({chain.question_key: question})

@ -240,6 +240,54 @@ class BaseMemory(BaseModel, ABC):
"""Clear memory contents."""
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
lookup_str: str = ""
lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property
def paragraphs(self) -> List[str]:
"""Paragraphs of the page."""
return self.page_content.split("\n\n")
@property
def summary(self) -> str:
"""Summary of the page (the first paragraph)."""
return self.paragraphs[0]
def lookup(self, string: str) -> str:
"""Lookup a term in the page, imitating cmd-F functionality."""
if string.lower() != self.lookup_str:
self.lookup_str = string.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_texts(self, query: str) -> List[Document]:
"""Get texts relevant for a query.
Args:
query: string to find relevant tests for
Returns:
List of relevant documents
"""
# For backwards compatibility

@ -6,7 +6,7 @@ from typing import Any, Dict
from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.vector_db_qa.base import VectorDBQA
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool

@ -4,8 +4,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
class VectorStore(ABC):
@ -122,3 +125,19 @@ class VectorStore(ABC):
**kwargs: Any,
) -> VectorStore:
"""Return VectorStore initialized from texts and embeddings."""
def as_retriever(self) -> VectorStoreRetriever:
return VectorStoreRetriever(vectorstore=self)
class VectorStoreRetriever(BaseRetriever, BaseModel):
vectorstore: VectorStore
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_relevant_texts(self, query: str) -> List[Document]:
return self.vectorstore.similarity_search(query, **self.search_kwargs)

Loading…
Cancel
Save