From e9baf9c134f6d05b0ce7b87461bf296572e24c65 Mon Sep 17 00:00:00 2001 From: Jim Salmons Date: Sun, 20 Nov 2022 16:22:53 -0700 Subject: [PATCH 1/4] Update llm.md (#164) Without the print on the `llm` call, the new user sees no visible effect when just getting started. The assumption here is the new user is running this in a new sandbox script file or repl via copy-paste. --- docs/getting_started/llm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/getting_started/llm.md b/docs/getting_started/llm.md index a4ac89cc..9133a4a1 100644 --- a/docs/getting_started/llm.md +++ b/docs/getting_started/llm.md @@ -21,5 +21,5 @@ We can now call it on some input! ```python text = "What would be a good company name a company that makes colorful socks?" -llm(text) +print(llm(text)) ``` From 315b0c09c614fa44daa61529d1f1da2fe827b16c Mon Sep 17 00:00:00 2001 From: Samantha Whitmore Date: Sun, 20 Nov 2022 16:23:58 -0800 Subject: [PATCH 2/4] wip: add method for both docstore and embeddings (#119) this will break atm but wanted to get thoughts on implementation. 1. should add() be on docstore interface? 2. should InMemoryDocstore change to take a list of documents as init? (makes this slightly easier to implement in FAISS -- if we think it is less clean then could expose a method to get the number of documents currently in the dict, and perform the logic of creating the necessary dictionary in the FAISS.add_texts method. Co-authored-by: Harrison Chase --- langchain/docstore/base.py | 10 +++- langchain/docstore/in_memory.py | 11 ++++- langchain/vectorstores/base.py | 6 ++- .../vectorstores/elastic_vector_search.py | 24 ++++++++- langchain/vectorstores/faiss.py | 49 ++++++++++++++++--- .../vectorstores/test_faiss.py | 26 ++++++++-- tests/unit_tests/docstore/test_inmemory.py | 35 +++++++++++++ 7 files changed, 146 insertions(+), 15 deletions(-) diff --git a/langchain/docstore/base.py b/langchain/docstore/base.py index 2849dd09..4a91680c 100644 --- a/langchain/docstore/base.py +++ b/langchain/docstore/base.py @@ -1,6 +1,6 @@ """Interface to access to place that stores documents.""" from abc import ABC, abstractmethod -from typing import Union +from typing import Dict, Union from langchain.docstore.document import Document @@ -15,3 +15,11 @@ class Docstore(ABC): If page exists, return the page summary, and a Document object. If page does not exist, return similar entries. """ + + +class AddableMixin(ABC): + """Mixin class that supports adding texts.""" + + @abstractmethod + def add(self, texts: Dict[str, Document]) -> None: + """Add more documents.""" diff --git a/langchain/docstore/in_memory.py b/langchain/docstore/in_memory.py index 5023d5ff..f1e36102 100644 --- a/langchain/docstore/in_memory.py +++ b/langchain/docstore/in_memory.py @@ -1,17 +1,24 @@ """Simple in memory docstore in the form of a dict.""" from typing import Dict, Union -from langchain.docstore.base import Docstore +from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document -class InMemoryDocstore(Docstore): +class InMemoryDocstore(Docstore, AddableMixin): """Simple in memory docstore in the form of a dict.""" def __init__(self, _dict: Dict[str, Document]): """Initialize with dict.""" self._dict = _dict + def add(self, texts: Dict[str, Document]) -> None: + """Add texts to in memory dictionary.""" + overlapping = set(texts).intersection(self._dict) + if overlapping: + raise ValueError(f"Tried to add ids that already exist: {overlapping}") + self._dict = dict(self._dict, **texts) + def search(self, search: str) -> Union[str, Document]: """Search via direct lookup.""" if search not in self._dict: diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index a7097893..8c9b171c 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -1,6 +1,6 @@ """Interface for vector stores.""" from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -9,6 +9,10 @@ from langchain.embeddings.base import Embeddings class VectorStore(ABC): """Interface for vector stores.""" + @abstractmethod + def add_texts(self, texts: Iterable[str]) -> None: + """Run more texts through the embeddings and add to the vectorstore.""" + @abstractmethod def similarity_search(self, query: str, k: int = 4) -> List[Document]: """Return docs most similar to query.""" diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index b186cfff..91946364 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -1,6 +1,6 @@ """Wrapper around Elasticsearch vector database.""" import uuid -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -65,6 +65,28 @@ class ElasticVectorSearch(VectorStore): ) self.client = es_client + def add_texts(self, texts: Iterable[str]) -> None: + """Run more texts through the embeddings and add to the vectorstore.""" + try: + from elasticsearch.helpers import bulk + except ImportError: + raise ValueError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticearch`." + ) + requests = [] + for i, text in enumerate(texts): + request = { + "_op_type": "index", + "_index": self.index_name, + "vector": self.embedding_function(text), + "text": text, + } + requests.append(request) + bulk(self.client, requests) + # TODO: add option not to refresh + self.client.indices.refresh(index=self.index_name) + def similarity_search(self, query: str, k: int = 4) -> List[Document]: """Return docs most similar to query. diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 8ae2e3f0..2b3e4d61 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -1,9 +1,10 @@ """Wrapper around FAISS vector database.""" -from typing import Any, Callable, List, Optional +import uuid +from typing import Any, Callable, Dict, Iterable, List, Optional import numpy as np -from langchain.docstore.base import Docstore +from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore from langchain.embeddings.base import Embeddings @@ -23,11 +24,41 @@ class FAISS(VectorStore): """ - def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore): + def __init__( + self, + embedding_function: Callable, + index: Any, + docstore: Docstore, + index_to_docstore_id: Dict[int, str], + ): """Initialize with necessary components.""" self.embedding_function = embedding_function self.index = index self.docstore = docstore + self.index_to_docstore_id = index_to_docstore_id + + def add_texts(self, texts: Iterable[str]) -> None: + """Run more texts through the embeddings and add to the vectorstore.""" + if not isinstance(self.docstore, AddableMixin): + raise ValueError( + "If trying to add texts, the underlying docstore should support " + f"adding items, which {self.docstore} does not" + ) + # Embed and create the documents. + embeddings = [self.embedding_function(text) for text in texts] + documents = [Document(page_content=text) for text in texts] + # Add to the index, the index_to_id mapping, and the docstore. + starting_len = len(self.index_to_docstore_id) + self.index.add(np.array(embeddings, dtype=np.float32)) + # Get list of index, id, and docs. + full_info = [ + (starting_len + i, str(uuid.uuid4()), doc) + for i, doc in enumerate(documents) + ] + # Add information to docstore and index. + self.docstore.add({_id: doc for _, _id, doc in full_info}) + index_to_id = {index: _id for index, _id, _ in full_info} + self.index_to_docstore_id.update(index_to_id) def similarity_search(self, query: str, k: int = 4) -> List[Document]: """Return docs most similar to query. @@ -46,9 +77,10 @@ class FAISS(VectorStore): if i == -1: # This happens when not enough docs are returned. continue - doc = self.docstore.search(str(i)) + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {i}, got {doc}") + raise ValueError(f"Could not find document for id {_id}, got {doc}") docs.append(doc) return docs @@ -92,5 +124,8 @@ class FAISS(VectorStore): for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} documents.append(Document(page_content=text, metadata=metadata)) - docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) - return cls(embedding.embed_query, index, docstore) + index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} + docstore = InMemoryDocstore( + {index_to_id[i]: doc for i, doc in enumerate(documents)} + ) + return cls(embedding.embed_query, index, docstore, index_to_id) diff --git a/tests/integration_tests/vectorstores/test_faiss.py b/tests/integration_tests/vectorstores/test_faiss.py index 2b3cbd1d..c3d2ba57 100644 --- a/tests/integration_tests/vectorstores/test_faiss.py +++ b/tests/integration_tests/vectorstores/test_faiss.py @@ -5,6 +5,7 @@ import pytest from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore +from langchain.docstore.wikipedia import Wikipedia from langchain.embeddings.base import Embeddings from langchain.vectorstores.faiss import FAISS @@ -25,11 +26,12 @@ def test_faiss() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id expected_docstore = InMemoryDocstore( { - "0": Document(page_content="foo"), - "1": Document(page_content="bar"), - "2": Document(page_content="baz"), + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), } ) assert docsearch.docstore.__dict__ == expected_docstore.__dict__ @@ -62,3 +64,21 @@ def test_faiss_search_not_found() -> None: docsearch.docstore = InMemoryDocstore({}) with pytest.raises(ValueError): docsearch.similarity_search("foo") + + +def test_faiss_add_texts() -> None: + """Test end to end adding of texts.""" + # Create initial doc store. + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + # Test adding a similar document as before. + docsearch.add_texts(["foo"]) + output = docsearch.similarity_search("foo", k=2) + assert output == [Document(page_content="foo"), Document(page_content="foo")] + + +def test_faiss_add_texts_not_supported() -> None: + """Test adding of texts to a docstore that doesn't support it.""" + docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {}) + with pytest.raises(ValueError): + docsearch.add_texts(["foo"]) diff --git a/tests/unit_tests/docstore/test_inmemory.py b/tests/unit_tests/docstore/test_inmemory.py index 284f9224..4fe9104c 100644 --- a/tests/unit_tests/docstore/test_inmemory.py +++ b/tests/unit_tests/docstore/test_inmemory.py @@ -1,4 +1,5 @@ """Test in memory docstore.""" +import pytest from langchain.docstore.document import Document from langchain.docstore.in_memory import InMemoryDocstore @@ -19,3 +20,37 @@ def test_document_not_found() -> None: docstore = InMemoryDocstore(_dict) output = docstore.search("bar") assert output == "ID bar not found." + + +def test_adding_document() -> None: + """Test that documents are added correctly.""" + _dict = {"foo": Document(page_content="bar")} + docstore = InMemoryDocstore(_dict) + new_dict = {"bar": Document(page_content="foo")} + docstore.add(new_dict) + + # Test that you can find new document. + foo_output = docstore.search("bar") + assert isinstance(foo_output, Document) + assert foo_output.page_content == "foo" + + # Test that old document is the same. + bar_output = docstore.search("foo") + assert isinstance(bar_output, Document) + assert bar_output.page_content == "bar" + + +def test_adding_document_already_exists() -> None: + """Test that error is raised if document id already exists.""" + _dict = {"foo": Document(page_content="bar")} + docstore = InMemoryDocstore(_dict) + new_dict = {"foo": Document(page_content="foo")} + + # Test that error is raised. + with pytest.raises(ValueError): + docstore.add(new_dict) + + # Test that old document is the same. + bar_output = docstore.search("foo") + assert isinstance(bar_output, Document) + assert bar_output.page_content == "bar" From 15c19fcc6038990c795aa69e645123a9e53283fe Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 21 Nov 2022 09:34:44 -0800 Subject: [PATCH 3/4] bump version to 0.0.18 (#167) --- langchain/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/VERSION b/langchain/VERSION index cd231804..32786aa4 100644 --- a/langchain/VERSION +++ b/langchain/VERSION @@ -1 +1 @@ -0.0.17 +0.0.18 From 4a4dfbfbed5ca271fc74f61a0b3387314dda8703 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 21 Nov 2022 13:08:53 -0800 Subject: [PATCH 4/4] Harrison/sequential chains (#168) add support for basic sequential chains --- docs/examples/demos/sequential_chains.ipynb | 265 ++++++++++++++++++++ langchain/chains/__init__.py | 3 + langchain/chains/base.py | 20 +- langchain/chains/sequential.py | 137 ++++++++++ tests/unit_tests/chains/test_sequential.py | 140 +++++++++++ 5 files changed, 562 insertions(+), 3 deletions(-) create mode 100644 docs/examples/demos/sequential_chains.ipynb create mode 100644 langchain/chains/sequential.py create mode 100644 tests/unit_tests/chains/test_sequential.py diff --git a/docs/examples/demos/sequential_chains.ipynb b/docs/examples/demos/sequential_chains.ipynb new file mode 100644 index 00000000..2c907f84 --- /dev/null +++ b/docs/examples/demos/sequential_chains.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4f73605d", + "metadata": {}, + "source": [ + "# Sequential Chains" + ] + }, + { + "cell_type": "markdown", + "id": "3b235f7a", + "metadata": {}, + "source": [ + "The next step after calling a language model is make a series of calls to a language model. This is particularly useful when you want to take the output from one call and use it as the input to another.\n", + "\n", + "In this notebook we will walk through some examples for how to do this, using sequential chains. Sequential chains are defined as a series of chains, called in deterministic order. There are two types of sequential chains:\n", + "\n", + "- `SimpleSequentialChain`: The simplest form of sequential chains, where each step has a singular input/output, and the output of one step is the input to the next.\n", + "- `SequentialChain`: A more general form of sequential chains, allowing for multiple inputs/outputs." + ] + }, + { + "cell_type": "markdown", + "id": "5162794e", + "metadata": {}, + "source": [ + "## SimpleSequentialChain\n", + "\n", + "In this series of chains, each individual chain has a single input and a single output, and the output of one step is used as input to the next.\n", + "\n", + "Let's walk through a toy example of doing this, where the first chain takes in the title of an imaginary play and then generates a synopsis for that title, and the second chain takes in the synopsis of that play and generates an imaginary review for that play." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3f2f9b8c", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.chains import LLMChain\n", + "from langchain.prompts import PromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b8237d1a", + "metadata": {}, + "outputs": [], + "source": [ + "# This is an LLMChain to write a synopsis given a title of a play.\n", + "llm = OpenAI(temperature=.7)\n", + "template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n", + "\n", + "Title: {title}\n", + "Playwright: This is a synopsis for the above play:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4a391730", + "metadata": {}, + "outputs": [], + "source": [ + "# This is an LLMChain to write a review of a play given a synopsis.\n", + "llm = OpenAI(temperature=.7)\n", + "template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n", + "\n", + "Play Synopsis:\n", + "{synopsis}\n", + "Review from a New York Times play critic of the above play:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n", + "review_chain = LLMChain(llm=llm, prompt=prompt_template)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9368bd63", + "metadata": {}, + "outputs": [], + "source": [ + "# This is the overall chain where we run these two chains in sequence.\n", + "from langchain.chains import SimpleSequentialChain\n", + "overall_chain = SimpleSequentialChain(chains=[synopsis_chain, review_chain], verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d39e15f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[36;1m\u001b[1;3m\n", + "\n", + "A young couple, John and Mary, are enjoying a day at the beach. As the sun sets, they share a romantic moment. However, their happiness is short-lived, as a tragic accident claims John's life. Mary is left devastated by the loss of her husband.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3m\n", + "\n", + "\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "review = overall_chain.run(\"Tragedy at sunset on the beach\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c6649a01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\"A young couple's happiness is cut short by tragedy in this moving play. Mary is left devastated by the loss of her husband, John, in a freak accident. The play captures the pain and grief of loss, as well as the strength of love. A must-see for fans of theater.\"\n" + ] + } + ], + "source": [ + "print(review)" + ] + }, + { + "cell_type": "markdown", + "id": "c3f1549a", + "metadata": {}, + "source": [ + "## Sequential Chain\n", + "Of course, not all sequential chains will be as simple as passing a single string as an argument and getting a single string as output for all steps in the chain. In this next example, we will experiment with more complex chains that involve multiple inputs, and where there also multiple final outputs. \n", + "\n", + "Of particular importance is how we name the input/output variable names. In the above example we didn't have to think about that because we were just passing the output of one chain directly as input to the next, but here we do have worry about that because we have multiple inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "02016a51", + "metadata": {}, + "outputs": [], + "source": [ + "# This is an LLMChain to write a synopsis given a title of a play and the era it is set in.\n", + "llm = OpenAI(temperature=.7)\n", + "template = \"\"\"You are a playwright. Given the title of play and the era it is set in, it is your job to write a synopsis for that title.\n", + "\n", + "Title: {title}\n", + "Era: {era}\n", + "Playwright: This is a synopsis for the above play:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"title\", 'era'], template=template)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"synopsis\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bd38cc2", + "metadata": {}, + "outputs": [], + "source": [ + "# This is an LLMChain to write a review of a play given a synopsis.\n", + "llm = OpenAI(temperature=.7)\n", + "template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n", + "\n", + "Play Synopsis:\n", + "{synopsis}\n", + "Review from a New York Times play critic of the above play:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n", + "review_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"review\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "524523af", + "metadata": {}, + "outputs": [], + "source": [ + "# This is the overall chain where we run these two chains in sequence.\n", + "from langchain.chains import SequentialChain\n", + "overall_chain = SequentialChain(\n", + " chains=[synopsis_chain, review_chain],\n", + " input_variables=[\"era\", \"title\"],\n", + " # Here we return multiple variables\n", + " output_variables=[\"synopsis\", \"review\"],\n", + " verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3fd3a7be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[1mChain 0\u001b[0m:\n", + "{'synopsis': \"\\n\\nThe play is set in Victorian England and follows the tragic story of a young woman who drowns while swimming at sunset on the beach. Her body is found the next morning by a fisherman who raises the alarm. The young woman's family and friends are devastated by her death and the play ends with their mourning her loss.\"}\n", + "\n", + "\u001b[1mChain 1\u001b[0m:\n", + "{'review': '\\n\\n\"The play is a tragedy, pure and simple. It is the story of a young woman\\'s death, told through the eyes of those who loved her. It is a sad, beautiful play that will stay with you long after you\\'ve seen it. The acting is superb, and the writing is exquisite. If you are looking for a play that will touch your heart and make you think, this is it.\"'}\n", + "\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "review = overall_chain({\"title\":\"Tragedy at sunset on the beach\", \"era\": \"Victorian England\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6be70d27", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index ae27d37e..ec5ac1d5 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -5,6 +5,7 @@ from langchain.chains.mrkl.base import MRKLChain from langchain.chains.python import PythonChain from langchain.chains.react.base import ReActChain from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain +from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.serpapi import SerpAPIChain from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.vector_db_qa.base import VectorDBQA @@ -19,4 +20,6 @@ __all__ = [ "SQLDatabaseChain", "MRKLChain", "VectorDBQA", + "SequentialChain", + "SimpleSequentialChain", ] diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 0f9edecb..c3048293 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -38,8 +38,19 @@ class Chain(BaseModel, ABC): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: """Run the logic of this chain and return the output.""" - def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: - """Run the logic of this chain and add to output.""" + def __call__( + self, inputs: Dict[str, Any], return_only_outputs: bool = False + ) -> Dict[str, str]: + """Run the logic of this chain and add to output if desired. + + Args: + inputs: Dictionary of inputs. + return_only_outputs: boolean for whether to return only outputs in the + response. If True, only new keys generated by this chain will be + returned. If False, both input keys and new keys generated by this + chain will be returned. Defaults to False. + + """ self._validate_inputs(inputs) if self.verbose: print("\n\n\033[1m> Entering new chain...\033[0m") @@ -47,7 +58,10 @@ class Chain(BaseModel, ABC): if self.verbose: print("\n\033[1m> Finished chain.\033[0m") self._validate_outputs(outputs) - return {**inputs, **outputs} + if return_only_outputs: + return outputs + else: + return {**inputs, **outputs} def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Call the chain on all inputs in the list.""" diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py new file mode 100644 index 00000000..1bcc723d --- /dev/null +++ b/langchain/chains/sequential.py @@ -0,0 +1,137 @@ +"""Chain pipeline where the outputs of one step feed directly into next.""" + +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.base import Chain +from langchain.input import get_color_mapping, print_text + + +class SequentialChain(Chain, BaseModel): + """Chain where the outputs of one step feed directly into next.""" + + chains: List[Chain] + input_variables: List[str] + output_variables: List[str] #: :meta private: + return_all: bool = False + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return self.output_variables + + @root_validator(pre=True) + def validate_chains(cls, values: Dict) -> Dict: + """Validate that the correct inputs exist for all chains.""" + chains = values["chains"] + input_variables = values["input_variables"] + known_variables = set(input_variables) + for chain in chains: + missing_vars = set(chain.input_keys).difference(known_variables) + if missing_vars: + raise ValueError(f"Missing required input keys: {missing_vars}") + overlapping_keys = known_variables.intersection(chain.output_keys) + if overlapping_keys: + raise ValueError( + f"Chain returned keys that already exist: {overlapping_keys}" + ) + known_variables |= set(chain.output_keys) + + if "output_variables" not in values: + if values.get("return_all", False): + output_keys = known_variables.difference(input_variables) + else: + output_keys = chains[-1].output_keys + values["output_variables"] = output_keys + else: + missing_vars = set(values["output_variables"]).difference(known_variables) + if missing_vars: + raise ValueError( + f"Expected output variables that were not found: {missing_vars}." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + known_values = inputs.copy() + for i, chain in enumerate(self.chains): + outputs = chain(known_values, return_only_outputs=True) + if self.verbose: + print(f"\033[1mChain {i}\033[0m:\n{outputs}\n") + known_values.update(outputs) + return {k: known_values[k] for k in self.output_variables} + + +class SimpleSequentialChain(Chain, BaseModel): + """Simple chain where the outputs of one step feed directly into next.""" + + chains: List[Chain] + strip_outputs: bool = False + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + @root_validator() + def validate_chains(cls, values: Dict) -> Dict: + """Validate that chains are all single input/output.""" + for chain in values["chains"]: + if len(chain.input_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one input, got " + f"{chain} with {len(chain.input_keys)} inputs." + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one output, got " + f"{chain} with {len(chain.output_keys)} outputs." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + _input = inputs[self.input_key] + color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) + for i, chain in enumerate(self.chains): + _input = chain.run(_input) + if self.strip_outputs: + _input = _input.strip() + if self.verbose: + print_text(_input, color=color_mapping[str(i)], end="\n") + return {self.output_key: _input} diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py new file mode 100644 index 00000000..aa83f2ac --- /dev/null +++ b/tests/unit_tests/chains/test_sequential.py @@ -0,0 +1,140 @@ +"""Test pipeline functionality.""" +from typing import Dict, List + +import pytest +from pydantic import BaseModel + +from langchain.chains.base import Chain +from langchain.chains.sequential import SequentialChain, SimpleSequentialChain + + +class FakeChain(Chain, BaseModel): + """Fake Chain for testing purposes.""" + + input_variables: List[str] + output_variables: List[str] + + @property + def input_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.output_variables + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + outputs = {} + for var in self.output_variables: + variables = [inputs[k] for k in self.input_variables] + outputs[var] = " ".join(variables) + "foo" + return outputs + + +def test_sequential_usage_single_inputs() -> None: + """Test sequential on single input chains.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) + output = chain({"foo": "123"}) + expected_output = {"baz": "123foofoo", "foo": "123"} + assert output == expected_output + + +def test_sequential_usage_multiple_inputs() -> None: + """Test sequential on multiple input chains.""" + chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) + output = chain({"foo": "123", "test": "456"}) + expected_output = { + "baz": "123 456foo 123foo", + "foo": "123", + "test": "456", + } + assert output == expected_output + + +def test_sequential_usage_multiple_outputs() -> None: + """Test sequential usage on multiple output chains.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) + output = chain({"foo": "123"}) + expected_output = { + "baz": "123foo 123foo", + "foo": "123", + } + assert output == expected_output + + +def test_sequential_missing_inputs() -> None: + """Test error is raised when input variables are missing.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"]) + with pytest.raises(ValueError): + # Also needs "test" as an input + SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) + + +def test_sequential_bad_outputs() -> None: + """Test error is raised when bad outputs are specified.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + # "test" is not present as an output variable. + SequentialChain( + chains=[chain_1, chain_2], + input_variables=["foo"], + output_variables=["test"], + ) + + +def test_sequential_valid_outputs() -> None: + """Test chain runs when valid outputs are specified.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + chain = SequentialChain( + chains=[chain_1, chain_2], + input_variables=["foo"], + output_variables=["bar", "baz"], + ) + output = chain({"foo": "123"}, return_only_outputs=True) + expected_output = {"baz": "123foofoo", "bar": "123foo"} + assert output == expected_output + + +def test_sequential_overlapping_inputs() -> None: + """Test error is raised when input variables are overlapping.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + # "test" is specified as an input, but also is an output of one step + SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) + + +def test_simple_sequential_functionality() -> None: + """Test simple sequential functionality.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + chain = SimpleSequentialChain(chains=[chain_1, chain_2]) + output = chain({"input": "123"}) + expected_output = {"output": "123foofoo", "input": "123"} + assert output == expected_output + + +def test_multi_input_errors() -> None: + """Test simple sequential errors if multiple input variables are expected.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + with pytest.raises(ValueError): + SimpleSequentialChain(chains=[chain_1, chain_2]) + + +def test_multi_output_errors() -> None: + """Test simple sequential errors if multiple output variables are expected.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + SimpleSequentialChain(chains=[chain_1, chain_2])