support for arbitrary kwargs for llamacpp (#8727)

llamacpp params (per their own code) are unstable, so instead of
adding/deleting them constantly adding a model_kwargs parameter that
allows for arbitrary additional kwargs

cc @jsjolund and @zacps re #8599 and #8704
This commit is contained in:
Bagatur 2023-08-04 06:52:02 -07:00 committed by GitHub
parent f0b0c72d98
commit 115a77142a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 21 deletions

View File

@ -6,6 +6,8 @@ from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.utils import get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -106,6 +108,9 @@ class LlamaCpp(LLM):
rope_freq_base: float = 10000.0 rope_freq_base: float = 10000.0
"""Base frequency for rope sampling.""" """Base frequency for rope sampling."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Any additional parameters to pass to llama_cpp.Llama."""
streaming: bool = True streaming: bool = True
"""Whether to stream the results, token by token.""" """Whether to stream the results, token by token."""
@ -139,6 +144,8 @@ class LlamaCpp(LLM):
if values["n_gpu_layers"] is not None: if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"] model_params["n_gpu_layers"] = values["n_gpu_layers"]
model_params.update(values["model_kwargs"])
try: try:
from llama_cpp import Llama from llama_cpp import Llama
@ -157,6 +164,16 @@ class LlamaCpp(LLM):
return values return values
@root_validator(pre=True)
def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling llama_cpp.""" """Get the default parameters for calling llama_cpp."""

View File

@ -30,6 +30,7 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -215,25 +216,9 @@ class BaseOpenAI(BaseLLM):
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): values["model_kwargs"] = build_extra_kwargs(
if field_name in extra: extra, values, all_required_field_names
raise ValueError(f"Found {field_name} supplied twice.") )
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values return values
@root_validator() @root_validator()

View File

@ -2,8 +2,9 @@
import contextlib import contextlib
import datetime import datetime
import importlib import importlib
import warnings
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Callable, Optional, Set, Tuple from typing import Any, Callable, Dict, Optional, Set, Tuple
from packaging.version import parse from packaging.version import parse
from requests import HTTPError, Response from requests import HTTPError, Response
@ -122,7 +123,7 @@ def check_package_version(
) )
def get_pydantic_field_names(pydantic_cls: Any) -> Set: def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class. """Get field names, including aliases, for a pydantic class.
Args: Args:
@ -133,3 +134,30 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set:
if field.has_alias: if field.has_alias:
all_required_field_names.add(field.alias) all_required_field_names.add(field.alias)
return all_required_field_names return all_required_field_names
def build_extra_kwargs(
extra_kwargs: Dict[str, Any],
values: Dict[str, Any],
all_required_field_names: Set[str],
) -> Dict[str, Any]:
""""""
for field_name in list(values):
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra_kwargs[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
return extra_kwargs

View File

@ -4,6 +4,8 @@ import os
from typing import Generator from typing import Generator
from urllib.request import urlretrieve from urllib.request import urlretrieve
import pytest
from langchain.llms import LlamaCpp from langchain.llms import LlamaCpp
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -68,3 +70,19 @@ def test_llamacpp_streaming_callback() -> None:
) )
llm("Q: Can you count to 10? A:'1, ") llm("Q: Can you count to 10? A:'1, ")
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
def test_llamacpp_model_kwargs() -> None:
llm = LlamaCpp(model_path=get_model(), model_kwargs={"n_gqa": None})
assert llm.model_kwargs == {"n_gqa": None}
def test_llamacpp_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
LlamaCpp(model_path=get_model(), model_kwargs={"n_ctx": 1024})
def test_llamacpp_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = LlamaCpp(model_path=get_model(), n_gqa=None)
llm.model_kwargs == {"n_gqa": None}

View File

@ -22,6 +22,12 @@ def test_openai_model_param() -> None:
assert llm.model_name == "foo" assert llm.model_name == "foo"
@pytest.mark.requires("openai")
def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None: def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):