Harrison/serialize llm chain (#671)

harrison/document-split
Harrison Chase 1 year ago committed by GitHub
parent 499e54edda
commit 0ffeabd14f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,13 @@
{
"model_name": "text-davinci-003",
"temperature": 0.0,
"max_tokens": 256,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
"best_of": 1,
"request_timeout": null,
"logit_bias": {},
"_type": "openai"
}

@ -0,0 +1,27 @@
{
"memory": null,
"verbose": true,
"prompt": {
"input_variables": [
"question"
],
"output_parser": null,
"template": "Question: {question}\n\nAnswer: Let's think step by step.",
"template_format": "f-string"
},
"llm": {
"model_name": "text-davinci-003",
"temperature": 0.0,
"max_tokens": 256,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
"best_of": 1,
"request_timeout": null,
"logit_bias": {},
"_type": "openai"
},
"output_key": "text",
"_type": "llm_chain"
}

@ -0,0 +1,8 @@
{
"memory": null,
"verbose": true,
"prompt_path": "prompt.json",
"llm_path": "llm.json",
"output_key": "text",
"_type": "llm_chain"
}

@ -0,0 +1,8 @@
{
"input_variables": [
"question"
],
"output_parser": null,
"template": "Question: {question}\n\nAnswer: Let's think step by step.",
"template_format": "f-string"
}

@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cbe47c3a",
"metadata": {},
"source": [
"# Serialization\n",
"This notebook covers how to serialize chains to and from disk. The serialization format we use is json or yaml. Currently, only some chains support this type of serialization. We will grow the number of supported chains over time.\n"
]
},
{
"cell_type": "markdown",
"id": "e4a8a447",
"metadata": {},
"source": [
"## Saving a chain to disk\n",
"First, let's go over how to save a chain to disk. This can be done with the `.save` method, and specifying a file path with a json or yaml extension."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "26e28451",
"metadata": {},
"outputs": [],
"source": [
"from langchain import PromptTemplate, OpenAI, LLMChain\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0), verbose=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bfa18e1f",
"metadata": {},
"outputs": [],
"source": [
"llm_chain.save(\"llm_chain.json\")"
]
},
{
"cell_type": "markdown",
"id": "ea82665d",
"metadata": {},
"source": [
"Let's now take a look at what's inside this saved file"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0fd33328",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\r\n",
" \"memory\": null,\r\n",
" \"verbose\": true,\r\n",
" \"prompt\": {\r\n",
" \"input_variables\": [\r\n",
" \"question\"\r\n",
" ],\r\n",
" \"output_parser\": null,\r\n",
" \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n",
" \"template_format\": \"f-string\"\r\n",
" },\r\n",
" \"llm\": {\r\n",
" \"model_name\": \"text-davinci-003\",\r\n",
" \"temperature\": 0.0,\r\n",
" \"max_tokens\": 256,\r\n",
" \"top_p\": 1,\r\n",
" \"frequency_penalty\": 0,\r\n",
" \"presence_penalty\": 0,\r\n",
" \"n\": 1,\r\n",
" \"best_of\": 1,\r\n",
" \"request_timeout\": null,\r\n",
" \"logit_bias\": {},\r\n",
" \"_type\": \"openai\"\r\n",
" },\r\n",
" \"output_key\": \"text\",\r\n",
" \"_type\": \"llm_chain\"\r\n",
"}"
]
}
],
"source": [
"!cat llm_chain.json"
]
},
{
"cell_type": "markdown",
"id": "2012c724",
"metadata": {},
"source": [
"## Loading a chain from disk\n",
"We can load a chain from disk by using the `load_chain` method."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "342a1974",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import load_chain"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "394b7da8",
"metadata": {},
"outputs": [],
"source": [
"chain = load_chain(\"llm_chain.json\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "20d99787",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n",
"\n",
"Answer: Let's think step by step.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' 2 + 2 = 4'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"whats 2 + 2\")"
]
},
{
"cell_type": "markdown",
"id": "14449679",
"metadata": {},
"source": [
"## Saving components separately\n",
"In the above example, we can see that the prompt and llm configuration information is saved in the same json as the overall chain. Alternatively, we can split them up and save them separately. This is often useful to make the saved components more modular. In order to do this, we just need to specify `llm_path` instead of the `llm` component, and `prompt_path` instead of the `prompt` component."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "50ec35ab",
"metadata": {},
"outputs": [],
"source": [
"llm_chain.prompt.save(\"prompt.json\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c48b39aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\r\n",
" \"input_variables\": [\r\n",
" \"question\"\r\n",
" ],\r\n",
" \"output_parser\": null,\r\n",
" \"template\": \"Question: {question}\\n\\nAnswer: Let's think step by step.\",\r\n",
" \"template_format\": \"f-string\"\r\n",
"}"
]
}
],
"source": [
"!cat prompt.json"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "13c92944",
"metadata": {},
"outputs": [],
"source": [
"llm_chain.llm.save(\"llm.json\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1b815f89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\r\n",
" \"model_name\": \"text-davinci-003\",\r\n",
" \"temperature\": 0.0,\r\n",
" \"max_tokens\": 256,\r\n",
" \"top_p\": 1,\r\n",
" \"frequency_penalty\": 0,\r\n",
" \"presence_penalty\": 0,\r\n",
" \"n\": 1,\r\n",
" \"best_of\": 1,\r\n",
" \"request_timeout\": null,\r\n",
" \"logit_bias\": {},\r\n",
" \"_type\": \"openai\"\r\n",
"}"
]
}
],
"source": [
"!cat llm.json"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7e6aa9ab",
"metadata": {},
"outputs": [],
"source": [
"config = {\n",
" \"memory\": None,\n",
" \"verbose\": True,\n",
" \"prompt_path\": \"prompt.json\",\n",
" \"llm_path\": \"llm.json\",\n",
" \"output_key\": \"text\",\n",
" \"_type\": \"llm_chain\"\n",
"}\n",
"import json\n",
"with open(\"llm_chain_separate.json\", \"w\") as f:\n",
" json.dump(config, f, indent=2)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8e959ca6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\r\n",
" \"memory\": null,\r\n",
" \"verbose\": true,\r\n",
" \"prompt_path\": \"prompt.json\",\r\n",
" \"llm_path\": \"llm.json\",\r\n",
" \"output_key\": \"text\",\r\n",
" \"_type\": \"llm_chain\"\r\n",
"}"
]
}
],
"source": [
"!cat llm_chain_separate.json"
]
},
{
"cell_type": "markdown",
"id": "662731c0",
"metadata": {},
"source": [
"We can then load it in the same way"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d69ceb93",
"metadata": {},
"outputs": [],
"source": [
"chain = load_chain(\"llm_chain_separate.json\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a99d61b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mQuestion: whats 2 + 2\n",
"\n",
"Answer: Let's think step by step.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' 2 + 2 = 4'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"whats 2 + 2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "822b7c12",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -6,6 +6,7 @@ from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.loading import load_chain
from langchain.chains.mapreduce import MapReduceChain
from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.pal.base import PALChain
@ -39,4 +40,5 @@ __all__ = [
"MapReduceChain",
"OpenAIModerationChain",
"SQLDatabaseSequentialChain",
"load_chain",
]

@ -1,7 +1,10 @@
"""Base interface that all chains should implement."""
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Extra, Field, validator
import langchain
@ -44,7 +47,9 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[Memory] = None
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
callback_manager: BaseCallbackManager = Field(
default_factory=get_callback_manager, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
@ -54,6 +59,10 @@ class Chain(BaseModel, ABC):
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
@ -177,3 +186,43 @@ class Chain(BaseModel, ABC):
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")

@ -122,3 +122,7 @@ class LLMChain(Chain, BaseModel):
return new_result
else:
return result
@property
def _chain_type(self) -> str:
return "llm_chain"

@ -0,0 +1,68 @@
"""Functionality for loading chains."""
import json
from pathlib import Path
from typing import Union
import yaml
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config
def _load_llm_chain(config: dict) -> LLMChain:
"""Load LLM chain from config dict."""
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_config` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
else:
raise ValueError("One of `prompt` or `prompt_path` must be present.")
return LLMChain(llm=llm, prompt=prompt, **config)
type_to_loader_dict = {"llm_chain": _load_llm_chain}
def load_chain_from_config(config: dict) -> Chain:
"""Load chain from Config Dict."""
if "_type" not in config:
raise ValueError("Must specify an chain Type in config")
config_type = config.pop("_type")
if config_type not in type_to_loader_dict:
raise ValueError(f"Loading {config_type} chain not supported")
chain_loader = type_to_loader_dict[config_type]
return chain_loader(config)
def load_chain(file: Union[str, Path]) -> Chain:
"""Load chain from file."""
# Convert file to Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
config = yaml.safe_load(f)
else:
raise ValueError("File type must be json or yaml")
# Load the chain from the config now.
return load_chain_from_config(config)

@ -79,7 +79,7 @@ class BaseLLM(BaseModel, ABC):
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self._llm_dict()
params = self.dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
@ -148,8 +148,8 @@ class BaseLLM(BaseModel, ABC):
def _llm_type(self) -> str:
"""Return type of llm."""
def _llm_dict(self) -> Dict:
"""Return a dictionary of the prompt."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
@ -175,7 +175,7 @@ class BaseLLM(BaseModel, ABC):
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self._llm_dict()
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:

@ -135,10 +135,6 @@ class BasePromptTemplate(BaseModel, ABC):
prompt.format(variable1="foo")
"""
def _prompt_dict(self) -> Dict:
"""Return a dictionary of the prompt."""
return self.dict()
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
@ -160,7 +156,7 @@ class BasePromptTemplate(BaseModel, ABC):
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self._prompt_dict()
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:

@ -109,11 +109,11 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
def _prompt_dict(self) -> Dict:
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
prompt_dict = self.dict()
prompt_dict = super().dict()
prompt_dict["_type"] = "few_shot"
return prompt_dict

@ -1,9 +1,12 @@
"""Test LLM chain."""
from tempfile import TemporaryDirectory
from typing import Dict, List, Union
from unittest.mock import patch
import pytest
from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain
from langchain.prompts.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -24,6 +27,16 @@ def fake_llm_chain() -> LLMChain:
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
def test_serialization(fake_llm_chain: LLMChain) -> None:
"""Test serialization."""
with TemporaryDirectory() as temp_dir:
file = temp_dir + "/llm.json"
fake_llm_chain.save(file)
loaded_chain = load_chain(file)
assert loaded_chain == fake_llm_chain
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
"""Test error is raised if inputs are missing."""
with pytest.raises(ValueError):

@ -12,7 +12,7 @@ def test_caching() -> None:
"""Test caching behavior."""
langchain.llm_cache = InMemoryCache()
llm = FakeLLM()
params = llm._llm_dict()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
@ -50,7 +50,7 @@ def test_custom_caching() -> None:
engine = create_engine("sqlite://")
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
llm = FakeLLM()
params = llm._llm_dict()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])

Loading…
Cancel
Save