forked from Archives/langchain
cr
This commit is contained in:
commit
178e8217a4
265
docs/examples/demos/sequential_chains.ipynb
Normal file
265
docs/examples/demos/sequential_chains.ipynb
Normal file
@ -0,0 +1,265 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f73605d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Sequential Chains"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3b235f7a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The next step after calling a language model is make a series of calls to a language model. This is particularly useful when you want to take the output from one call and use it as the input to another.\n",
|
||||
"\n",
|
||||
"In this notebook we will walk through some examples for how to do this, using sequential chains. Sequential chains are defined as a series of chains, called in deterministic order. There are two types of sequential chains:\n",
|
||||
"\n",
|
||||
"- `SimpleSequentialChain`: The simplest form of sequential chains, where each step has a singular input/output, and the output of one step is the input to the next.\n",
|
||||
"- `SequentialChain`: A more general form of sequential chains, allowing for multiple inputs/outputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5162794e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SimpleSequentialChain\n",
|
||||
"\n",
|
||||
"In this series of chains, each individual chain has a single input and a single output, and the output of one step is used as input to the next.\n",
|
||||
"\n",
|
||||
"Let's walk through a toy example of doing this, where the first chain takes in the title of an imaginary play and then generates a synopsis for that title, and the second chain takes in the synopsis of that play and generates an imaginary review for that play."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3f2f9b8c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b8237d1a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is an LLMChain to write a synopsis given a title of a play.\n",
|
||||
"llm = OpenAI(temperature=.7)\n",
|
||||
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
|
||||
"\n",
|
||||
"Title: {title}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4a391730",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is an LLMChain to write a review of a play given a synopsis.\n",
|
||||
"llm = OpenAI(temperature=.7)\n",
|
||||
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
|
||||
"\n",
|
||||
"Play Synopsis:\n",
|
||||
"{synopsis}\n",
|
||||
"Review from a New York Times play critic of the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
|
||||
"review_chain = LLMChain(llm=llm, prompt=prompt_template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9368bd63",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the overall chain where we run these two chains in sequence.\n",
|
||||
"from langchain.chains import SimpleSequentialChain\n",
|
||||
"overall_chain = SimpleSequentialChain(chains=[synopsis_chain, review_chain], verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "d39e15f5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"\u001b[36;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"A young couple, John and Mary, are enjoying a day at the beach. As the sun sets, they share a romantic moment. However, their happiness is short-lived, as a tragic accident claims John's life. Mary is left devastated by the loss of her husband.\u001b[0m\n",
|
||||
"\u001b[33;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"review = overall_chain.run(\"Tragedy at sunset on the beach\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "c6649a01",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(review)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c3f1549a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sequential Chain\n",
|
||||
"Of course, not all sequential chains will be as simple as passing a single string as an argument and getting a single string as output for all steps in the chain. In this next example, we will experiment with more complex chains that involve multiple inputs, and where there also multiple final outputs. \n",
|
||||
"\n",
|
||||
"Of particular importance is how we name the input/output variable names. In the above example we didn't have to think about that because we were just passing the output of one chain directly as input to the next, but here we do have worry about that because we have multiple inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "02016a51",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is an LLMChain to write a synopsis given a title of a play and the era it is set in.\n",
|
||||
"llm = OpenAI(temperature=.7)\n",
|
||||
"template = \"\"\"You are a playwright. Given the title of play and the era it is set in, it is your job to write a synopsis for that title.\n",
|
||||
"\n",
|
||||
"Title: {title}\n",
|
||||
"Era: {era}\n",
|
||||
"Playwright: This is a synopsis for the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"title\", 'era'], template=template)\n",
|
||||
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"synopsis\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8bd38cc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is an LLMChain to write a review of a play given a synopsis.\n",
|
||||
"llm = OpenAI(temperature=.7)\n",
|
||||
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
|
||||
"\n",
|
||||
"Play Synopsis:\n",
|
||||
"{synopsis}\n",
|
||||
"Review from a New York Times play critic of the above play:\"\"\"\n",
|
||||
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
|
||||
"review_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"review\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "524523af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the overall chain where we run these two chains in sequence.\n",
|
||||
"from langchain.chains import SequentialChain\n",
|
||||
"overall_chain = SequentialChain(\n",
|
||||
" chains=[synopsis_chain, review_chain],\n",
|
||||
" input_variables=[\"era\", \"title\"],\n",
|
||||
" # Here we return multiple variables\n",
|
||||
" output_variables=[\"synopsis\", \"review\"],\n",
|
||||
" verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "3fd3a7be",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"\u001b[1mChain 0\u001b[0m:\n",
|
||||
"{'synopsis': \"\\n\\nThe play is set in Victorian England and follows the tragic story of a young woman who drowns while swimming at sunset on the beach. Her body is found the next morning by a fisherman who raises the alarm. The young woman's family and friends are devastated by her death and the play ends with their mourning her loss.\"}\n",
|
||||
"\n",
|
||||
"\u001b[1mChain 1\u001b[0m:\n",
|
||||
"{'review': '\\n\\n\"The play is a tragedy, pure and simple. It is the story of a young woman\\'s death, told through the eyes of those who loved her. It is a sad, beautiful play that will stay with you long after you\\'ve seen it. The acting is superb, and the writing is exquisite. If you are looking for a play that will touch your heart and make you think, this is it.\"'}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"review = overall_chain({\"title\":\"Tragedy at sunset on the beach\", \"era\": \"Victorian England\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6be70d27",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -21,5 +21,5 @@ We can now call it on some input!
|
||||
|
||||
```python
|
||||
text = "What would be a good company name a company that makes colorful socks?"
|
||||
llm(text)
|
||||
print(llm(text))
|
||||
```
|
||||
|
@ -1 +1 @@
|
||||
0.0.17
|
||||
0.0.18
|
||||
|
@ -2,6 +2,7 @@
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
@ -13,4 +14,6 @@ __all__ = [
|
||||
"SerpAPIChain",
|
||||
"SQLDatabaseChain",
|
||||
"VectorDBQA",
|
||||
"SequentialChain",
|
||||
"SimpleSequentialChain",
|
||||
]
|
||||
|
@ -38,8 +38,19 @@ class Chain(BaseModel, ABC):
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and add to output."""
|
||||
def __call__(
|
||||
self, inputs: Dict[str, Any], return_only_outputs: bool = False
|
||||
) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
|
||||
"""
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
@ -47,7 +58,10 @@ class Chain(BaseModel, ABC):
|
||||
if self.verbose:
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
return {**inputs, **outputs}
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
|
137
langchain/chains/sequential.py
Normal file
137
langchain/chains/sequential.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
|
||||
|
||||
class SequentialChain(Chain, BaseModel):
|
||||
"""Chain where the outputs of one step feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
input_variables: List[str]
|
||||
output_variables: List[str] #: :meta private:
|
||||
return_all: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.output_variables
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chains = values["chains"]
|
||||
input_variables = values["input_variables"]
|
||||
known_variables = set(input_variables)
|
||||
for chain in chains:
|
||||
missing_vars = set(chain.input_keys).difference(known_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required input keys: {missing_vars}")
|
||||
overlapping_keys = known_variables.intersection(chain.output_keys)
|
||||
if overlapping_keys:
|
||||
raise ValueError(
|
||||
f"Chain returned keys that already exist: {overlapping_keys}"
|
||||
)
|
||||
known_variables |= set(chain.output_keys)
|
||||
|
||||
if "output_variables" not in values:
|
||||
if values.get("return_all", False):
|
||||
output_keys = known_variables.difference(input_variables)
|
||||
else:
|
||||
output_keys = chains[-1].output_keys
|
||||
values["output_variables"] = output_keys
|
||||
else:
|
||||
missing_vars = set(values["output_variables"]).difference(known_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Expected output variables that were not found: {missing_vars}."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
known_values = inputs.copy()
|
||||
for i, chain in enumerate(self.chains):
|
||||
outputs = chain(known_values, return_only_outputs=True)
|
||||
if self.verbose:
|
||||
print(f"\033[1mChain {i}\033[0m:\n{outputs}\n")
|
||||
known_values.update(outputs)
|
||||
return {k: known_values[k] for k in self.output_variables}
|
||||
|
||||
|
||||
class SimpleSequentialChain(Chain, BaseModel):
|
||||
"""Simple chain where the outputs of one step feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
strip_outputs: bool = False
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that chains are all single input/output."""
|
||||
for chain in values["chains"]:
|
||||
if len(chain.input_keys) != 1:
|
||||
raise ValueError(
|
||||
"Chains used in SimplePipeline should all have one input, got "
|
||||
f"{chain} with {len(chain.input_keys)} inputs."
|
||||
)
|
||||
if len(chain.output_keys) != 1:
|
||||
raise ValueError(
|
||||
"Chains used in SimplePipeline should all have one output, got "
|
||||
f"{chain} with {len(chain.output_keys)} outputs."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = chain.run(_input)
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
if self.verbose:
|
||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||
return {self.output_key: _input}
|
@ -1,6 +1,6 @@
|
||||
"""Interface to access to place that stores documents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@ -15,3 +15,11 @@ class Docstore(ABC):
|
||||
If page exists, return the page summary, and a Document object.
|
||||
If page does not exist, return similar entries.
|
||||
"""
|
||||
|
||||
|
||||
class AddableMixin(ABC):
|
||||
"""Mixin class that supports adding texts."""
|
||||
|
||||
@abstractmethod
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add more documents."""
|
||||
|
@ -1,17 +1,24 @@
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class InMemoryDocstore(Docstore):
|
||||
class InMemoryDocstore(Docstore, AddableMixin):
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
|
||||
def __init__(self, _dict: Dict[str, Document]):
|
||||
"""Initialize with dict."""
|
||||
self._dict = _dict
|
||||
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add texts to in memory dictionary."""
|
||||
overlapping = set(texts).intersection(self._dict)
|
||||
if overlapping:
|
||||
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
||||
self._dict = dict(self._dict, **texts)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Search via direct lookup."""
|
||||
if search not in self._dict:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Interface for vector stores."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -9,6 +9,10 @@ from langchain.embeddings.base import Embeddings
|
||||
class VectorStore(ABC):
|
||||
"""Interface for vector stores."""
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -65,6 +65,28 @@ class ElasticVectorSearch(VectorStore):
|
||||
)
|
||||
self.client = es_client
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
try:
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticearch`."
|
||||
)
|
||||
requests = []
|
||||
for i, text in enumerate(texts):
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"vector": self.embedding_function(text),
|
||||
"text": text,
|
||||
}
|
||||
requests.append(request)
|
||||
bulk(self.client, requests)
|
||||
# TODO: add option not to refresh
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from typing import Any, Callable, List, Optional
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -23,11 +24,41 @@ class FAISS(VectorStore):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
|
||||
def add_texts(self, texts: Iterable[str]) -> None:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
# Embed and create the documents.
|
||||
embeddings = [self.embedding_function(text) for text in texts]
|
||||
documents = [Document(page_content=text) for text in texts]
|
||||
# Add to the index, the index_to_id mapping, and the docstore.
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
self.index.add(np.array(embeddings, dtype=np.float32))
|
||||
# Get list of index, id, and docs.
|
||||
full_info = [
|
||||
(starting_len + i, str(uuid.uuid4()), doc)
|
||||
for i, doc in enumerate(documents)
|
||||
]
|
||||
# Add information to docstore and index.
|
||||
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
@ -46,9 +77,10 @@ class FAISS(VectorStore):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
doc = self.docstore.search(str(i))
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -92,5 +124,8 @@ class FAISS(VectorStore):
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
||||
return cls(embedding.embed_query, index, docstore)
|
||||
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
||||
docstore = InMemoryDocstore(
|
||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
||||
)
|
||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.docstore.wikipedia import Wikipedia
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
|
||||
@ -25,11 +26,12 @@ def test_faiss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
"0": Document(page_content="foo"),
|
||||
"1": Document(page_content="bar"),
|
||||
"2": Document(page_content="baz"),
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
@ -62,3 +64,21 @@ def test_faiss_search_not_found() -> None:
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
||||
|
||||
|
||||
def test_faiss_add_texts() -> None:
|
||||
"""Test end to end adding of texts."""
|
||||
# Create initial doc store.
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Test adding a similar document as before.
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_faiss_add_texts_not_supported() -> None:
|
||||
"""Test adding of texts to a docstore that doesn't support it."""
|
||||
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.add_texts(["foo"])
|
||||
|
140
tests/unit_tests/chains/test_sequential.py
Normal file
140
tests/unit_tests/chains/test_sequential.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""Test pipeline functionality."""
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
"""Fake Chain for testing purposes."""
|
||||
|
||||
input_variables: List[str]
|
||||
output_variables: List[str]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
outputs[var] = " ".join(variables) + "foo"
|
||||
return outputs
|
||||
|
||||
|
||||
def test_sequential_usage_single_inputs() -> None:
|
||||
"""Test sequential on single input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = chain({"foo": "123"})
|
||||
expected_output = {"baz": "123foofoo", "foo": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_usage_multiple_inputs() -> None:
|
||||
"""Test sequential on multiple input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||
output = chain({"foo": "123", "test": "456"})
|
||||
expected_output = {
|
||||
"baz": "123 456foo 123foo",
|
||||
"foo": "123",
|
||||
"test": "456",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_usage_multiple_outputs() -> None:
|
||||
"""Test sequential usage on multiple output chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = chain({"foo": "123"})
|
||||
expected_output = {
|
||||
"baz": "123foo 123foo",
|
||||
"foo": "123",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_missing_inputs() -> None:
|
||||
"""Test error is raised when input variables are missing."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# Also needs "test" as an input
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
|
||||
|
||||
def test_sequential_bad_outputs() -> None:
|
||||
"""Test error is raised when bad outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is not present as an output variable.
|
||||
SequentialChain(
|
||||
chains=[chain_1, chain_2],
|
||||
input_variables=["foo"],
|
||||
output_variables=["test"],
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_valid_outputs() -> None:
|
||||
"""Test chain runs when valid outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SequentialChain(
|
||||
chains=[chain_1, chain_2],
|
||||
input_variables=["foo"],
|
||||
output_variables=["bar", "baz"],
|
||||
)
|
||||
output = chain({"foo": "123"}, return_only_outputs=True)
|
||||
expected_output = {"baz": "123foofoo", "bar": "123foo"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_overlapping_inputs() -> None:
|
||||
"""Test error is raised when input variables are overlapping."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is specified as an input, but also is an output of one step
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||
|
||||
|
||||
def test_simple_sequential_functionality() -> None:
|
||||
"""Test simple sequential functionality."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
output = chain({"input": "123"})
|
||||
expected_output = {"output": "123foofoo", "input": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_multi_input_errors() -> None:
|
||||
"""Test simple sequential errors if multiple input variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
|
||||
|
||||
def test_multi_output_errors() -> None:
|
||||
"""Test simple sequential errors if multiple output variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
@ -1,4 +1,5 @@
|
||||
"""Test in memory docstore."""
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
@ -19,3 +20,37 @@ def test_document_not_found() -> None:
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("bar")
|
||||
assert output == "ID bar not found."
|
||||
|
||||
|
||||
def test_adding_document() -> None:
|
||||
"""Test that documents are added correctly."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
new_dict = {"bar": Document(page_content="foo")}
|
||||
docstore.add(new_dict)
|
||||
|
||||
# Test that you can find new document.
|
||||
foo_output = docstore.search("bar")
|
||||
assert isinstance(foo_output, Document)
|
||||
assert foo_output.page_content == "foo"
|
||||
|
||||
# Test that old document is the same.
|
||||
bar_output = docstore.search("foo")
|
||||
assert isinstance(bar_output, Document)
|
||||
assert bar_output.page_content == "bar"
|
||||
|
||||
|
||||
def test_adding_document_already_exists() -> None:
|
||||
"""Test that error is raised if document id already exists."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
new_dict = {"foo": Document(page_content="foo")}
|
||||
|
||||
# Test that error is raised.
|
||||
with pytest.raises(ValueError):
|
||||
docstore.add(new_dict)
|
||||
|
||||
# Test that old document is the same.
|
||||
bar_output = docstore.search("foo")
|
||||
assert isinstance(bar_output, Document)
|
||||
assert bar_output.page_content == "bar"
|
||||
|
Loading…
Reference in New Issue
Block a user