interface

This commit is contained in:
William Fu-Hinthorn 2022-10-27 06:40:20 -07:00
parent ad53a2ef81
commit 86fdeaf4ec
3 changed files with 11 additions and 11 deletions

View File

@ -3,18 +3,18 @@
from typing import List, Optional from typing import List, Optional
from langchain.chains.natbot.base import NatBotChain from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM from langchain.llms.base import LLM, CompletionOutput
class FakeLLM(LLM): class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Return `foo` if longer than 10000 words, else `bar`.""" """Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000: if len(prompt) > 10000:
return "foo" return [CompletionOutput("foo")]
else: else:
return "bar" return [CompletionOutput("bar")]
def test_proper_inputs() -> None: def test_proper_inputs() -> None:

View File

@ -8,7 +8,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.react.base import ReActChain, predict_until_observation from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import LLM, CompletionOutput
from langchain.prompt import Prompt from langchain.prompt import Prompt
_PAGE_CONTENT = """This is a page about LangChain. _PAGE_CONTENT = """This is a page about LangChain.
@ -30,10 +30,10 @@ class FakeListLLM(LLM):
self.responses = responses self.responses = responses
self.i = -1 self.i = -1
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Increment counter, and then return response in that index.""" """Increment counter, and then return response in that index."""
self.i += 1 self.i += 1
return self.responses[self.i] return [CompletionOutput(self.responses[self.i])]
class FakeDocstore(Docstore): class FakeDocstore(Docstore):

View File

@ -1,7 +1,7 @@
"""Fake LLM wrapper for testing purposes.""" """Fake LLM wrapper for testing purposes."""
from typing import List, Mapping, Optional from typing import List, Mapping, Optional
from langchain.llms.base import LLM from langchain.llms.base import LLM, CompletionOutput
class FakeLLM(LLM): class FakeLLM(LLM):
@ -11,11 +11,11 @@ class FakeLLM(LLM):
"""Initialize with optional lookup of queries.""" """Initialize with optional lookup of queries."""
self._queries = queries self._queries = queries
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""First try to lookup in queries, else return 'foo' or 'bar'.""" """First try to lookup in queries, else return 'foo' or 'bar'."""
if self._queries is not None: if self._queries is not None:
return self._queries[prompt] return self._queries[prompt]
if stop is None: if stop is None:
return "foo" return [CompletionOutput("foo")]
else: else:
return "bar" return [CompletionOutput("bar")]