diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index c0996872..0307fcf3 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -1,9 +1,12 @@ """Chains are easily reusable components which can be linked together.""" from langchain.chains.api.base import APIChain -from langchain.chains.chat_vector_db.base import ChatVectorDBChain from langchain.chains.combine_documents.base import AnalyzeDocumentChain from langchain.chains.constitutional_ai.base import ConstitutionalChain from langchain.chains.conversation.base import ConversationChain +from langchain.chains.conversational_retrieval.base import ( + ChatVectorDBChain, + ConversationalRetrievalChain, +) from langchain.chains.graph_qa.base import GraphQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain @@ -18,14 +21,15 @@ from langchain.chains.moderation import OpenAIModerationChain from langchain.chains.pal.base import PALChain from langchain.chains.qa_generation.base import QAGenerationChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain +from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sql_database.base import ( SQLDatabaseChain, SQLDatabaseSequentialChain, ) from langchain.chains.transform import TransformChain -from langchain.chains.vector_db_qa.base import VectorDBQA __all__ = [ "ConversationChain", @@ -54,4 +58,7 @@ __all__ = [ "GraphQAChain", "ConstitutionalChain", "QAGenerationChain", + "RetrievalQA", + "RetrievalQAWithSourcesChain", + "ConversationalRetrievalChain", ] diff --git a/langchain/chains/chat_vector_db/__init__.py b/langchain/chains/conversational_retrieval/__init__.py similarity index 100% rename from langchain/chains/chat_vector_db/__init__.py rename to langchain/chains/conversational_retrieval/__init__.py diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/conversational_retrieval/base.py similarity index 65% rename from langchain/chains/chat_vector_db/base.py rename to langchain/chains/conversational_retrieval/base.py index 4dd18b68..838c0af5 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -1,18 +1,19 @@ """Chain for chatting with a vector database.""" from __future__ import annotations +from abc import abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel +from pydantic import BaseModel, Extra, Field from langchain.chains.base import Chain -from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.schema import BaseLanguageModel, BaseRetriever, Document from langchain.vectorstores.base import VectorStore @@ -25,21 +26,22 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: return buffer -class ChatVectorDBChain(Chain, BaseModel): - """Chain for chatting with a vector database.""" +class BaseConversationalRetrievalChain(Chain, BaseModel): + """Chain for chatting with an index.""" - vectorstore: VectorStore combine_docs_chain: BaseCombineDocumentsChain question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False - top_k_docs_for_context: int = 4 get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None """Return the source documents.""" - @property - def _chain_type(self) -> str: - return "chat-vector-db" + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + allow_population_by_field_name = True @property def input_keys(self) -> List[str]: @@ -57,44 +59,22 @@ class ChatVectorDBChain(Chain, BaseModel): _output_keys = _output_keys + ["source_documents"] return _output_keys - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - vectorstore: VectorStore, - condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: Optional[BasePromptTemplate] = None, - chain_type: str = "stuff", - **kwargs: Any, - ) -> ChatVectorDBChain: - """Load chain from LLM.""" - doc_chain = load_qa_chain( - llm, - chain_type=chain_type, - prompt=qa_prompt, - ) - condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) - return cls( - vectorstore=vectorstore, - combine_docs_chain=doc_chain, - question_generator=condense_question_chain, - **kwargs, - ) + @abstractmethod + def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + """Get docs.""" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) - vectordbkwargs = inputs.get("vectordbkwargs", {}) + if chat_history_str: new_question = self.question_generator.run( question=question, chat_history=chat_history_str ) else: new_question = question - docs = self.vectorstore.similarity_search( - new_question, k=self.top_k_docs_for_context, **vectordbkwargs - ) + docs = self._get_docs(new_question, inputs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -108,7 +88,6 @@ class ChatVectorDBChain(Chain, BaseModel): question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) - vectordbkwargs = inputs.get("vectordbkwargs", {}) if chat_history_str: new_question = await self.question_generator.arun( question=question, chat_history=chat_history_str @@ -116,9 +95,7 @@ class ChatVectorDBChain(Chain, BaseModel): else: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. - docs = self.vectorstore.similarity_search( - new_question, k=self.top_k_docs_for_context, **vectordbkwargs - ) + docs = self._get_docs(new_question, inputs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -132,3 +109,79 @@ class ChatVectorDBChain(Chain, BaseModel): if self.get_chat_history: raise ValueError("Chain not savable when `get_chat_history` is not None.") super().save(file_path) + + +class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): + """Chain for chatting with an index.""" + + retriever: BaseRetriever + + def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + return self.retriever.get_relevant_texts(question) + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + retriever: BaseRetriever, + condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, + qa_prompt: Optional[BasePromptTemplate] = None, + chain_type: str = "stuff", + **kwargs: Any, + ) -> BaseConversationalRetrievalChain: + """Load chain from LLM.""" + doc_chain = load_qa_chain( + llm, + chain_type=chain_type, + prompt=qa_prompt, + ) + condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) + return cls( + retriever=retriever, + combine_docs_chain=doc_chain, + question_generator=condense_question_chain, + **kwargs, + ) + + +class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): + """Chain for chatting with a vector database.""" + + vectorstore: VectorStore = Field(alias="vectorstore") + top_k_docs_for_context: int = 4 + search_kwargs: dict = Field(default_factory=dict) + + @property + def _chain_type(self) -> str: + return "chat-vector-db" + + def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + vectordbkwargs = inputs.get("vectordbkwargs", {}) + full_kwargs = {**self.search_kwargs, **vectordbkwargs} + return self.vectorstore.similarity_search( + question, k=self.top_k_docs_for_context, **full_kwargs + ) + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + vectorstore: VectorStore, + condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, + qa_prompt: Optional[BasePromptTemplate] = None, + chain_type: str = "stuff", + **kwargs: Any, + ) -> BaseConversationalRetrievalChain: + """Load chain from LLM.""" + doc_chain = load_qa_chain( + llm, + chain_type=chain_type, + prompt=qa_prompt, + ) + condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) + return cls( + vectorstore=vectorstore, + combine_docs_chain=doc_chain, + question_generator=condense_question_chain, + **kwargs, + ) diff --git a/langchain/chains/conversational_retrieval/prompts.py b/langchain/chains/conversational_retrieval/prompts.py new file mode 100644 index 00000000..b2a2df09 --- /dev/null +++ b/langchain/chains/conversational_retrieval/prompts.py @@ -0,0 +1,20 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. + +Chat History: +{chat_history} +Follow Up Input: {question} +Standalone question:""" +CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) + +prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + +{context} + +Question: {question} +Helpful Answer:""" +QA_PROMPT = PromptTemplate( + template=prompt_template, input_variables=["context", "question"] +) diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index 026f47bc..5b9b78f9 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -20,8 +20,8 @@ from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.pal.base import PALChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain +from langchain.chains.retrieval_qa.base import VectorDBQA from langchain.chains.sql_database.base import SQLDatabaseChain -from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config from langchain.utilities.loading import try_load_from_hub diff --git a/langchain/chains/qa_with_sources/retrieval.py b/langchain/chains/qa_with_sources/retrieval.py new file mode 100644 index 00000000..1f50b28a --- /dev/null +++ b/langchain/chains/qa_with_sources/retrieval.py @@ -0,0 +1,46 @@ +"""Question-answering with sources over an index.""" + +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain +from langchain.docstore.document import Document +from langchain.schema import BaseRetriever + + +class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): + """Question-answering with sources over an index.""" + + retriever: BaseRetriever = Field(exclude=True) + """Index to connect to.""" + reduce_k_below_max_tokens: bool = False + """Reduce the number of results to return from store based on tokens limit""" + max_tokens_limit: int = 3375 + """Restrict the docs to return from store based on tokens, + enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true""" + + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: + num_docs = len(docs) + + if self.reduce_k_below_max_tokens and isinstance( + self.combine_documents_chain, StuffDocumentsChain + ): + tokens = [ + self.combine_documents_chain.llm_chain.llm.get_num_tokens( + doc.page_content + ) + for doc in docs + ] + token_count = sum(tokens[:num_docs]) + while token_count > self.max_tokens_limit: + num_docs -= 1 + token_count -= tokens[num_docs] + + return docs[:num_docs] + + def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: + question = inputs[self.question_key] + docs = self.retriever.get_relevant_texts(question) + return self._reduce_tokens_below_limit(docs) diff --git a/langchain/chains/vector_db_qa/__init__.py b/langchain/chains/retrieval_qa/__init__.py similarity index 100% rename from langchain/chains/vector_db_qa/__init__.py rename to langchain/chains/retrieval_qa/__init__.py diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/retrieval_qa/base.py similarity index 76% rename from langchain/chains/vector_db_qa/base.py rename to langchain/chains/retrieval_qa/base.py index 16182b78..c1543da8 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -1,6 +1,7 @@ """Chain for question-answering against a vector database.""" from __future__ import annotations +from abc import abstractmethod from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, Field, root_validator @@ -12,43 +13,24 @@ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.schema import BaseLanguageModel, BaseRetriever, Document from langchain.vectorstores.base import VectorStore -class VectorDBQA(Chain, BaseModel): - """Chain for question-answering against a vector database. - - Example: - .. code-block:: python - - from langchain import OpenAI, VectorDBQA - from langchain.faiss import FAISS - vectordb = FAISS(...) - vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb) - - """ - - vectorstore: VectorStore = Field(exclude=True) - """Vector Database to connect to.""" - k: int = 4 - """Number of documents to query for.""" +class BaseRetrievalQA(Chain, BaseModel): combine_documents_chain: BaseCombineDocumentsChain """Chain to use to combine the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_source_documents: bool = False """Return the source documents.""" - search_kwargs: Dict[str, Any] = Field(default_factory=dict) - """Extra search args.""" - search_type: str = "similarity" - """Search type to use over vectorstore. `similarity` or `mmr`.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True + allow_population_by_field_name = True @property def input_keys(self) -> List[str]: @@ -69,45 +51,13 @@ class VectorDBQA(Chain, BaseModel): _output_keys = _output_keys + ["source_documents"] return _output_keys - # TODO: deprecate this - @root_validator(pre=True) - def load_combine_documents_chain(cls, values: Dict) -> Dict: - """Validate question chain.""" - if "combine_documents_chain" not in values: - if "llm" not in values: - raise ValueError( - "If `combine_documents_chain` not provided, `llm` should be." - ) - llm = values.pop("llm") - prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm)) - llm_chain = LLMChain(llm=llm, prompt=prompt) - document_prompt = PromptTemplate( - input_variables=["page_content"], template="Context:\n{page_content}" - ) - combine_documents_chain = StuffDocumentsChain( - llm_chain=llm_chain, - document_variable_name="context", - document_prompt=document_prompt, - ) - values["combine_documents_chain"] = combine_documents_chain - return values - - @root_validator() - def validate_search_type(cls, values: Dict) -> Dict: - """Validate search type.""" - if "search_type" in values: - search_type = values["search_type"] - if search_type not in ("similarity", "mmr"): - raise ValueError(f"search_type of {search_type} not allowed.") - return values - @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, **kwargs: Any, - ) -> VectorDBQA: + ) -> BaseRetrievalQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain(llm=llm, prompt=_prompt) @@ -129,7 +79,7 @@ class VectorDBQA(Chain, BaseModel): chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any, - ) -> VectorDBQA: + ) -> BaseRetrievalQA: """Load chain from chain type.""" _chain_type_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_chain( @@ -137,8 +87,12 @@ class VectorDBQA(Chain, BaseModel): ) return cls(combine_documents_chain=combine_documents_chain, **kwargs) + @abstractmethod + def _get_docs(self, question: str) -> List[Document]: + """Get documents to do question answering over.""" + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: - """Run similarity search and llm on input query. + """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. @@ -146,11 +100,62 @@ class VectorDBQA(Chain, BaseModel): Example: .. code-block:: python - res = vectordbqa({'query': 'This is my query'}) + res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ question = inputs[self.input_key] + docs = self._get_docs(question) + answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) + + if self.return_source_documents: + return {self.output_key: answer, "source_documents": docs} + else: + return {self.output_key: answer} + + +class RetrievalQA(BaseRetrievalQA, BaseModel): + """Chain for question-answering against an index. + + Example: + .. code-block:: python + + from langchain.llms import OpenAI + from langchain.chains import RetrievalQA + from langchain.faiss import FAISS + vectordb = FAISS(...) + retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb) + + """ + + retriever: BaseRetriever = Field(exclude=True) + + def _get_docs(self, question: str) -> List[Document]: + return self.retriever.get_relevant_texts(question) + + +class VectorDBQA(BaseRetrievalQA, BaseModel): + """Chain for question-answering against a vector database.""" + + vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") + """Vector Database to connect to.""" + k: int = 4 + """Number of documents to query for.""" + search_type: str = "similarity" + """Search type to use over vectorstore. `similarity` or `mmr`.""" + search_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Extra search args.""" + + @root_validator() + def validate_search_type(cls, values: Dict) -> Dict: + """Validate search type.""" + if "search_type" in values: + search_type = values["search_type"] + if search_type not in ("similarity", "mmr"): + raise ValueError(f"search_type of {search_type} not allowed.") + return values + + def _get_docs(self, question: str) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search( question, k=self.k, **self.search_kwargs @@ -161,12 +166,7 @@ class VectorDBQA(Chain, BaseModel): ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") - answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) - - if self.return_source_documents: - return {self.output_key: answer, "source_documents": docs} - else: - return {self.output_key: answer} + return docs @property def _chain_type(self) -> str: diff --git a/langchain/chains/vector_db_qa/prompt.py b/langchain/chains/retrieval_qa/prompt.py similarity index 100% rename from langchain/chains/vector_db_qa/prompt.py rename to langchain/chains/retrieval_qa/prompt.py diff --git a/langchain/docstore/document.py b/langchain/docstore/document.py index cd6349d5..1c33318d 100644 --- a/langchain/docstore/document.py +++ b/langchain/docstore/document.py @@ -1,39 +1,3 @@ -"""Interface for interacting with a document.""" -from typing import List +from langchain.schema import Document -from pydantic import BaseModel, Field - - -class Document(BaseModel): - """Interface for interacting with a document.""" - - page_content: str - lookup_str: str = "" - lookup_index = 0 - metadata: dict = Field(default_factory=dict) - - @property - def paragraphs(self) -> List[str]: - """Paragraphs of the page.""" - return self.page_content.split("\n\n") - - @property - def summary(self) -> str: - """Summary of the page (the first paragraph).""" - return self.paragraphs[0] - - def lookup(self, string: str) -> str: - """Lookup a term in the page, imitating cmd-F functionality.""" - if string.lower() != self.lookup_str: - self.lookup_str = string.lower() - self.lookup_index = 0 - else: - self.lookup_index += 1 - lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()] - if len(lookups) == 0: - return "No Results" - elif self.lookup_index >= len(lookups): - return "No More Results" - else: - result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" - return f"{result_prefix} {lookups[self.lookup_index]}" +__all__ = ["Document"] diff --git a/langchain/indexes/vectorstore.py b/langchain/indexes/vectorstore.py index dc5807e4..87e4f4d9 100644 --- a/langchain/indexes/vectorstore.py +++ b/langchain/indexes/vectorstore.py @@ -2,8 +2,8 @@ from typing import Any, List, Optional, Type from pydantic import BaseModel, Extra, Field -from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain -from langchain.chains.vector_db_qa.base import VectorDBQA +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain +from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.document_loaders.base import BaseLoader from langchain.embeddings.base import Embeddings from langchain.embeddings.openai import OpenAIEmbeddings @@ -32,7 +32,9 @@ class VectorStoreIndexWrapper(BaseModel): def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: """Query the vectorstore.""" llm = llm or OpenAI(temperature=0) - chain = VectorDBQA.from_chain_type(llm, vectorstore=self.vectorstore, **kwargs) + chain = RetrievalQA.from_chain_type( + llm, retriver=self.vectorstore.as_retriever(), **kwargs + ) return chain.run(question) def query_with_sources( @@ -40,8 +42,8 @@ class VectorStoreIndexWrapper(BaseModel): ) -> dict: """Query the vectorstore and get back sources.""" llm = llm or OpenAI(temperature=0) - chain = VectorDBQAWithSourcesChain.from_chain_type( - llm, vectorstore=self.vectorstore, **kwargs + chain = RetrievalQAWithSourcesChain.from_chain_type( + llm, retriever=self.vectorstore.as_retriever(), **kwargs ) return chain({chain.question_key: question}) diff --git a/langchain/schema.py b/langchain/schema.py index b42a640b..087b81e0 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -240,6 +240,54 @@ class BaseMemory(BaseModel, ABC): """Clear memory contents.""" +class Document(BaseModel): + """Interface for interacting with a document.""" + + page_content: str + lookup_str: str = "" + lookup_index = 0 + metadata: dict = Field(default_factory=dict) + + @property + def paragraphs(self) -> List[str]: + """Paragraphs of the page.""" + return self.page_content.split("\n\n") + + @property + def summary(self) -> str: + """Summary of the page (the first paragraph).""" + return self.paragraphs[0] + + def lookup(self, string: str) -> str: + """Lookup a term in the page, imitating cmd-F functionality.""" + if string.lower() != self.lookup_str: + self.lookup_str = string.lower() + self.lookup_index = 0 + else: + self.lookup_index += 1 + lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()] + if len(lookups) == 0: + return "No Results" + elif self.lookup_index >= len(lookups): + return "No More Results" + else: + result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" + return f"{result_prefix} {lookups[self.lookup_index]}" + + +class BaseRetriever(ABC): + @abstractmethod + def get_relevant_texts(self, query: str) -> List[Document]: + """Get texts relevant for a query. + + Args: + query: string to find relevant tests for + + Returns: + List of relevant documents + """ + + # For backwards compatibility diff --git a/langchain/tools/vectorstore/tool.py b/langchain/tools/vectorstore/tool.py index ee5a0695..adc35e7f 100644 --- a/langchain/tools/vectorstore/tool.py +++ b/langchain/tools/vectorstore/tool.py @@ -6,7 +6,7 @@ from typing import Any, Dict from pydantic import BaseModel, Field from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain -from langchain.chains.vector_db_qa.base import VectorDBQA +from langchain.chains.retrieval_qa.base import VectorDBQA from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI from langchain.tools.base import BaseTool diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index e3b5241e..60f4d1e0 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -4,8 +4,11 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Iterable, List, Optional +from pydantic import BaseModel, Field + from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever class VectorStore(ABC): @@ -122,3 +125,19 @@ class VectorStore(ABC): **kwargs: Any, ) -> VectorStore: """Return VectorStore initialized from texts and embeddings.""" + + def as_retriever(self) -> VectorStoreRetriever: + return VectorStoreRetriever(vectorstore=self) + + +class VectorStoreRetriever(BaseRetriever, BaseModel): + vectorstore: VectorStore + search_kwargs: dict = Field(default_factory=dict) + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_relevant_texts(self, query: str) -> List[Document]: + return self.vectorstore.similarity_search(query, **self.search_kwargs)