Harrison/cache (#343)

harrison/agent_multi_inputs^2
Harrison Chase 1 year ago committed by GitHub
parent 8cf62ce06e
commit 78b31e5966
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@
"\n",
"There is only one required thing that a custom LLM needs to implement:\n",
"\n",
"1. A `__call__` method that takes in a string, some optional stop words, and returns a string\n",
"1. A `_call` method that takes in a string, some optional stop words, and returns a string\n",
"\n",
"There is a second optional thing it can implement:\n",
"\n",
@ -33,17 +33,20 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"id": "d5ceff02",
"metadata": {},
"outputs": [],
"source": [
"class CustomLLM(LLM):\n",
" \n",
" def __init__(self, n: int):\n",
" self.n = n\n",
" n: int\n",
" \n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"custom\"\n",
" \n",
" def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n",
" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n",
" if stop is not None:\n",
" raise ValueError(\"stop kwargs are not permitted.\")\n",
" return prompt[:self.n]\n",
@ -64,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"id": "10e5ece6",
"metadata": {},
"outputs": [],
@ -74,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"id": "8cd49199",
"metadata": {},
"outputs": [
@ -84,7 +87,7 @@
"'This is a '"
]
},
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -103,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 10,
"id": "9c33fa19",
"metadata": {},
"outputs": [
@ -145,7 +148,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.8"
}
},
"nbformat": 4,

@ -81,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "740392f6",
"metadata": {},
"outputs": [
@ -91,7 +91,7 @@
"30"
]
},
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -102,18 +102,18 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"id": "ab6cdcf1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'),\n",
" Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!')]"
"[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'),\n",
" Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.')]"
]
},
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -124,18 +124,18 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"id": "4946a778",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Generation(text='\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode'),\n",
" Generation(text=\"\\n\\nWhen I was younger\\nI thought that love\\nI was something like a fairytale\\nI would find my prince\\nAnd we would be together\\nForever\\nI was naïve\\nAnd I was wrong\\nLove is not a fairytale\\nIt's something else entirely\\nSomething that should be cherished\\nAnd loved\\nAnd never taken for granted\\nLove is something that you have to work for\\nIt doesn't come easy\\nYou have to sacrifice\\nYour time, your effort\\nAnd sometimes you have to give up \\nYou have to do what's best for yourself\\nAnd sometimes that means giving love up\")]"
"[Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt reminds me of the love I've found\\n\\nAnd I know that wherever I go\\n\\nI'll always find my rose by the side of the road.\"),\n",
" Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt tells me that true love is nigh\\n\\nAnd I know that this is the day\\n\\nWhen I look at this rose\\n\\nI am sure of what I am doing\\n\\nWhen I look at this rose\\n\\nI am confident in my love for you\\n\\nAnd I know that I am in love with you\\n\\nSo let it be, the rose by the side of the road\\n\\nAnd let it be what you do, what you are\\n\\nAnd you do it well, for this is what we want\\n\\nAnd we want to be with you\\n\\nAnd we want to be with you\\n\\nAnd we want to be with you\\n\\nWhen we find our way home\")]"
]
},
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -153,9 +153,9 @@
{
"data": {
"text/plain": [
"{'token_usage': {'completion_tokens': 3721,\n",
"{'token_usage': {'completion_tokens': 4108,\n",
" 'prompt_tokens': 120,\n",
" 'total_tokens': 3841}}"
" 'total_tokens': 4228}}"
]
},
"execution_count": 8,
@ -180,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "b623c774",
"metadata": {},
"outputs": [
@ -197,7 +197,7 @@
"3"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -206,10 +206,178 @@
"llm.get_num_tokens(\"what a joke\")"
]
},
{
"cell_type": "markdown",
"id": "ee6fcf8d",
"metadata": {},
"source": [
"### Caching\n",
"With LangChain, you can also enable caching of LLM calls. Note that currently this only applies for individual LLM calls."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2626ca48",
"metadata": {},
"outputs": [],
"source": [
"import langchain\n",
"from langchain.cache import InMemoryCache\n",
"langchain.llm_cache = InMemoryCache()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97762272",
"metadata": {},
"outputs": [],
"source": [
"# To make the caching really obvious, lets use a slower model.\n",
"llm = OpenAI(model_name=\"text-davinci-002\", n=2, best_of=2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e80c65e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 31.2 ms, sys: 11.8 ms, total: 43.1 ms\n",
"Wall time: 1.75 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "678408ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 51 µs, sys: 1 µs, total: 52 µs\n",
"Wall time: 67.2 µs\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3f0ac8d2",
"metadata": {},
"outputs": [],
"source": [
"# We can do the same thing with a SQLite cache\n",
"from langchain.cache import SQLiteCache\n",
"langchain.llm_cache = SQLiteCache(database_path=\".langchain.db\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0e1dcce3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 26.6 ms, sys: 11.2 ms, total: 37.7 ms\n",
"Wall time: 1.89 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "efadd750",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.69 ms, sys: 1.57 ms, total: 4.27 ms\n",
"Wall time: 2.73 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4196efd9",
"id": "6053408b",
"metadata": {},
"outputs": [],
"source": []

@ -1,6 +1,9 @@
"""Main entrypoint into package."""
from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -28,6 +31,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
logger: BaseLogger = StdOutLogger()
verbose: bool = False
llm_cache: Optional[BaseCache] = None
__all__ = [
"LLMChain",

@ -0,0 +1,108 @@
"""Beta Feature: base interface for cache."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from langchain.schema import Generation
RETURN_VAL_TYPE = Union[List[Generation], str]
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
class InMemoryCache(BaseCache):
"""Cache that stores things in memory."""
def __init__(self) -> None:
"""Initialize with empty cache."""
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return self._cache.get((prompt, llm_string), None)
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
Base = declarative_base()
class LLMCache(Base): # type: ignore
"""SQLite table for simple LLM cache (string only)."""
__tablename__ = "llm_cache"
prompt = Column(String, primary_key=True)
llm = Column(String, primary_key=True)
response = Column(String)
class FullLLMCache(Base): # type: ignore
"""SQLite table for full LLM Cache (all generations)."""
__tablename__ = "full_llm_cache"
prompt = Column(String, primary_key=True)
llm = Column(String, primary_key=True)
idx = Column(Integer, primary_key=True)
response = Column(String)
class SQLiteCache(BaseCache):
"""Cache that uses SQLite as a backend."""
def __init__(self, database_path: str = ".langchain.db"):
"""Initialize by creating the engine and all tables."""
self.engine = create_engine(f"sqlite:///{database_path}")
Base.metadata.create_all(self.engine)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
stmt = (
select(FullLLMCache.response)
.where(FullLLMCache.prompt == prompt)
.where(FullLLMCache.llm == llm_string)
.order_by(FullLLMCache.idx)
)
with Session(self.engine) as session:
generations = []
for row in session.execute(stmt):
generations.append(Generation(text=row[0]))
if len(generations) > 0:
return generations
stmt = (
select(LLMCache.response)
.where(LLMCache.prompt == prompt)
.where(LLMCache.llm == llm_string)
)
with Session(self.engine) as session:
for row in session.execute(stmt):
return row[0]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string."""
if isinstance(return_val, str):
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val)
with Session(self.engine) as session, session.begin():
session.add(item)
else:
for i, generation in enumerate(return_val):
item = FullLLMCache(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
session.add(item)

@ -101,7 +101,7 @@ class AI21(LLM, BaseModel):
"""Return type of llm."""
return "ai21"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to AI21's complete endpoint.
Args:

@ -7,13 +7,8 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
import yaml
from pydantic import BaseModel, Extra
class Generation(NamedTuple):
"""Output of a single generation."""
text: str
"""Generated text output."""
# TODO: add log probs
import langchain
from langchain.schema import Generation
class LLMResult(NamedTuple):
@ -34,16 +29,44 @@ class LLM(BaseModel, ABC):
extra = Extra.forbid
def generate(
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
generations = []
for prompt in prompts:
text = self(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
if langchain.llm_cache is None:
return self._generate(prompts, stop=stop)
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
new_results = self._generate(missing_prompts, stop=stop)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
prompt = prompts[i]
langchain.llm_cache.update(prompt, llm_string, result)
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=new_results.llm_output)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
@ -66,9 +89,28 @@ class LLM(BaseModel, ABC):
return len(tokenized_text)
@abstractmethod
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
if langchain.llm_cache is None:
return self._call(prompt, stop=stop)
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
if langchain.cache is not None:
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if cache_val is not None:
if isinstance(cache_val, str):
return cache_val
else:
return cache_val[0].text
return_val = self._call(prompt, stop=stop)
if langchain.cache is not None:
langchain.llm_cache.update(prompt, llm_string, return_val)
return return_val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""

@ -90,7 +90,7 @@ class Cohere(LLM, BaseModel):
"""Return type of llm."""
return "cohere"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Cohere's generate endpoint.
Args:

@ -84,7 +84,7 @@ class HuggingFaceHub(LLM, BaseModel):
"""Return type of llm."""
return "huggingface_hub"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:

@ -42,7 +42,7 @@ class ManifestWrapper(LLM, BaseModel):
"""Return type of llm."""
return "manifest"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1:
raise NotImplementedError(

@ -111,7 +111,7 @@ class NLPCloud(LLM, BaseModel):
"""Return type of llm."""
return "nlpcloud"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to NLPCloud's create endpoint.
Args:

@ -3,7 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM, Generation, LLMResult
from langchain.llms.base import LLM, LLMResult
from langchain.schema import Generation
from langchain.utils import get_from_dict_or_env
@ -99,7 +100,7 @@ class OpenAI(LLM, BaseModel):
}
return {**normal_params, **self.model_kwargs}
def generate(
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
@ -168,7 +169,7 @@ class OpenAI(LLM, BaseModel):
"""Return type of llm."""
return "openai"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint.
Args:

@ -9,3 +9,11 @@ class AgentAction(NamedTuple):
tool: str
tool_input: str
log: str
class Generation(NamedTuple):
"""Output of a single generation."""
text: str
"""Generated text output."""
# TODO: add log probs

@ -14,7 +14,7 @@ class FakeListLLM(LLM, BaseModel):
responses: List[str]
i: int = -1
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1
print(self.i)

@ -33,7 +33,7 @@ class FakeListLLM(LLM, BaseModel):
"""Return type of llm."""
return "fake_list"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1
return self.responses[self.i]

@ -11,7 +11,7 @@ from langchain.llms.base import LLM
class FakeLLM(LLM, BaseModel):
"""Fake LLM wrapper for testing purposes."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:
return "foo"

@ -16,7 +16,7 @@ class FakeLLM(LLM, BaseModel):
"""Return type of llm."""
return "fake"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
if self.queries is not None:
return self.queries[prompt]

Loading…
Cancel
Save