mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
FAISS and embedding support (#48)
also adds embeddings and an in memory docstore
This commit is contained in:
parent
798deaec2b
commit
76aff023d7
@ -7,4 +7,5 @@ Welcome to LangChain
|
|||||||
|
|
||||||
modules/prompt
|
modules/prompt
|
||||||
modules/llms
|
modules/llms
|
||||||
|
modules/embeddings
|
||||||
modules/chains
|
modules/chains
|
||||||
|
5
docs/modules/embeddings.rst
Normal file
5
docs/modules/embeddings.rst
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
:mod:`langchain.embeddings`
|
||||||
|
===========================
|
||||||
|
|
||||||
|
.. automodule:: langchain.embeddings
|
||||||
|
:members:
|
98
examples/embeddings.ipynb
Normal file
98
examples/embeddings.ipynb
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "965eecee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||||
|
"from langchain.faiss import FAISS\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "68481687",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('state_of_the_union.txt') as f:\n",
|
||||||
|
" state_of_the_union = f.read()\n",
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
|
"texts = text_splitter.split_text(state_of_the_union)\n",
|
||||||
|
"\n",
|
||||||
|
"embeddings = OpenAIEmbeddings()\n",
|
||||||
|
"docsearch = FAISS.from_texts(texts, embeddings)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "015f4ff5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs = docsearch.similarity_search(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "67baf32e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||||
|
"\n",
|
||||||
|
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||||
|
"\n",
|
||||||
|
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n",
|
||||||
|
"\n",
|
||||||
|
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
|
||||||
|
"\n",
|
||||||
|
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(docs[0].page_content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "25500fa6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -15,6 +15,7 @@ from langchain.chains import (
|
|||||||
SQLDatabaseChain,
|
SQLDatabaseChain,
|
||||||
)
|
)
|
||||||
from langchain.docstore import Wikipedia
|
from langchain.docstore import Wikipedia
|
||||||
|
from langchain.faiss import FAISS
|
||||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||||
from langchain.prompt import Prompt
|
from langchain.prompt import Prompt
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"HuggingFaceHub",
|
"HuggingFaceHub",
|
||||||
"SQLDatabase",
|
"SQLDatabase",
|
||||||
"SQLDatabaseChain",
|
"SQLDatabaseChain",
|
||||||
|
"FAISS",
|
||||||
]
|
]
|
||||||
|
20
langchain/docstore/in_memory.py
Normal file
20
langchain/docstore/in_memory.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
from langchain.docstore.base import Docstore
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDocstore(Docstore):
|
||||||
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
|
|
||||||
|
def __init__(self, _dict: Dict[str, Document]):
|
||||||
|
"""Initialize with dict."""
|
||||||
|
self._dict = _dict
|
||||||
|
|
||||||
|
def search(self, search: str) -> Union[str, Document]:
|
||||||
|
"""Search via direct lookup."""
|
||||||
|
if search not in self._dict:
|
||||||
|
return f"ID {search} not found."
|
||||||
|
else:
|
||||||
|
return self._dict[search]
|
4
langchain/embeddings/__init__.py
Normal file
4
langchain/embeddings/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"""Wrappers around embedding modules."""
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
__all__ = ["OpenAIEmbeddings"]
|
15
langchain/embeddings/base.py
Normal file
15
langchain/embeddings/base.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""Interface for embedding models."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings(ABC):
|
||||||
|
"""Interface for embedding models."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed search docs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed query text."""
|
84
langchain/embeddings/openai.py
Normal file
84
langchain/embeddings/openai.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""Wrapper around OpenAI embedding models."""
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Wrapper around OpenAI embedding models.
|
||||||
|
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
openai = OpenAIEmbeddings(model_name="davinci")
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model_name: str = "babbage"
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find OpenAI API key, please add an environment variable"
|
||||||
|
" `OPENAI_API_KEY` which contains it."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
values["client"] = openai.Embedding
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please it install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
"""Call out to OpenAI's embedding endpoint."""
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
responses = [
|
||||||
|
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding_func(
|
||||||
|
text, engine=f"text-search-{self.model_name}-query-001"
|
||||||
|
)
|
||||||
|
return embedding
|
86
langchain/faiss.py
Normal file
86
langchain/faiss.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""Wrapper around FAISS vector database."""
|
||||||
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.base import Docstore
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class FAISS:
|
||||||
|
"""Wrapper around FAISS vector database.
|
||||||
|
|
||||||
|
To use, you should have the ``faiss`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import FAISS
|
||||||
|
faiss = FAISS(embedding_function, index, docstore)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
||||||
|
"""Initialize with necessary components."""
|
||||||
|
self.embedding_function = embedding_function
|
||||||
|
self.index = index
|
||||||
|
self.docstore = docstore
|
||||||
|
|
||||||
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
embedding = self.embedding_function(query)
|
||||||
|
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||||
|
docs = []
|
||||||
|
for i in indices[0]:
|
||||||
|
if i == -1:
|
||||||
|
# This happens when not enough docs are returned.
|
||||||
|
continue
|
||||||
|
doc = self.docstore.search(str(i))
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
|
||||||
|
"""Construct FAISS wrapper from raw documents.
|
||||||
|
|
||||||
|
This is a user friendly interface that:
|
||||||
|
1. Embeds documents.
|
||||||
|
2. Creates an in memory docstore
|
||||||
|
3. Initializes the FAISS database
|
||||||
|
|
||||||
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import FAISS
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
faiss = FAISS.from_texts(texts, embeddings)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import faiss python package. "
|
||||||
|
"Please it install it with `pip install faiss` "
|
||||||
|
"or `pip install faiss-cpu` (depending on Python version)."
|
||||||
|
)
|
||||||
|
embeddings = embedding.embed_documents(texts)
|
||||||
|
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||||
|
index.add(np.array(embeddings, dtype=np.float32))
|
||||||
|
documents = [Document(page_content=text) for text in texts]
|
||||||
|
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
||||||
|
return cls(embedding.embed_query, index, docstore)
|
@ -12,5 +12,6 @@ google-search-results
|
|||||||
playwright
|
playwright
|
||||||
wikipedia
|
wikipedia
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
|
faiss
|
||||||
# For development
|
# For development
|
||||||
jupyter
|
jupyter
|
||||||
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
|||||||
version=__version__,
|
version=__version__,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
description="Building applications with LLMs through composability",
|
description="Building applications with LLMs through composability",
|
||||||
install_requires=["pydantic", "sqlalchemy"],
|
install_requires=["pydantic", "sqlalchemy", "numpy"],
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/hwchase17/langchain",
|
url="https://github.com/hwchase17/langchain",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
-e .
|
-e .
|
||||||
|
# For testing
|
||||||
pytest
|
pytest
|
||||||
pytest-dotenv
|
pytest-dotenv
|
||||||
|
1
tests/integration_tests/embeddings/__init__.py
Normal file
1
tests/integration_tests/embeddings/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Test embedding integrations."""
|
19
tests/integration_tests/embeddings/test_openai.py
Normal file
19
tests/integration_tests/embeddings/test_openai.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Test openai embeddings."""
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embedding_documents() -> None:
|
||||||
|
"""Test openai embeddings."""
|
||||||
|
documents = ["foo bar"]
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert len(output[0]) == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embedding_query() -> None:
|
||||||
|
"""Test openai embeddings."""
|
||||||
|
document = "foo bar"
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
output = embedding.embed_query(document)
|
||||||
|
assert len(output) == 2048
|
47
tests/integration_tests/test_faiss.py
Normal file
47
tests/integration_tests/test_faiss.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Test FAISS functionality."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.faiss import FAISS
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmbeddings(Embeddings):
|
||||||
|
"""Fake embeddings functionality for testing."""
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return simple embeddings."""
|
||||||
|
return [[i] * 10 for i in range(len(texts))]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return simple embeddings."""
|
||||||
|
return [0] * 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
"0": Document(page_content="foo"),
|
||||||
|
"1": Document(page_content="bar"),
|
||||||
|
"2": Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_search_not_found() -> None:
|
||||||
|
"""Test what happens when document is not found."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
# Get rid of the docstore to purposefully induce errors.
|
||||||
|
docsearch.docstore = InMemoryDocstore({})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.similarity_search("foo")
|
21
tests/unit_tests/docstore/test_inmemory.py
Normal file
21
tests/unit_tests/docstore/test_inmemory.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"""Test in memory docstore."""
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_found() -> None:
|
||||||
|
"""Test document found."""
|
||||||
|
_dict = {"foo": Document(page_content="bar")}
|
||||||
|
docstore = InMemoryDocstore(_dict)
|
||||||
|
output = docstore.search("foo")
|
||||||
|
assert isinstance(output, Document)
|
||||||
|
assert output.page_content == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_not_found() -> None:
|
||||||
|
"""Test when document is not found."""
|
||||||
|
_dict = {"foo": Document(page_content="bar")}
|
||||||
|
docstore = InMemoryDocstore(_dict)
|
||||||
|
output = docstore.search("bar")
|
||||||
|
assert output == "ID bar not found."
|
Loading…
Reference in New Issue
Block a user