Harrison/retrieval code (#1916)

tool-patch
Harrison Chase 1 year ago committed by GitHub
parent eb80d6e0e4
commit fab7994b74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,12 @@
"""Chains are easily reusable components which can be linked together.""" """Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.chat_vector_db.base import ChatVectorDBChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain from langchain.chains.constitutional_ai.base import ConstitutionalChain
from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversational_retrieval.base import (
ChatVectorDBChain,
ConversationalRetrievalChain,
)
from langchain.chains.graph_qa.base import GraphQAChain from langchain.chains.graph_qa.base import GraphQAChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -18,14 +21,15 @@ from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
from langchain.chains.qa_generation.base import QAGenerationChain from langchain.chains.qa_generation.base import QAGenerationChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.sql_database.base import ( from langchain.chains.sql_database.base import (
SQLDatabaseChain, SQLDatabaseChain,
SQLDatabaseSequentialChain, SQLDatabaseSequentialChain,
) )
from langchain.chains.transform import TransformChain from langchain.chains.transform import TransformChain
from langchain.chains.vector_db_qa.base import VectorDBQA
__all__ = [ __all__ = [
"ConversationChain", "ConversationChain",
@ -54,4 +58,7 @@ __all__ = [
"GraphQAChain", "GraphQAChain",
"ConstitutionalChain", "ConstitutionalChain",
"QAGenerationChain", "QAGenerationChain",
"RetrievalQA",
"RetrievalQAWithSourcesChain",
"ConversationalRetrievalChain",
] ]

@ -1,18 +1,19 @@
"""Chain for chatting with a vector database.""" """Chain for chatting with a vector database."""
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -25,21 +26,22 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
return buffer return buffer
class ChatVectorDBChain(Chain, BaseModel): class BaseConversationalRetrievalChain(Chain, BaseModel):
"""Chain for chatting with a vector database.""" """Chain for chatting with an index."""
vectorstore: VectorStore
combine_docs_chain: BaseCombineDocumentsChain combine_docs_chain: BaseCombineDocumentsChain
question_generator: LLMChain question_generator: LLMChain
output_key: str = "answer" output_key: str = "answer"
return_source_documents: bool = False return_source_documents: bool = False
top_k_docs_for_context: int = 4
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
"""Return the source documents.""" """Return the source documents."""
@property class Config:
def _chain_type(self) -> str: """Configuration for this pydantic object."""
return "chat-vector-db"
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -57,44 +59,22 @@ class ChatVectorDBChain(Chain, BaseModel):
_output_keys = _output_keys + ["source_documents"] _output_keys = _output_keys + ["source_documents"]
return _output_keys return _output_keys
@classmethod @abstractmethod
def from_llm( def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
cls, """Get docs."""
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> ChatVectorDBChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"] question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"]) chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str: if chat_history_str:
new_question = self.question_generator.run( new_question = self.question_generator.run(
question=question, chat_history=chat_history_str question=question, chat_history=chat_history_str
) )
else: else:
new_question = question new_question = question
docs = self.vectorstore.similarity_search( docs = self._get_docs(new_question, inputs)
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
)
new_inputs = inputs.copy() new_inputs = inputs.copy()
new_inputs["question"] = new_question new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str new_inputs["chat_history"] = chat_history_str
@ -108,7 +88,6 @@ class ChatVectorDBChain(Chain, BaseModel):
question = inputs["question"] question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"]) chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str: if chat_history_str:
new_question = await self.question_generator.arun( new_question = await self.question_generator.arun(
question=question, chat_history=chat_history_str question=question, chat_history=chat_history_str
@ -116,9 +95,7 @@ class ChatVectorDBChain(Chain, BaseModel):
else: else:
new_question = question new_question = question
# TODO: This blocks the event loop, but it's not clear how to avoid it. # TODO: This blocks the event loop, but it's not clear how to avoid it.
docs = self.vectorstore.similarity_search( docs = self._get_docs(new_question, inputs)
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
)
new_inputs = inputs.copy() new_inputs = inputs.copy()
new_inputs["question"] = new_question new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str new_inputs["chat_history"] = chat_history_str
@ -132,3 +109,79 @@ class ChatVectorDBChain(Chain, BaseModel):
if self.get_chat_history: if self.get_chat_history:
raise ValueError("Chain not savable when `get_chat_history` is not None.") raise ValueError("Chain not savable when `get_chat_history` is not None.")
super().save(file_path) super().save(file_path)
class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with an index."""
retriever: BaseRetriever
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
return self.retriever.get_relevant_texts(question)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
retriever: BaseRetriever,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)
class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with a vector database."""
vectorstore: VectorStore = Field(alias="vectorstore")
top_k_docs_for_context: int = 4
search_kwargs: dict = Field(default_factory=dict)
@property
def _chain_type(self) -> str:
return "chat-vector-db"
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
prompt=qa_prompt,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)

@ -0,0 +1,20 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

@ -20,8 +20,8 @@ from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA
from langchain.llms.loading import load_llm, load_llm_from_config from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config
from langchain.utilities.loading import try_load_from_hub from langchain.utilities.loading import try_load_from_hub

@ -0,0 +1,46 @@
"""Question-answering with sources over an index."""
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
"""Question-answering with sources over an index."""
retriever: BaseRetriever = Field(exclude=True)
"""Index to connect to."""
reduce_k_below_max_tokens: bool = False
"""Reduce the number of results to return from store based on tokens limit"""
max_tokens_limit: int = 3375
"""Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance(
self.combine_documents_chain, StuffDocumentsChain
):
tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
doc.page_content
)
for doc in docs
]
token_count = sum(tokens[:num_docs])
while token_count > self.max_tokens_limit:
num_docs -= 1
token_count -= tokens[num_docs]
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
question = inputs[self.question_key]
docs = self.retriever.get_relevant_texts(question)
return self._reduce_tokens_below_limit(docs)

@ -1,6 +1,7 @@
"""Chain for question-answering against a vector database.""" """Chain for question-answering against a vector database."""
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -12,43 +13,24 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
class VectorDBQA(Chain, BaseModel): class BaseRetrievalQA(Chain, BaseModel):
"""Chain for question-answering against a vector database.
Example:
.. code-block:: python
from langchain import OpenAI, VectorDBQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
"""
vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
combine_documents_chain: BaseCombineDocumentsChain combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents.""" """Chain to use to combine the documents."""
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
return_source_documents: bool = False return_source_documents: bool = False
"""Return the source documents.""" """Return the source documents."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
allow_population_by_field_name = True
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -69,45 +51,13 @@ class VectorDBQA(Chain, BaseModel):
_output_keys = _output_keys + ["source_documents"] _output_keys = _output_keys + ["source_documents"]
return _output_keys return _output_keys
# TODO: deprecate this
@root_validator(pre=True)
def load_combine_documents_chain(cls, values: Dict) -> Dict:
"""Validate question chain."""
if "combine_documents_chain" not in values:
if "llm" not in values:
raise ValueError(
"If `combine_documents_chain` not provided, `llm` should be."
)
llm = values.pop("llm")
prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm))
llm_chain = LLMChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
)
values["combine_documents_chain"] = combine_documents_chain
return values
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None, prompt: Optional[PromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> VectorDBQA: ) -> BaseRetrievalQA:
"""Initialize from LLM.""" """Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt) llm_chain = LLMChain(llm=llm, prompt=_prompt)
@ -129,7 +79,7 @@ class VectorDBQA(Chain, BaseModel):
chain_type: str = "stuff", chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None, chain_type_kwargs: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> VectorDBQA: ) -> BaseRetrievalQA:
"""Load chain from chain type.""" """Load chain from chain type."""
_chain_type_kwargs = chain_type_kwargs or {} _chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain( combine_documents_chain = load_qa_chain(
@ -137,8 +87,12 @@ class VectorDBQA(Chain, BaseModel):
) )
return cls(combine_documents_chain=combine_documents_chain, **kwargs) return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run similarity search and llm on input query. """Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'. the retrieved documents as well under the key 'source_documents'.
@ -146,11 +100,62 @@ class VectorDBQA(Chain, BaseModel):
Example: Example:
.. code-block:: python .. code-block:: python
res = vectordbqa({'query': 'This is my query'}) res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents'] answer, docs = res['result'], res['source_documents']
""" """
question = inputs[self.input_key] question = inputs[self.input_key]
docs = self._get_docs(question)
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against an index.
Example:
.. code-block:: python
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb)
"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_texts(question)
class VectorDBQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against a vector database."""
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(self, question: str) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search( docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs question, k=self.k, **self.search_kwargs
@ -161,12 +166,7 @@ class VectorDBQA(Chain, BaseModel):
) )
else: else:
raise ValueError(f"search_type of {self.search_type} not allowed.") raise ValueError(f"search_type of {self.search_type} not allowed.")
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) return docs
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:

@ -1,39 +1,3 @@
"""Interface for interacting with a document.""" from langchain.schema import Document
from typing import List
from pydantic import BaseModel, Field __all__ = ["Document"]
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
lookup_str: str = ""
lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property
def paragraphs(self) -> List[str]:
"""Paragraphs of the page."""
return self.page_content.split("\n\n")
@property
def summary(self) -> str:
"""Summary of the page (the first paragraph)."""
return self.paragraphs[0]
def lookup(self, string: str) -> str:
"""Lookup a term in the page, imitating cmd-F functionality."""
if string.lower() != self.lookup_str:
self.lookup_str = string.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"

@ -2,8 +2,8 @@ from typing import Any, List, Optional, Type
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
@ -32,7 +32,9 @@ class VectorStoreIndexWrapper(BaseModel):
def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
"""Query the vectorstore.""" """Query the vectorstore."""
llm = llm or OpenAI(temperature=0) llm = llm or OpenAI(temperature=0)
chain = VectorDBQA.from_chain_type(llm, vectorstore=self.vectorstore, **kwargs) chain = RetrievalQA.from_chain_type(
llm, retriver=self.vectorstore.as_retriever(), **kwargs
)
return chain.run(question) return chain.run(question)
def query_with_sources( def query_with_sources(
@ -40,8 +42,8 @@ class VectorStoreIndexWrapper(BaseModel):
) -> dict: ) -> dict:
"""Query the vectorstore and get back sources.""" """Query the vectorstore and get back sources."""
llm = llm or OpenAI(temperature=0) llm = llm or OpenAI(temperature=0)
chain = VectorDBQAWithSourcesChain.from_chain_type( chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, vectorstore=self.vectorstore, **kwargs llm, retriever=self.vectorstore.as_retriever(), **kwargs
) )
return chain({chain.question_key: question}) return chain({chain.question_key: question})

@ -240,6 +240,54 @@ class BaseMemory(BaseModel, ABC):
"""Clear memory contents.""" """Clear memory contents."""
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
lookup_str: str = ""
lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property
def paragraphs(self) -> List[str]:
"""Paragraphs of the page."""
return self.page_content.split("\n\n")
@property
def summary(self) -> str:
"""Summary of the page (the first paragraph)."""
return self.paragraphs[0]
def lookup(self, string: str) -> str:
"""Lookup a term in the page, imitating cmd-F functionality."""
if string.lower() != self.lookup_str:
self.lookup_str = string.lower()
self.lookup_index = 0
else:
self.lookup_index += 1
lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]
if len(lookups) == 0:
return "No Results"
elif self.lookup_index >= len(lookups):
return "No More Results"
else:
result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})"
return f"{result_prefix} {lookups[self.lookup_index]}"
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_texts(self, query: str) -> List[Document]:
"""Get texts relevant for a query.
Args:
query: string to find relevant tests for
Returns:
List of relevant documents
"""
# For backwards compatibility # For backwards compatibility

@ -6,7 +6,7 @@ from typing import Any, Dict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool

@ -4,8 +4,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
class VectorStore(ABC): class VectorStore(ABC):
@ -122,3 +125,19 @@ class VectorStore(ABC):
**kwargs: Any, **kwargs: Any,
) -> VectorStore: ) -> VectorStore:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
def as_retriever(self) -> VectorStoreRetriever:
return VectorStoreRetriever(vectorstore=self)
class VectorStoreRetriever(BaseRetriever, BaseModel):
vectorstore: VectorStore
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_relevant_texts(self, query: str) -> List[Document]:
return self.vectorstore.similarity_search(query, **self.search_kwargs)

Loading…
Cancel
Save