From 78b31e596611d2a55dc719c640532290b42ca535 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 15 Dec 2022 07:53:32 -0800 Subject: [PATCH] Harrison/cache (#343) --- docs/examples/prompts/custom_llm.ipynb | 23 +- docs/examples/prompts/llm_functionality.ipynb | 198 ++++++++++++++++-- langchain/__init__.py | 4 + langchain/cache.py | 108 ++++++++++ langchain/llms/ai21.py | 2 +- langchain/llms/base.py | 60 +++++- langchain/llms/cohere.py | 2 +- langchain/llms/huggingface_hub.py | 2 +- langchain/llms/manifest.py | 2 +- langchain/llms/nlpcloud.py | 2 +- langchain/llms/openai.py | 7 +- langchain/schema.py | 8 + tests/unit_tests/agents/test_agent.py | 2 +- tests/unit_tests/agents/test_react.py | 2 +- tests/unit_tests/chains/test_natbot.py | 2 +- tests/unit_tests/llms/fake_llm.py | 2 +- 16 files changed, 380 insertions(+), 46 deletions(-) create mode 100644 langchain/cache.py diff --git a/docs/examples/prompts/custom_llm.ipynb b/docs/examples/prompts/custom_llm.ipynb index fd9e6a7a..bb3938aa 100644 --- a/docs/examples/prompts/custom_llm.ipynb +++ b/docs/examples/prompts/custom_llm.ipynb @@ -11,7 +11,7 @@ "\n", "There is only one required thing that a custom LLM needs to implement:\n", "\n", - "1. A `__call__` method that takes in a string, some optional stop words, and returns a string\n", + "1. A `_call` method that takes in a string, some optional stop words, and returns a string\n", "\n", "There is a second optional thing it can implement:\n", "\n", @@ -33,17 +33,20 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "id": "d5ceff02", "metadata": {}, "outputs": [], "source": [ "class CustomLLM(LLM):\n", " \n", - " def __init__(self, n: int):\n", - " self.n = n\n", + " n: int\n", + " \n", + " @property\n", + " def _llm_type(self) -> str:\n", + " return \"custom\"\n", " \n", - " def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n", + " def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n", " if stop is not None:\n", " raise ValueError(\"stop kwargs are not permitted.\")\n", " return prompt[:self.n]\n", @@ -64,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "id": "10e5ece6", "metadata": {}, "outputs": [], @@ -74,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "8cd49199", "metadata": {}, "outputs": [ @@ -84,7 +87,7 @@ "'This is a '" ] }, - "execution_count": 4, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -103,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "id": "9c33fa19", "metadata": {}, "outputs": [ @@ -145,7 +148,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/docs/examples/prompts/llm_functionality.ipynb b/docs/examples/prompts/llm_functionality.ipynb index 05eec44e..a4b00eef 100644 --- a/docs/examples/prompts/llm_functionality.ipynb +++ b/docs/examples/prompts/llm_functionality.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "740392f6", "metadata": {}, "outputs": [ @@ -91,7 +91,7 @@ "30" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -102,18 +102,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "id": "ab6cdcf1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'),\n", - " Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!')]" + "[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'),\n", + " Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.')]" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -124,18 +124,18 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "4946a778", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[Generation(text='\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode'),\n", - " Generation(text=\"\\n\\nWhen I was younger\\nI thought that love\\nI was something like a fairytale\\nI would find my prince\\nAnd we would be together\\nForever\\nI was naïve\\nAnd I was wrong\\nLove is not a fairytale\\nIt's something else entirely\\nSomething that should be cherished\\nAnd loved\\nAnd never taken for granted\\nLove is something that you have to work for\\nIt doesn't come easy\\nYou have to sacrifice\\nYour time, your effort\\nAnd sometimes you have to give up \\nYou have to do what's best for yourself\\nAnd sometimes that means giving love up\")]" + "[Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt reminds me of the love I've found\\n\\nAnd I know that wherever I go\\n\\nI'll always find my rose by the side of the road.\"),\n", + " Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt tells me that true love is nigh\\n\\nAnd I know that this is the day\\n\\nWhen I look at this rose\\n\\nI am sure of what I am doing\\n\\nWhen I look at this rose\\n\\nI am confident in my love for you\\n\\nAnd I know that I am in love with you\\n\\nSo let it be, the rose by the side of the road\\n\\nAnd let it be what you do, what you are\\n\\nAnd you do it well, for this is what we want\\n\\nAnd we want to be with you\\n\\nAnd we want to be with you\\n\\nAnd we want to be with you\\n\\nWhen we find our way home\")]" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -153,9 +153,9 @@ { "data": { "text/plain": [ - "{'token_usage': {'completion_tokens': 3721,\n", + "{'token_usage': {'completion_tokens': 4108,\n", " 'prompt_tokens': 120,\n", - " 'total_tokens': 3841}}" + " 'total_tokens': 4228}}" ] }, "execution_count": 8, @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "b623c774", "metadata": {}, "outputs": [ @@ -197,7 +197,7 @@ "3" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -206,10 +206,178 @@ "llm.get_num_tokens(\"what a joke\")" ] }, + { + "cell_type": "markdown", + "id": "ee6fcf8d", + "metadata": {}, + "source": [ + "### Caching\n", + "With LangChain, you can also enable caching of LLM calls. Note that currently this only applies for individual LLM calls." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2626ca48", + "metadata": {}, + "outputs": [], + "source": [ + "import langchain\n", + "from langchain.cache import InMemoryCache\n", + "langchain.llm_cache = InMemoryCache()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "97762272", + "metadata": {}, + "outputs": [], + "source": [ + "# To make the caching really obvious, lets use a slower model.\n", + "llm = OpenAI(model_name=\"text-davinci-002\", n=2, best_of=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e80c65e4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 31.2 ms, sys: 11.8 ms, total: 43.1 ms\n", + "Wall time: 1.75 s\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "678408ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 51 µs, sys: 1 µs, total: 52 µs\n", + "Wall time: 67.2 µs\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The second time it is, so it goes faster\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3f0ac8d2", + "metadata": {}, + "outputs": [], + "source": [ + "# We can do the same thing with a SQLite cache\n", + "from langchain.cache import SQLiteCache\n", + "langchain.llm_cache = SQLiteCache(database_path=\".langchain.db\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0e1dcce3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 26.6 ms, sys: 11.2 ms, total: 37.7 ms\n", + "Wall time: 1.89 s\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "efadd750", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.69 ms, sys: 1.57 ms, total: 4.27 ms\n", + "Wall time: 2.73 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The second time it is, so it goes faster\n", + "llm(\"Tell me a joke\")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "4196efd9", + "id": "6053408b", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/__init__.py b/langchain/__init__.py index 31ec09eb..a32ffa95 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -1,6 +1,9 @@ """Main entrypoint into package.""" +from typing import Optional + from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain +from langchain.cache import BaseCache from langchain.chains import ( ConversationChain, LLMBashChain, @@ -28,6 +31,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch logger: BaseLogger = StdOutLogger() verbose: bool = False +llm_cache: Optional[BaseCache] = None __all__ = [ "LLMChain", diff --git a/langchain/cache.py b/langchain/cache.py new file mode 100644 index 00000000..86a9dc6b --- /dev/null +++ b/langchain/cache.py @@ -0,0 +1,108 @@ +"""Beta Feature: base interface for cache.""" +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union + +from sqlalchemy import Column, Integer, String, create_engine, select +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session + +from langchain.schema import Generation + +RETURN_VAL_TYPE = Union[List[Generation], str] + + +class BaseCache(ABC): + """Base interface for cache.""" + + @abstractmethod + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + + @abstractmethod + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + + +class InMemoryCache(BaseCache): + """Cache that stores things in memory.""" + + def __init__(self) -> None: + """Initialize with empty cache.""" + self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + return self._cache.get((prompt, llm_string), None) + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + self._cache[(prompt, llm_string)] = return_val + + +Base = declarative_base() + + +class LLMCache(Base): # type: ignore + """SQLite table for simple LLM cache (string only).""" + + __tablename__ = "llm_cache" + prompt = Column(String, primary_key=True) + llm = Column(String, primary_key=True) + response = Column(String) + + +class FullLLMCache(Base): # type: ignore + """SQLite table for full LLM Cache (all generations).""" + + __tablename__ = "full_llm_cache" + prompt = Column(String, primary_key=True) + llm = Column(String, primary_key=True) + idx = Column(Integer, primary_key=True) + response = Column(String) + + +class SQLiteCache(BaseCache): + """Cache that uses SQLite as a backend.""" + + def __init__(self, database_path: str = ".langchain.db"): + """Initialize by creating the engine and all tables.""" + self.engine = create_engine(f"sqlite:///{database_path}") + Base.metadata.create_all(self.engine) + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + stmt = ( + select(FullLLMCache.response) + .where(FullLLMCache.prompt == prompt) + .where(FullLLMCache.llm == llm_string) + .order_by(FullLLMCache.idx) + ) + with Session(self.engine) as session: + generations = [] + for row in session.execute(stmt): + generations.append(Generation(text=row[0])) + if len(generations) > 0: + return generations + stmt = ( + select(LLMCache.response) + .where(LLMCache.prompt == prompt) + .where(LLMCache.llm == llm_string) + ) + with Session(self.engine) as session: + for row in session.execute(stmt): + return row[0] + return None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Look up based on prompt and llm_string.""" + if isinstance(return_val, str): + item = LLMCache(prompt=prompt, llm=llm_string, response=return_val) + with Session(self.engine) as session, session.begin(): + session.add(item) + else: + for i, generation in enumerate(return_val): + item = FullLLMCache( + prompt=prompt, llm=llm_string, response=generation.text, idx=i + ) + with Session(self.engine) as session, session.begin(): + session.add(item) diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 3678473d..77a9300d 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -101,7 +101,7 @@ class AI21(LLM, BaseModel): """Return type of llm.""" return "ai21" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to AI21's complete endpoint. Args: diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 0eb8b73b..3b030838 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -7,13 +7,8 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union import yaml from pydantic import BaseModel, Extra - -class Generation(NamedTuple): - """Output of a single generation.""" - - text: str - """Generated text output.""" - # TODO: add log probs +import langchain +from langchain.schema import Generation class LLMResult(NamedTuple): @@ -34,16 +29,44 @@ class LLM(BaseModel, ABC): extra = Extra.forbid - def generate( + def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: """Run the LLM on the given prompt and input.""" + # TODO: add caching here. generations = [] for prompt in prompts: text = self(prompt, stop=stop) generations.append([Generation(text=text)]) return LLMResult(generations=generations) + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + if langchain.llm_cache is None: + return self._generate(prompts, stop=stop) + params = self._llm_dict() + params["stop"] = stop + llm_string = str(sorted([(k, v) for k, v in params.items()])) + missing_prompts = [] + missing_prompt_idxs = [] + existing_prompts = {} + for i, prompt in enumerate(prompts): + cache_val = langchain.llm_cache.lookup(prompt, llm_string) + if isinstance(cache_val, list): + existing_prompts[i] = cache_val + else: + missing_prompts.append(prompt) + missing_prompt_idxs.append(i) + new_results = self._generate(missing_prompts, stop=stop) + for i, result in enumerate(new_results.generations): + existing_prompts[i] = result + prompt = prompts[i] + langchain.llm_cache.update(prompt, llm_string, result) + generations = [existing_prompts[i] for i in range(len(prompts))] + return LLMResult(generations=generations, llm_output=new_results.llm_output) + def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text.""" # TODO: this method may not be exact. @@ -66,9 +89,28 @@ class LLM(BaseModel, ABC): return len(tokenized_text) @abstractmethod - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Run the LLM on the given prompt and input.""" + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + if langchain.llm_cache is None: + return self._call(prompt, stop=stop) + params = self._llm_dict() + params["stop"] = stop + llm_string = str(sorted([(k, v) for k, v in params.items()])) + if langchain.cache is not None: + cache_val = langchain.llm_cache.lookup(prompt, llm_string) + if cache_val is not None: + if isinstance(cache_val, str): + return cache_val + else: + return cache_val[0].text + return_val = self._call(prompt, stop=stop) + if langchain.cache is not None: + langchain.llm_cache.update(prompt, llm_string, return_val) + return return_val + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 3853cfb9..d9a3f51a 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -90,7 +90,7 @@ class Cohere(LLM, BaseModel): """Return type of llm.""" return "cohere" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to Cohere's generate endpoint. Args: diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 74f0545a..ef53275d 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -84,7 +84,7 @@ class HuggingFaceHub(LLM, BaseModel): """Return type of llm.""" return "huggingface_hub" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to HuggingFace Hub's inference endpoint. Args: diff --git a/langchain/llms/manifest.py b/langchain/llms/manifest.py index 665f0c5b..b9a4ce14 100644 --- a/langchain/llms/manifest.py +++ b/langchain/llms/manifest.py @@ -42,7 +42,7 @@ class ManifestWrapper(LLM, BaseModel): """Return type of llm.""" return "manifest" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to LLM through Manifest.""" if stop is not None and len(stop) != 1: raise NotImplementedError( diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index ae3dc07b..94f0df7d 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -111,7 +111,7 @@ class NLPCloud(LLM, BaseModel): """Return type of llm.""" return "nlpcloud" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to NLPCloud's create endpoint. Args: diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2e9485ca..ac137c7a 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -3,7 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, Field, root_validator -from langchain.llms.base import LLM, Generation, LLMResult +from langchain.llms.base import LLM, LLMResult +from langchain.schema import Generation from langchain.utils import get_from_dict_or_env @@ -99,7 +100,7 @@ class OpenAI(LLM, BaseModel): } return {**normal_params, **self.model_kwargs} - def generate( + def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: """Call out to OpenAI's endpoint with k unique prompts. @@ -168,7 +169,7 @@ class OpenAI(LLM, BaseModel): """Return type of llm.""" return "openai" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to OpenAI's create endpoint. Args: diff --git a/langchain/schema.py b/langchain/schema.py index 67a64c01..4e255e3e 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -9,3 +9,11 @@ class AgentAction(NamedTuple): tool: str tool_input: str log: str + + +class Generation(NamedTuple): + """Output of a single generation.""" + + text: str + """Generated text output.""" + # TODO: add log probs diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index d407f4b4..48d0e61f 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -14,7 +14,7 @@ class FakeListLLM(LLM, BaseModel): responses: List[str] i: int = -1 - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Increment counter, and then return response in that index.""" self.i += 1 print(self.i) diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 917fa745..f3dc8da5 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -33,7 +33,7 @@ class FakeListLLM(LLM, BaseModel): """Return type of llm.""" return "fake_list" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Increment counter, and then return response in that index.""" self.i += 1 return self.responses[self.i] diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index a9817f71..0beaa409 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -11,7 +11,7 @@ from langchain.llms.base import LLM class FakeLLM(LLM, BaseModel): """Fake LLM wrapper for testing purposes.""" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Return `foo` if longer than 10000 words, else `bar`.""" if len(prompt) > 10000: return "foo" diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index a3896b6f..dd8b3462 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -16,7 +16,7 @@ class FakeLLM(LLM, BaseModel): """Return type of llm.""" return "fake" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" if self.queries is not None: return self.queries[prompt]