support for arbitrary kwargs for llamacpp (#8727)

llamacpp params (per their own code) are unstable, so instead of
adding/deleting them constantly adding a model_kwargs parameter that
allows for arbitrary additional kwargs

cc @jsjolund and @zacps re #8599 and #8704
pull/8759/head
Bagatur 11 months ago committed by GitHub
parent f0b0c72d98
commit 115a77142a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,6 +6,8 @@ from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
from langchain.utils import get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -106,6 +108,9 @@ class LlamaCpp(LLM):
rope_freq_base: float = 10000.0
"""Base frequency for rope sampling."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Any additional parameters to pass to llama_cpp.Llama."""
streaming: bool = True
"""Whether to stream the results, token by token."""
@ -139,6 +144,8 @@ class LlamaCpp(LLM):
if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"]
model_params.update(values["model_kwargs"])
try:
from llama_cpp import Llama
@ -157,6 +164,16 @@ class LlamaCpp(LLM):
return values
@root_validator(pre=True)
def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling llama_cpp."""

@ -30,6 +30,7 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@ -215,25 +216,9 @@ class BaseOpenAI(BaseLLM):
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()

@ -2,8 +2,9 @@
import contextlib
import datetime
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Optional, Set, Tuple
from typing import Any, Callable, Dict, Optional, Set, Tuple
from packaging.version import parse
from requests import HTTPError, Response
@ -122,7 +123,7 @@ def check_package_version(
)
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class.
Args:
@ -133,3 +134,30 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set:
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
def build_extra_kwargs(
extra_kwargs: Dict[str, Any],
values: Dict[str, Any],
all_required_field_names: Set[str],
) -> Dict[str, Any]:
""""""
for field_name in list(values):
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra_kwargs[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
return extra_kwargs

@ -4,6 +4,8 @@ import os
from typing import Generator
from urllib.request import urlretrieve
import pytest
from langchain.llms import LlamaCpp
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -68,3 +70,19 @@ def test_llamacpp_streaming_callback() -> None:
)
llm("Q: Can you count to 10? A:'1, ")
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
def test_llamacpp_model_kwargs() -> None:
llm = LlamaCpp(model_path=get_model(), model_kwargs={"n_gqa": None})
assert llm.model_kwargs == {"n_gqa": None}
def test_llamacpp_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
LlamaCpp(model_path=get_model(), model_kwargs={"n_ctx": 1024})
def test_llamacpp_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = LlamaCpp(model_path=get_model(), n_gqa=None)
llm.model_kwargs == {"n_gqa": None}

@ -22,6 +22,12 @@ def test_openai_model_param() -> None:
assert llm.model_name == "foo"
@pytest.mark.requires("openai")
def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):

Loading…
Cancel
Save