mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
support for arbitrary kwargs for llamacpp (#8727)
llamacpp params (per their own code) are unstable, so instead of adding/deleting them constantly adding a model_kwargs parameter that allows for arbitrary additional kwargs cc @jsjolund and @zacps re #8599 and #8704
This commit is contained in:
parent
f0b0c72d98
commit
115a77142a
@ -6,6 +6,8 @@ from pydantic import Field, root_validator
|
|||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
from langchain.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -106,6 +108,9 @@ class LlamaCpp(LLM):
|
|||||||
rope_freq_base: float = 10000.0
|
rope_freq_base: float = 10000.0
|
||||||
"""Base frequency for rope sampling."""
|
"""Base frequency for rope sampling."""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Any additional parameters to pass to llama_cpp.Llama."""
|
||||||
|
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
"""Whether to stream the results, token by token."""
|
"""Whether to stream the results, token by token."""
|
||||||
|
|
||||||
@ -139,6 +144,8 @@ class LlamaCpp(LLM):
|
|||||||
if values["n_gpu_layers"] is not None:
|
if values["n_gpu_layers"] is not None:
|
||||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||||
|
|
||||||
|
model_params.update(values["model_kwargs"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
|
|
||||||
@ -157,6 +164,16 @@ class LlamaCpp(LLM):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
values["model_kwargs"] = build_extra_kwargs(
|
||||||
|
extra, values, all_required_field_names
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling llama_cpp."""
|
"""Get the default parameters for calling llama_cpp."""
|
||||||
|
@ -30,6 +30,7 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
|||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
from langchain.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -215,25 +216,9 @@ class BaseOpenAI(BaseLLM):
|
|||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = get_pydantic_field_names(cls)
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
values["model_kwargs"] = build_extra_kwargs(
|
||||||
if field_name in extra:
|
extra, values, all_required_field_names
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
)
|
||||||
if field_name not in all_required_field_names:
|
|
||||||
warnings.warn(
|
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
|
||||||
{field_name} was transferred to model_kwargs.
|
|
||||||
Please confirm that {field_name} is what you intended."""
|
|
||||||
)
|
|
||||||
extra[field_name] = values.pop(field_name)
|
|
||||||
|
|
||||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
|
||||||
if invalid_model_kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
|
||||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
|
||||||
)
|
|
||||||
|
|
||||||
values["model_kwargs"] = extra
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
|
import warnings
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import Any, Callable, Optional, Set, Tuple
|
from typing import Any, Callable, Dict, Optional, Set, Tuple
|
||||||
|
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
@ -122,7 +123,7 @@ def check_package_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
||||||
"""Get field names, including aliases, for a pydantic class.
|
"""Get field names, including aliases, for a pydantic class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -133,3 +134,30 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
|||||||
if field.has_alias:
|
if field.has_alias:
|
||||||
all_required_field_names.add(field.alias)
|
all_required_field_names.add(field.alias)
|
||||||
return all_required_field_names
|
return all_required_field_names
|
||||||
|
|
||||||
|
|
||||||
|
def build_extra_kwargs(
|
||||||
|
extra_kwargs: Dict[str, Any],
|
||||||
|
values: Dict[str, Any],
|
||||||
|
all_required_field_names: Set[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
""""""
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra_kwargs:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
warnings.warn(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra_kwargs[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
return extra_kwargs
|
||||||
|
@ -4,6 +4,8 @@ import os
|
|||||||
from typing import Generator
|
from typing import Generator
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.llms import LlamaCpp
|
from langchain.llms import LlamaCpp
|
||||||
|
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
@ -68,3 +70,19 @@ def test_llamacpp_streaming_callback() -> None:
|
|||||||
)
|
)
|
||||||
llm("Q: Can you count to 10? A:'1, ")
|
llm("Q: Can you count to 10? A:'1, ")
|
||||||
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
|
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamacpp_model_kwargs() -> None:
|
||||||
|
llm = LlamaCpp(model_path=get_model(), model_kwargs={"n_gqa": None})
|
||||||
|
assert llm.model_kwargs == {"n_gqa": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamacpp_invalid_model_kwargs() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LlamaCpp(model_path=get_model(), model_kwargs={"n_ctx": 1024})
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamacpp_incorrect_field() -> None:
|
||||||
|
with pytest.warns(match="not default parameter"):
|
||||||
|
llm = LlamaCpp(model_path=get_model(), n_gqa=None)
|
||||||
|
llm.model_kwargs == {"n_gqa": None}
|
||||||
|
@ -22,6 +22,12 @@ def test_openai_model_param() -> None:
|
|||||||
assert llm.model_name == "foo"
|
assert llm.model_name == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_model_kwargs() -> None:
|
||||||
|
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_openai_invalid_model_kwargs() -> None:
|
def test_openai_invalid_model_kwargs() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user