|
|
|
@ -21,7 +21,7 @@ class LLMResult(NamedTuple):
|
|
|
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLM(BaseModel, ABC):
|
|
|
|
|
class BaseLLM(BaseModel, ABC):
|
|
|
|
|
"""LLM wrapper should take in a prompt and return a string."""
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
@ -29,16 +29,11 @@ class LLM(BaseModel, ABC):
|
|
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _generate(
|
|
|
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
|
# TODO: add caching here.
|
|
|
|
|
generations = []
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
text = self(prompt, stop=stop)
|
|
|
|
|
generations.append([Generation(text=text)])
|
|
|
|
|
return LLMResult(generations=generations)
|
|
|
|
|
"""Run the LLM on the given prompts."""
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
|
|
@ -88,28 +83,9 @@ class LLM(BaseModel, ABC):
|
|
|
|
|
# calculate the number of tokens in the tokenized text
|
|
|
|
|
return len(tokenized_text)
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
|
|
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|
"""Check Cache and run the LLM on the given prompt and input."""
|
|
|
|
|
if langchain.llm_cache is None:
|
|
|
|
|
return self._call(prompt, stop=stop)
|
|
|
|
|
params = self._llm_dict()
|
|
|
|
|
params["stop"] = stop
|
|
|
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
|
|
|
if langchain.cache is not None:
|
|
|
|
|
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
|
|
|
|
if cache_val is not None:
|
|
|
|
|
if isinstance(cache_val, str):
|
|
|
|
|
return cache_val
|
|
|
|
|
else:
|
|
|
|
|
return cache_val[0].text
|
|
|
|
|
return_val = self._call(prompt, stop=stop)
|
|
|
|
|
if langchain.cache is not None:
|
|
|
|
|
langchain.llm_cache.update(prompt, llm_string, return_val)
|
|
|
|
|
return return_val
|
|
|
|
|
return self.generate([prompt], stop=stop).generations[0][0].text
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
@ -163,3 +139,26 @@ class LLM(BaseModel, ABC):
|
|
|
|
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"{save_path} must be json or yaml")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLM(BaseLLM):
|
|
|
|
|
"""LLM class that expect subclasses to implement a simpler call method.
|
|
|
|
|
|
|
|
|
|
The purpose of this class is to expose a simpler interface for working
|
|
|
|
|
with LLMs, rather than expect the user to implement the full _generate method.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
|
# TODO: add caching here.
|
|
|
|
|
generations = []
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
text = self._call(prompt, stop=stop)
|
|
|
|
|
generations.append([Generation(text=text)])
|
|
|
|
|
return LLMResult(generations=generations)
|
|
|
|
|