Harrison/improve cache (#368)

make it so everything goes through generate, which removes the need for
two types of caches
harrison/sequential_chain_from_prompts
Harrison Chase 1 year ago committed by GitHub
parent 8d0869c6d3
commit 3474f39e21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction
@ -87,7 +87,9 @@ class Agent(Chain, BaseModel, ABC):
pass
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
def from_llm_and_tools(
cls, llm: BaseLLM, tools: List[Tool], **kwargs: Any
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))

@ -6,7 +6,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotAgent,
@ -17,7 +17,7 @@ AGENT_TO_CLASS = {
def initialize_agent(
tools: List[Tool],
llm: LLM,
llm: BaseLLM,
agent: str = "zero-shot-react-description",
**kwargs: Any,
) -> Agent:

@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer: "
@ -116,7 +116,9 @@ class MRKLChain(ZeroShotAgent):
"""
@classmethod
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
def from_chains(
cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any
) -> Agent:
"""User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the

@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -123,7 +123,7 @@ class ReActChain(ReActDocstoreAgent):
react = ReAct(llm=OpenAI())
"""
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore)
tools = [

@ -5,7 +5,7 @@ from langchain.agents.agent import Agent
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper
@ -72,7 +72,7 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
"""
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
llm_chain = LLMChain(llm=llm, prompt=PROMPT)

@ -1,6 +1,6 @@
"""Beta Feature: base interface for cache."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from langchain.schema import Generation
RETURN_VAL_TYPE = Union[List[Generation], str]
RETURN_VAL_TYPE = List[Generation]
class BaseCache(ABC):
@ -43,15 +43,6 @@ class InMemoryCache(BaseCache):
Base = declarative_base()
class LLMCache(Base): # type: ignore
"""SQLite table for simple LLM cache (string only)."""
__tablename__ = "llm_cache"
prompt = Column(String, primary_key=True)
llm = Column(String, primary_key=True)
response = Column(String)
class FullLLMCache(Base): # type: ignore
"""SQLite table for full LLM Cache (all generations)."""
@ -84,29 +75,16 @@ class SQLAlchemyCache(BaseCache):
generations.append(Generation(text=row[0]))
if len(generations) > 0:
return generations
stmt = (
select(LLMCache.response)
.where(LLMCache.prompt == prompt)
.where(LLMCache.llm == llm_string)
)
with Session(self.engine) as session:
for row in session.execute(stmt):
return row[0]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string."""
if isinstance(return_val, str):
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val)
for i, generation in enumerate(return_val):
item = FullLLMCache(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
session.add(item)
else:
for i, generation in enumerate(return_val):
item = FullLLMCache(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
session.add(item)
class SQLiteCache(SQLAlchemyCache):

@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper
@ -81,7 +81,7 @@ class APIChain(Chain, BaseModel):
@classmethod
def from_llm_and_api_docs(
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, root_validator
from langchain.chains.base import Memory
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -88,7 +88,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
llm: LLM
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history" #: :meta private:

@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -25,7 +25,7 @@ class LLMChain(Chain, BaseModel):
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
output_key: str = "text" #: :meta private:

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.utilities.bash import BashProcess
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
llm_bash = LLMBashChain(llm=OpenAI())
"""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:

@ -14,7 +14,7 @@ from langchain.chains.llm_checker.prompt import (
REVISED_ANSWER_PROMPT,
)
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
@ -28,7 +28,7 @@ class LLMCheckerChain(Chain, BaseModel):
checker_chain = LLMCheckerChain(llm=llm)
"""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.python import PythonREPL
@ -21,7 +21,7 @@ class LLMMathChain(Chain, BaseModel):
llm_math = LLMMathChain(llm=OpenAI())
"""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:

@ -15,7 +15,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter
@ -32,7 +32,7 @@ class MapReduceChain(Chain, BaseModel):
@classmethod
def from_params(
cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt)

@ -8,7 +8,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
@ -22,7 +22,7 @@ class NatBotChain(Chain, BaseModel):
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
"""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
objective: str
"""Objective that NatBot is tasked with completing."""

@ -13,7 +13,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL
@ -21,7 +21,7 @@ from langchain.python import PythonREPL
class PALChain(Chain, BaseModel):
"""Implements Program-Aided Language Models."""
llm: LLM
llm: BaseLLM
prompt: BasePromptTemplate
stop: str = "\n\n"
get_answer_expr: str = "print(solution())"
@ -59,7 +59,7 @@ class PALChain(Chain, BaseModel):
return {self.output_key: res.strip()}
@classmethod
def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt."""
return cls(
llm=llm,
@ -70,7 +70,7 @@ class PALChain(Chain, BaseModel):
)
@classmethod
def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
"""Load PAL from colored object prompt."""
return cls(
llm=llm,

@ -11,19 +11,19 @@ from langchain.chains.qa_with_sources import (
refine_prompts,
stuff_prompt,
)
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: LLM,
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries",
**kwargs: Any,
@ -38,7 +38,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: LLM,
llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
@ -72,7 +72,7 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: LLM,
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
@ -93,7 +93,7 @@ def _load_refine_chain(
def load_qa_with_sources_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain.

@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
QUESTION_PROMPT,
)
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -35,7 +35,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod
def from_llm(
cls,
llm: LLM,
llm: BaseLLM,
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,

@ -11,19 +11,19 @@ from langchain.chains.question_answering import (
refine_prompts,
stuff_prompt,
)
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: LLM,
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context",
**kwargs: Any,
@ -36,7 +36,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: LLM,
llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
combine_document_variable_name: str = "summaries",
@ -67,7 +67,7 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: LLM,
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_variable_name: str = "context_str",
@ -86,7 +86,7 @@ def _load_refine_chain(
def load_qa_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Load question answering chain.

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.sql_database import SQLDatabase
@ -22,7 +22,7 @@ class SQLDatabaseChain(Chain, BaseModel):
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
"""
llm: LLM
llm: BaseLLM
"""LLM wrapper to use."""
database: SQLDatabase
"""SQL Database to connect to."""

@ -7,19 +7,19 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: LLM,
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text",
**kwargs: Any,
@ -32,7 +32,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: LLM,
llm: BaseLLM,
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text",
@ -63,7 +63,7 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: LLM,
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
document_variable_name: str = "text",
@ -82,7 +82,7 @@ def _load_refine_chain(
def load_summarize_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Load summarizing chain.

@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.vector_db_qa.prompt import PROMPT
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.vectorstores.base import VectorStore
@ -84,7 +84,7 @@ class VectorDBQA(Chain, BaseModel):
@classmethod
def from_llm(
cls, llm: LLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
) -> VectorDBQA:
"""Initialize from LLM."""
llm_chain = LLMChain(llm=llm, prompt=prompt)

@ -2,7 +2,7 @@
from typing import List
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example(
examples: List[dict], llm: LLM, prompt_template: PromptTemplate
examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate
) -> str:
"""Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate(

@ -2,7 +2,7 @@
from typing import Dict, Type
from langchain.llms.ai21 import AI21
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
@ -18,7 +18,7 @@ __all__ = [
"AI21",
]
type_to_cls_dict: Dict[str, Type[LLM]] = {
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"ai21": AI21,
"cohere": Cohere,
"huggingface_hub": HuggingFaceHub,

@ -21,7 +21,7 @@ class LLMResult(NamedTuple):
"""For arbitrary LLM provider specific output."""
class LLM(BaseModel, ABC):
class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
class Config:
@ -29,16 +29,11 @@ class LLM(BaseModel, ABC):
extra = Extra.forbid
@abstractmethod
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
generations = []
for prompt in prompts:
text = self(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
"""Run the LLM on the given prompts."""
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
@ -88,28 +83,9 @@ class LLM(BaseModel, ABC):
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
@abstractmethod
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
if langchain.llm_cache is None:
return self._call(prompt, stop=stop)
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
if langchain.cache is not None:
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if cache_val is not None:
if isinstance(cache_val, str):
return cache_val
else:
return cache_val[0].text
return_val = self._call(prompt, stop=stop)
if langchain.cache is not None:
langchain.llm_cache.update(prompt, llm_string, return_val)
return return_val
return self.generate([prompt], stop=stop).generations[0][0].text
@property
def _identifying_params(self) -> Mapping[str, Any]:
@ -163,3 +139,26 @@ class LLM(BaseModel, ABC):
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class LLM(BaseLLM):
"""LLM class that expect subclasses to implement a simpler call method.
The purpose of this class is to expose a simpler interface for working
with LLMs, rather than expect the user to implement the full _generate method.
"""
@abstractmethod
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
generations = []
for prompt in prompts:
text = self._call(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

@ -6,10 +6,10 @@ from typing import Union
import yaml
from langchain.llms import type_to_cls_dict
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
def load_llm_from_config(config: dict) -> LLM:
def load_llm_from_config(config: dict) -> BaseLLM:
"""Load LLM from Config Dict."""
if "_type" not in config:
raise ValueError("Must specify an LLM Type in config")
@ -22,7 +22,7 @@ def load_llm_from_config(config: dict) -> LLM:
return llm_cls(**config)
def load_llm(file: Union[str, Path]) -> LLM:
def load_llm(file: Union[str, Path]) -> BaseLLM:
"""Load LLM from file."""
# Convert file to Path object.
if isinstance(file, str):

@ -4,12 +4,12 @@ from typing import Any, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM, LLMResult
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
from langchain.utils import get_from_dict_or_env
class OpenAI(LLM, BaseModel):
class OpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models.
To use, you should have the ``openai`` python package installed, and the
@ -197,23 +197,6 @@ class OpenAI(LLM, BaseModel):
"""Return type of llm."""
return "openai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = openai("Tell me a joke.")
"""
return self.generate([prompt], stop=stop).generations[0][0].text
def get_num_tokens(self, text: str) -> int:
"""Calculate num tokens with tiktoken package."""
# tiktoken NOT supported for Python 3.8 or below

@ -6,7 +6,7 @@ from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
@ -46,7 +46,7 @@ class ModelLaboratory:
@classmethod
def from_llms(
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
) -> ModelLaboratory:
"""Initialize with LLMs to experiment with and optional prompt.

@ -1,9 +1,9 @@
"""Utils for LLM Tests."""
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
"""Assert LLM Equality for tests."""
# Check that they are the same type.
assert type(llm) == type(loaded_llm)

Loading…
Cancel
Save