From 76aff023d7aeaee9bbe6c3cf244dd6e8636f8bc2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 1 Nov 2022 21:29:39 -0700 Subject: [PATCH] FAISS and embedding support (#48) also adds embeddings and an in memory docstore --- docs/index.rst | 1 + docs/modules/embeddings.rst | 5 + examples/embeddings.ipynb | 98 +++++++++++++++++++ langchain/__init__.py | 2 + langchain/docstore/in_memory.py | 20 ++++ langchain/embeddings/__init__.py | 4 + langchain/embeddings/base.py | 15 +++ langchain/embeddings/openai.py | 84 ++++++++++++++++ langchain/faiss.py | 86 ++++++++++++++++ requirements.txt | 1 + setup.py | 2 +- test_requirements.txt | 1 + .../integration_tests/embeddings/__init__.py | 1 + .../embeddings/test_openai.py | 19 ++++ tests/integration_tests/test_faiss.py | 47 +++++++++ tests/unit_tests/docstore/test_inmemory.py | 21 ++++ 16 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 docs/modules/embeddings.rst create mode 100644 examples/embeddings.ipynb create mode 100644 langchain/docstore/in_memory.py create mode 100644 langchain/embeddings/__init__.py create mode 100644 langchain/embeddings/base.py create mode 100644 langchain/embeddings/openai.py create mode 100644 langchain/faiss.py create mode 100644 tests/integration_tests/embeddings/__init__.py create mode 100644 tests/integration_tests/embeddings/test_openai.py create mode 100644 tests/integration_tests/test_faiss.py create mode 100644 tests/unit_tests/docstore/test_inmemory.py diff --git a/docs/index.rst b/docs/index.rst index c6af00fd..418a0816 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,4 +7,5 @@ Welcome to LangChain modules/prompt modules/llms + modules/embeddings modules/chains diff --git a/docs/modules/embeddings.rst b/docs/modules/embeddings.rst new file mode 100644 index 00000000..ddba956f --- /dev/null +++ b/docs/modules/embeddings.rst @@ -0,0 +1,5 @@ +:mod:`langchain.embeddings` +=========================== + +.. automodule:: langchain.embeddings + :members: diff --git a/examples/embeddings.ipynb b/examples/embeddings.ipynb new file mode 100644 index 00000000..bd38758a --- /dev/null +++ b/examples/embeddings.ipynb @@ -0,0 +1,98 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "965eecee", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.faiss import FAISS\n", + "from langchain.text_splitter import CharacterTextSplitter" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "68481687", + "metadata": {}, + "outputs": [], + "source": [ + "with open('state_of_the_union.txt') as f:\n", + " state_of_the_union = f.read()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "texts = text_splitter.split_text(state_of_the_union)\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "docsearch = FAISS.from_texts(texts, embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "015f4ff5", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = docsearch.similarity_search(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "67baf32e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n", + "\n", + "A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n", + "\n", + "And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n" + ] + } + ], + "source": [ + "print(docs[0].page_content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25500fa6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index ef77f3ff..d7680041 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -15,6 +15,7 @@ from langchain.chains import ( SQLDatabaseChain, ) from langchain.docstore import Wikipedia +from langchain.faiss import FAISS from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.prompt import Prompt from langchain.sql_database import SQLDatabase @@ -33,4 +34,5 @@ __all__ = [ "HuggingFaceHub", "SQLDatabase", "SQLDatabaseChain", + "FAISS", ] diff --git a/langchain/docstore/in_memory.py b/langchain/docstore/in_memory.py new file mode 100644 index 00000000..5023d5ff --- /dev/null +++ b/langchain/docstore/in_memory.py @@ -0,0 +1,20 @@ +"""Simple in memory docstore in the form of a dict.""" +from typing import Dict, Union + +from langchain.docstore.base import Docstore +from langchain.docstore.document import Document + + +class InMemoryDocstore(Docstore): + """Simple in memory docstore in the form of a dict.""" + + def __init__(self, _dict: Dict[str, Document]): + """Initialize with dict.""" + self._dict = _dict + + def search(self, search: str) -> Union[str, Document]: + """Search via direct lookup.""" + if search not in self._dict: + return f"ID {search} not found." + else: + return self._dict[search] diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py new file mode 100644 index 00000000..b2d6091a --- /dev/null +++ b/langchain/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Wrappers around embedding modules.""" +from langchain.embeddings.openai import OpenAIEmbeddings + +__all__ = ["OpenAIEmbeddings"] diff --git a/langchain/embeddings/base.py b/langchain/embeddings/base.py new file mode 100644 index 00000000..4a56cd6a --- /dev/null +++ b/langchain/embeddings/base.py @@ -0,0 +1,15 @@ +"""Interface for embedding models.""" +from abc import ABC, abstractmethod +from typing import List + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py new file mode 100644 index 00000000..48f1a520 --- /dev/null +++ b/langchain/embeddings/openai.py @@ -0,0 +1,84 @@ +"""Wrapper around OpenAI embedding models.""" +import os +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings + + +class OpenAIEmbeddings(BaseModel, Embeddings): + """Wrapper around OpenAI embedding models. + + To use, you should have the ``openai`` python package installed, and the + environment variable ``OPENAI_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain.embeddings import OpenAIEmbeddings + openai = OpenAIEmbeddings(model_name="davinci") + """ + + client: Any #: :meta private: + model_name: str = "babbage" + """Model name to use.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + if "OPENAI_API_KEY" not in os.environ: + raise ValueError( + "Did not find OpenAI API key, please add an environment variable" + " `OPENAI_API_KEY` which contains it." + ) + try: + import openai + + values["client"] = openai.Embedding + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please it install it with `pip install openai`." + ) + return values + + def _embedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to OpenAI's embedding endpoint.""" + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + responses = [ + self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001") + for text in texts + ] + return responses + + def embed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + embedding = self._embedding_func( + text, engine=f"text-search-{self.model_name}-query-001" + ) + return embedding diff --git a/langchain/faiss.py b/langchain/faiss.py new file mode 100644 index 00000000..5474ad43 --- /dev/null +++ b/langchain/faiss.py @@ -0,0 +1,86 @@ +"""Wrapper around FAISS vector database.""" +from typing import Any, Callable, List + +import numpy as np + +from langchain.docstore.base import Docstore +from langchain.docstore.document import Document +from langchain.docstore.in_memory import InMemoryDocstore +from langchain.embeddings.base import Embeddings + + +class FAISS: + """Wrapper around FAISS vector database. + + To use, you should have the ``faiss`` python package installed. + + Example: + .. code-block:: python + + from langchain import FAISS + faiss = FAISS(embedding_function, index, docstore) + + """ + + def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore): + """Initialize with necessary components.""" + self.embedding_function = embedding_function + self.index = index + self.docstore = docstore + + def similarity_search(self, query: str, k: int = 4) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding_function(query) + _, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + for i in indices[0]: + if i == -1: + # This happens when not enough docs are returned. + continue + doc = self.docstore.search(str(i)) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {i}, got {doc}") + docs.append(doc) + return docs + + @classmethod + def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS": + """Construct FAISS wrapper from raw documents. + + This is a user friendly interface that: + 1. Embeds documents. + 2. Creates an in memory docstore + 3. Initializes the FAISS database + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain import FAISS + from langchain.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + faiss = FAISS.from_texts(texts, embeddings) + """ + try: + import faiss + except ImportError: + raise ValueError( + "Could not import faiss python package. " + "Please it install it with `pip install faiss` " + "or `pip install faiss-cpu` (depending on Python version)." + ) + embeddings = embedding.embed_documents(texts) + index = faiss.IndexFlatL2(len(embeddings[0])) + index.add(np.array(embeddings, dtype=np.float32)) + documents = [Document(page_content=text) for text in texts] + docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) + return cls(embedding.embed_query, index, docstore) diff --git a/requirements.txt b/requirements.txt index a6e1355c..1dd1e8a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,6 @@ google-search-results playwright wikipedia huggingface_hub +faiss # For development jupyter diff --git a/setup.py b/setup.py index 4ac3f475..68972b59 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( version=__version__, packages=find_packages(), description="Building applications with LLMs through composability", - install_requires=["pydantic", "sqlalchemy"], + install_requires=["pydantic", "sqlalchemy", "numpy"], long_description=long_description, license="MIT", url="https://github.com/hwchase17/langchain", diff --git a/test_requirements.txt b/test_requirements.txt index aea9aec7..19f50453 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,4 @@ -e . +# For testing pytest pytest-dotenv diff --git a/tests/integration_tests/embeddings/__init__.py b/tests/integration_tests/embeddings/__init__.py new file mode 100644 index 00000000..f72c0f3b --- /dev/null +++ b/tests/integration_tests/embeddings/__init__.py @@ -0,0 +1 @@ +"""Test embedding integrations.""" diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py new file mode 100644 index 00000000..a721f78c --- /dev/null +++ b/tests/integration_tests/embeddings/test_openai.py @@ -0,0 +1,19 @@ +"""Test openai embeddings.""" +from langchain.embeddings.openai import OpenAIEmbeddings + + +def test_openai_embedding_documents() -> None: + """Test openai embeddings.""" + documents = ["foo bar"] + embedding = OpenAIEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 2048 + + +def test_openai_embedding_query() -> None: + """Test openai embeddings.""" + document = "foo bar" + embedding = OpenAIEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 2048 diff --git a/tests/integration_tests/test_faiss.py b/tests/integration_tests/test_faiss.py new file mode 100644 index 00000000..0a42b0d7 --- /dev/null +++ b/tests/integration_tests/test_faiss.py @@ -0,0 +1,47 @@ +"""Test FAISS functionality.""" +from typing import List + +import pytest + +from langchain.docstore.document import Document +from langchain.docstore.in_memory import InMemoryDocstore +from langchain.embeddings.base import Embeddings +from langchain.faiss import FAISS + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [[i] * 10 for i in range(len(texts))] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [0] * 10 + + +def test_faiss() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + expected_docstore = InMemoryDocstore( + { + "0": Document(page_content="foo"), + "1": Document(page_content="bar"), + "2": Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_faiss_search_not_found() -> None: + """Test what happens when document is not found.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + # Get rid of the docstore to purposefully induce errors. + docsearch.docstore = InMemoryDocstore({}) + with pytest.raises(ValueError): + docsearch.similarity_search("foo") diff --git a/tests/unit_tests/docstore/test_inmemory.py b/tests/unit_tests/docstore/test_inmemory.py new file mode 100644 index 00000000..284f9224 --- /dev/null +++ b/tests/unit_tests/docstore/test_inmemory.py @@ -0,0 +1,21 @@ +"""Test in memory docstore.""" + +from langchain.docstore.document import Document +from langchain.docstore.in_memory import InMemoryDocstore + + +def test_document_found() -> None: + """Test document found.""" + _dict = {"foo": Document(page_content="bar")} + docstore = InMemoryDocstore(_dict) + output = docstore.search("foo") + assert isinstance(output, Document) + assert output.page_content == "bar" + + +def test_document_not_found() -> None: + """Test when document is not found.""" + _dict = {"foo": Document(page_content="bar")} + docstore = InMemoryDocstore(_dict) + output = docstore.search("bar") + assert output == "ID bar not found."