Fix LLM types so that they can be loaded from config dicts (#6235)

LLM configurations can be loaded from a Python dict (or JSON file
deserialized as dict) using the
[load_llm_from_config](8e1a7a8646/langchain/llms/loading.py (L12))
function.

However, the type string in the `type_to_cls_dict` lookup dict differs
from the type string defined in some LLM classes. This means that the
LLM object can be saved, but not loaded again, because the type strings
differ.
This commit is contained in:
Jan Pawellek 2023-06-19 02:46:22 +02:00 committed by GitHub
parent 46782ad79b
commit 3e3ed8c5c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 8 additions and 8 deletions

View File

@ -23,7 +23,7 @@ class AlephAlpha(LLM):
.. code-block:: python .. code-block:: python
from langchain.llms import AlephAlpha from langchain.llms import AlephAlpha
alpeh_alpha = AlephAlpha(aleph_alpha_api_key="my-api-key") aleph_alpha = AlephAlpha(aleph_alpha_api_key="my-api-key")
""" """
client: Any #: :meta private: client: Any #: :meta private:
@ -199,7 +199,7 @@ class AlephAlpha(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "alpeh_alpha" return "aleph_alpha"
def _call( def _call(
self, self,
@ -220,7 +220,7 @@ class AlephAlpha(LLM):
Example: Example:
.. code-block:: python .. code-block:: python
response = alpeh_alpha("Tell me a joke.") response = aleph_alpha("Tell me a joke.")
""" """
from aleph_alpha_client import CompletionRequest, Prompt from aleph_alpha_client import CompletionRequest, Prompt

View File

@ -80,7 +80,7 @@ class Banana(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "banana" return "bananadev"
def _call( def _call(
self, self,

View File

@ -106,7 +106,7 @@ class HuggingFaceTextGenInference(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "hf_textgen_inference" return "huggingface_textgen_inference"
def _call( def _call(
self, self,

View File

@ -168,7 +168,7 @@ class LlamaCpp(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "llama.cpp" return "llamacpp"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
""" """

View File

@ -86,7 +86,7 @@ class MosaicML(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "mosaicml" return "mosaic"
def _transform_prompt(self, prompt: str) -> str: def _transform_prompt(self, prompt: str) -> str:
"""Transform prompt.""" """Transform prompt."""

View File

@ -140,7 +140,7 @@ class RWKV(LLM, BaseModel):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return the type of llm.""" """Return the type of llm."""
return "rwkv-4" return "rwkv"
def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any: def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
AVOID_REPEAT_TOKENS = [] AVOID_REPEAT_TOKENS = []