forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
5.7 KiB
Python
165 lines
5.7 KiB
Python
"""Base interface for large language models to expose."""
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, Extra
|
|
|
|
import langchain
|
|
from langchain.schema import Generation
|
|
|
|
|
|
class LLMResult(NamedTuple):
|
|
"""Class that contains all relevant information for an LLM Result."""
|
|
|
|
generations: List[List[Generation]]
|
|
"""List of the things generated. This is List[List[]] because
|
|
each input could have multiple generations."""
|
|
llm_output: Optional[dict] = None
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
class BaseLLM(BaseModel, ABC):
|
|
"""LLM wrapper should take in a prompt and return a string."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@abstractmethod
|
|
def _generate(
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
) -> LLMResult:
|
|
"""Run the LLM on the given prompts."""
|
|
|
|
def generate(
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
) -> LLMResult:
|
|
"""Run the LLM on the given prompt and input."""
|
|
if langchain.llm_cache is None:
|
|
return self._generate(prompts, stop=stop)
|
|
params = self._llm_dict()
|
|
params["stop"] = stop
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
missing_prompts = []
|
|
missing_prompt_idxs = []
|
|
existing_prompts = {}
|
|
for i, prompt in enumerate(prompts):
|
|
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
|
if isinstance(cache_val, list):
|
|
existing_prompts[i] = cache_val
|
|
else:
|
|
missing_prompts.append(prompt)
|
|
missing_prompt_idxs.append(i)
|
|
new_results = self._generate(missing_prompts, stop=stop)
|
|
for i, result in enumerate(new_results.generations):
|
|
existing_prompts[i] = result
|
|
prompt = prompts[i]
|
|
langchain.llm_cache.update(prompt, llm_string, result)
|
|
generations = [existing_prompts[i] for i in range(len(prompts))]
|
|
return LLMResult(generations=generations, llm_output=new_results.llm_output)
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
"""Get the number of tokens present in the text."""
|
|
# TODO: this method may not be exact.
|
|
# TODO: this method may differ based on model (eg codex).
|
|
try:
|
|
from transformers import GPT2TokenizerFast
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import transformers python package. "
|
|
"This is needed in order to calculate get_num_tokens. "
|
|
"Please it install it with `pip install transformers`."
|
|
)
|
|
# create a GPT-3 tokenizer instance
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
# tokenize the text using the GPT-3 tokenizer
|
|
tokenized_text = tokenizer.tokenize(text)
|
|
|
|
# calculate the number of tokens in the tokenized text
|
|
return len(tokenized_text)
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""Check Cache and run the LLM on the given prompt and input."""
|
|
return self.generate([prompt], stop=stop).generations[0][0].text
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {}
|
|
|
|
def __str__(self) -> str:
|
|
"""Get a string representation of the object for printing."""
|
|
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
|
|
return f"{cls_name}\nParams: {self._identifying_params}"
|
|
|
|
@property
|
|
@abstractmethod
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
|
|
def _llm_dict(self) -> Dict:
|
|
"""Return a dictionary of the prompt."""
|
|
starter_dict = dict(self._identifying_params)
|
|
starter_dict["_type"] = self._llm_type
|
|
return starter_dict
|
|
|
|
def save(self, file_path: Union[Path, str]) -> None:
|
|
"""Save the LLM.
|
|
|
|
Args:
|
|
file_path: Path to file to save the LLM to.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
llm.save(file_path="path/llm.yaml")
|
|
"""
|
|
# Convert file to Path object.
|
|
if isinstance(file_path, str):
|
|
save_path = Path(file_path)
|
|
else:
|
|
save_path = file_path
|
|
|
|
directory_path = save_path.parent
|
|
directory_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Fetch dictionary to save
|
|
prompt_dict = self._llm_dict()
|
|
|
|
if save_path.suffix == ".json":
|
|
with open(file_path, "w") as f:
|
|
json.dump(prompt_dict, f, indent=4)
|
|
elif save_path.suffix == ".yaml":
|
|
with open(file_path, "w") as f:
|
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
|
else:
|
|
raise ValueError(f"{save_path} must be json or yaml")
|
|
|
|
|
|
class LLM(BaseLLM):
|
|
"""LLM class that expect subclasses to implement a simpler call method.
|
|
|
|
The purpose of this class is to expose a simpler interface for working
|
|
with LLMs, rather than expect the user to implement the full _generate method.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
def _generate(
|
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
) -> LLMResult:
|
|
"""Run the LLM on the given prompt and input."""
|
|
# TODO: add caching here.
|
|
generations = []
|
|
for prompt in prompts:
|
|
text = self._call(prompt, stop=stop)
|
|
generations.append([Generation(text=text)])
|
|
return LLMResult(generations=generations)
|