diff --git a/docs/modules/chains/generic/llm.json b/docs/modules/chains/generic/llm.json new file mode 100644 index 0000000000..f843c42d27 --- /dev/null +++ b/docs/modules/chains/generic/llm.json @@ -0,0 +1,13 @@ +{ + "model_name": "text-davinci-003", + "temperature": 0.0, + "max_tokens": 256, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "n": 1, + "best_of": 1, + "request_timeout": null, + "logit_bias": {}, + "_type": "openai" +} \ No newline at end of file diff --git a/docs/modules/chains/generic/llm_chain.json b/docs/modules/chains/generic/llm_chain.json new file mode 100644 index 0000000000..6c907bcd57 --- /dev/null +++ b/docs/modules/chains/generic/llm_chain.json @@ -0,0 +1,27 @@ +{ + "memory": null, + "verbose": true, + "prompt": { + "input_variables": [ + "question" + ], + "output_parser": null, + "template": "Question: {question}\n\nAnswer: Let's think step by step.", + "template_format": "f-string" + }, + "llm": { + "model_name": "text-davinci-003", + "temperature": 0.0, + "max_tokens": 256, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "n": 1, + "best_of": 1, + "request_timeout": null, + "logit_bias": {}, + "_type": "openai" + }, + "output_key": "text", + "_type": "llm_chain" +} \ No newline at end of file diff --git a/docs/modules/chains/generic/llm_chain_separate.json b/docs/modules/chains/generic/llm_chain_separate.json new file mode 100644 index 0000000000..340d813db2 --- /dev/null +++ b/docs/modules/chains/generic/llm_chain_separate.json @@ -0,0 +1,8 @@ +{ + "memory": null, + "verbose": true, + "prompt_path": "prompt.json", + "llm_path": "llm.json", + "output_key": "text", + "_type": "llm_chain" +} \ No newline at end of file diff --git a/docs/modules/chains/generic/prompt.json b/docs/modules/chains/generic/prompt.json new file mode 100644 index 0000000000..aceb330e2c --- /dev/null +++ b/docs/modules/chains/generic/prompt.json @@ -0,0 +1,8 @@ +{ + "input_variables": [ + "question" + ], + "output_parser": null, + "template": "Question: {question}\n\nAnswer: Let's think step by step.", + "template_format": "f-string" +} \ No newline at end of file diff --git a/docs/modules/chains/generic/serialization.ipynb b/docs/modules/chains/generic/serialization.ipynb new file mode 100644 index 0000000000..2483699d31 --- /dev/null +++ b/docs/modules/chains/generic/serialization.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cbe47c3a", + "metadata": {}, + "source": [ + "# Serialization\n", + "This notebook covers how to serialize chains to and from disk. The serialization format we use is json or yaml. Currently, only some chains support this type of serialization. We will grow the number of supported chains over time.\n" + ] + }, + { + "cell_type": "markdown", + "id": "e4a8a447", + "metadata": {}, + "source": [ + "## Saving a chain to disk\n", + "First, let's go over how to save a chain to disk. This can be done with the `.save` method, and specifying a file path with a json or yaml extension." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "26e28451", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import PromptTemplate, OpenAI, LLMChain\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0), verbose=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "bfa18e1f", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain.save(\"llm_chain.json\")" + ] + }, + { + "cell_type": "markdown", + "id": "ea82665d", + "metadata": {}, + "source": [ + "Let's now take a look at what's inside this saved file" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0fd33328", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\r\n", + " \"memory\": null,\r\n", + " \"verbose\": true,\r\n", + " \"prompt\": {\r\n", + " \"input_variables\": [\r\n", + " \"question\"\r\n", + " ],\r\n", + " \"output_parser\": null,\r\n", + " \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n", + " \"template_format\": \"f-string\"\r\n", + " },\r\n", + " \"llm\": {\r\n", + " \"model_name\": \"text-davinci-003\",\r\n", + " \"temperature\": 0.0,\r\n", + " \"max_tokens\": 256,\r\n", + " \"top_p\": 1,\r\n", + " \"frequency_penalty\": 0,\r\n", + " \"presence_penalty\": 0,\r\n", + " \"n\": 1,\r\n", + " \"best_of\": 1,\r\n", + " \"request_timeout\": null,\r\n", + " \"logit_bias\": {},\r\n", + " \"_type\": \"openai\"\r\n", + " },\r\n", + " \"output_key\": \"text\",\r\n", + " \"_type\": \"llm_chain\"\r\n", + "}" + ] + } + ], + "source": [ + "!cat llm_chain.json" + ] + }, + { + "cell_type": "markdown", + "id": "2012c724", + "metadata": {}, + "source": [ + "## Loading a chain from disk\n", + "We can load a chain from disk by using the `load_chain` method." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "342a1974", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import load_chain" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "394b7da8", + "metadata": {}, + "outputs": [], + "source": [ + "chain = load_chain(\"llm_chain.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "20d99787", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n", + "\n", + "Answer: Let's think step by step.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' 2 + 2 = 4'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"whats 2 + 2\")" + ] + }, + { + "cell_type": "markdown", + "id": "14449679", + "metadata": {}, + "source": [ + "## Saving components separately\n", + "In the above example, we can see that the prompt and llm configuration information is saved in the same json as the overall chain. Alternatively, we can split them up and save them separately. This is often useful to make the saved components more modular. In order to do this, we just need to specify `llm_path` instead of the `llm` component, and `prompt_path` instead of the `prompt` component." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "50ec35ab", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain.prompt.save(\"prompt.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c48b39aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\r\n", + " \"input_variables\": [\r\n", + " \"question\"\r\n", + " ],\r\n", + " \"output_parser\": null,\r\n", + " \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n", + " \"template_format\": \"f-string\"\r\n", + "}" + ] + } + ], + "source": [ + "!cat prompt.json" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "13c92944", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain.llm.save(\"llm.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1b815f89", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\r\n", + " \"model_name\": \"text-davinci-003\",\r\n", + " \"temperature\": 0.0,\r\n", + " \"max_tokens\": 256,\r\n", + " \"top_p\": 1,\r\n", + " \"frequency_penalty\": 0,\r\n", + " \"presence_penalty\": 0,\r\n", + " \"n\": 1,\r\n", + " \"best_of\": 1,\r\n", + " \"request_timeout\": null,\r\n", + " \"logit_bias\": {},\r\n", + " \"_type\": \"openai\"\r\n", + "}" + ] + } + ], + "source": [ + "!cat llm.json" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7e6aa9ab", + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " \"memory\": None,\n", + " \"verbose\": True,\n", + " \"prompt_path\": \"prompt.json\",\n", + " \"llm_path\": \"llm.json\",\n", + " \"output_key\": \"text\",\n", + " \"_type\": \"llm_chain\"\n", + "}\n", + "import json\n", + "with open(\"llm_chain_separate.json\", \"w\") as f:\n", + " json.dump(config, f, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e959ca6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\r\n", + " \"memory\": null,\r\n", + " \"verbose\": true,\r\n", + " \"prompt_path\": \"prompt.json\",\r\n", + " \"llm_path\": \"llm.json\",\r\n", + " \"output_key\": \"text\",\r\n", + " \"_type\": \"llm_chain\"\r\n", + "}" + ] + } + ], + "source": [ + "!cat llm_chain_separate.json" + ] + }, + { + "cell_type": "markdown", + "id": "662731c0", + "metadata": {}, + "source": [ + "We can then load it in the same way" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d69ceb93", + "metadata": {}, + "outputs": [], + "source": [ + "chain = load_chain(\"llm_chain_separate.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a99d61b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n", + "\n", + "Answer: Let's think step by step.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' 2 + 2 = 4'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"whats 2 + 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "822b7c12", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 6879e4850b..0079d1e26d 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -6,6 +6,7 @@ from langchain.chains.llm_bash.base import LLMBashChain from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_requests import LLMRequestsChain +from langchain.chains.loading import load_chain from langchain.chains.mapreduce import MapReduceChain from langchain.chains.moderation import OpenAIModerationChain from langchain.chains.pal.base import PALChain @@ -39,4 +40,5 @@ __all__ = [ "MapReduceChain", "OpenAIModerationChain", "SQLDatabaseSequentialChain", + "load_chain", ] diff --git a/langchain/chains/base.py b/langchain/chains/base.py index b828b06435..e30cdb1599 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -1,7 +1,10 @@ """Base interface that all chains should implement.""" +import json from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional, Union +import yaml from pydantic import BaseModel, Extra, Field, validator import langchain @@ -44,7 +47,9 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[Memory] = None - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callback_manager: BaseCallbackManager = Field( + default_factory=get_callback_manager, exclude=True + ) verbose: bool = Field( default_factory=_get_verbosity ) # Whether to print the response text @@ -54,6 +59,10 @@ class Chain(BaseModel, ABC): arbitrary_types_allowed = True + @property + def _chain_type(self) -> str: + raise NotImplementedError("Saving not supported for this chain type.") + @validator("callback_manager", pre=True, always=True) def set_callback_manager( cls, callback_manager: Optional[BaseCallbackManager] @@ -177,3 +186,43 @@ class Chain(BaseModel, ABC): f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}." ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of chain.""" + if self.memory is not None: + raise ValueError("Saving of memory is not yet supported.") + _dict = super().dict() + _dict["_type"] = self._chain_type + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the chain. + + Args: + file_path: Path to file to save the chain to. + + Example: + .. code-block:: python + + chain.save(file_path="path/chain.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + chain_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(chain_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(chain_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 9f713ed510..ce57631825 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -122,3 +122,7 @@ class LLMChain(Chain, BaseModel): return new_result else: return result + + @property + def _chain_type(self) -> str: + return "llm_chain" diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py new file mode 100644 index 0000000000..f36321a3fb --- /dev/null +++ b/langchain/chains/loading.py @@ -0,0 +1,68 @@ +"""Functionality for loading chains.""" +import json +from pathlib import Path +from typing import Union + +import yaml + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.llms.loading import load_llm, load_llm_from_config +from langchain.prompts.loading import load_prompt, load_prompt_from_config + + +def _load_llm_chain(config: dict) -> LLMChain: + """Load LLM chain from config dict.""" + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_config` must be present.") + + if "prompt" in config: + prompt_config = config.pop("prompt") + prompt = load_prompt_from_config(prompt_config) + elif "prompt_path" in config: + prompt = load_prompt(config.pop("prompt_path")) + else: + raise ValueError("One of `prompt` or `prompt_path` must be present.") + + return LLMChain(llm=llm, prompt=prompt, **config) + + +type_to_loader_dict = {"llm_chain": _load_llm_chain} + + +def load_chain_from_config(config: dict) -> Chain: + """Load chain from Config Dict.""" + if "_type" not in config: + raise ValueError("Must specify an chain Type in config") + config_type = config.pop("_type") + + if config_type not in type_to_loader_dict: + raise ValueError(f"Loading {config_type} chain not supported") + + chain_loader = type_to_loader_dict[config_type] + return chain_loader(config) + + +def load_chain(file: Union[str, Path]) -> Chain: + """Load chain from file.""" + # Convert file to Path object. + if isinstance(file, str): + file_path = Path(file) + else: + file_path = file + # Load from either json or yaml. + if file_path.suffix == ".json": + with open(file_path) as f: + config = json.load(f) + elif file_path.suffix == ".yaml": + with open(file_path, "r") as f: + config = yaml.safe_load(f) + else: + raise ValueError("File type must be json or yaml") + # Load the chain from the config now. + return load_chain_from_config(config) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index ebb1c58bef..59eccc345f 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -79,7 +79,7 @@ class BaseLLM(BaseModel, ABC): raise e self.callback_manager.on_llm_end(output, verbose=self.verbose) return output - params = self._llm_dict() + params = self.dict() params["stop"] = stop llm_string = str(sorted([(k, v) for k, v in params.items()])) missing_prompts = [] @@ -148,8 +148,8 @@ class BaseLLM(BaseModel, ABC): def _llm_type(self) -> str: """Return type of llm.""" - def _llm_dict(self) -> Dict: - """Return a dictionary of the prompt.""" + def dict(self, **kwargs: Any) -> Dict: + """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type return starter_dict @@ -175,7 +175,7 @@ class BaseLLM(BaseModel, ABC): directory_path.mkdir(parents=True, exist_ok=True) # Fetch dictionary to save - prompt_dict = self._llm_dict() + prompt_dict = self.dict() if save_path.suffix == ".json": with open(file_path, "w") as f: diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 21e5a6355f..cf99215e75 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -135,10 +135,6 @@ class BasePromptTemplate(BaseModel, ABC): prompt.format(variable1="foo") """ - def _prompt_dict(self) -> Dict: - """Return a dictionary of the prompt.""" - return self.dict() - def save(self, file_path: Union[Path, str]) -> None: """Save the prompt. @@ -160,7 +156,7 @@ class BasePromptTemplate(BaseModel, ABC): directory_path.mkdir(parents=True, exist_ok=True) # Fetch dictionary to save - prompt_dict = self._prompt_dict() + prompt_dict = self.dict() if save_path.suffix == ".json": with open(file_path, "w") as f: diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 1baab6fa5a..0119132d35 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -109,11 +109,11 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): # Format the template with the input variables. return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs) - def _prompt_dict(self) -> Dict: + def dict(self, **kwargs: Any) -> Dict: """Return a dictionary of the prompt.""" if self.example_selector: raise ValueError("Saving an example selector is not currently supported") - prompt_dict = self.dict() + prompt_dict = super().dict() prompt_dict["_type"] = "few_shot" return prompt_dict diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index 65b29ddf52..1dfe9bb54e 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -1,9 +1,12 @@ """Test LLM chain.""" +from tempfile import TemporaryDirectory from typing import Dict, List, Union +from unittest.mock import patch import pytest from langchain.chains.llm import LLMChain +from langchain.chains.loading import load_chain from langchain.prompts.base import BaseOutputParser from langchain.prompts.prompt import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM @@ -24,6 +27,16 @@ def fake_llm_chain() -> LLMChain: return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") +@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM}) +def test_serialization(fake_llm_chain: LLMChain) -> None: + """Test serialization.""" + with TemporaryDirectory() as temp_dir: + file = temp_dir + "/llm.json" + fake_llm_chain.save(file) + loaded_chain = load_chain(file) + assert loaded_chain == fake_llm_chain + + def test_missing_inputs(fake_llm_chain: LLMChain) -> None: """Test error is raised if inputs are missing.""" with pytest.raises(ValueError): diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py index a5c9cc00ba..9726b30414 100644 --- a/tests/unit_tests/llms/test_base.py +++ b/tests/unit_tests/llms/test_base.py @@ -12,7 +12,7 @@ def test_caching() -> None: """Test caching behavior.""" langchain.llm_cache = InMemoryCache() llm = FakeLLM() - params = llm._llm_dict() + params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) @@ -50,7 +50,7 @@ def test_custom_caching() -> None: engine = create_engine("sqlite://") langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache) llm = FakeLLM() - params = llm._llm_dict() + params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])