From 115a77142ae35c8349bed6b6051906f56d028dff Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 4 Aug 2023 06:52:02 -0700 Subject: [PATCH] support for arbitrary kwargs for llamacpp (#8727) llamacpp params (per their own code) are unstable, so instead of adding/deleting them constantly adding a model_kwargs parameter that allows for arbitrary additional kwargs cc @jsjolund and @zacps re #8599 and #8704 --- libs/langchain/langchain/llms/llamacpp.py | 17 ++++++++++ libs/langchain/langchain/llms/openai.py | 23 +++---------- libs/langchain/langchain/utils/utils.py | 32 +++++++++++++++++-- .../integration_tests/llms/test_llamacpp.py | 18 +++++++++++ .../tests/unit_tests/llms/test_openai.py | 6 ++++ 5 files changed, 75 insertions(+), 21 deletions(-) diff --git a/libs/langchain/langchain/llms/llamacpp.py b/libs/langchain/langchain/llms/llamacpp.py index c3a00b6861..518206462a 100644 --- a/libs/langchain/langchain/llms/llamacpp.py +++ b/libs/langchain/langchain/llms/llamacpp.py @@ -6,6 +6,8 @@ from pydantic import Field, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.schema.output import GenerationChunk +from langchain.utils import get_pydantic_field_names +from langchain.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) @@ -106,6 +108,9 @@ class LlamaCpp(LLM): rope_freq_base: float = 10000.0 """Base frequency for rope sampling.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Any additional parameters to pass to llama_cpp.Llama.""" + streaming: bool = True """Whether to stream the results, token by token.""" @@ -139,6 +144,8 @@ class LlamaCpp(LLM): if values["n_gpu_layers"] is not None: model_params["n_gpu_layers"] = values["n_gpu_layers"] + model_params.update(values["model_kwargs"]) + try: from llama_cpp import Llama @@ -157,6 +164,16 @@ class LlamaCpp(LLM): return values + @root_validator(pre=True) + def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling llama_cpp.""" diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index 52741d7f53..dacafe0f2e 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -30,6 +30,7 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.schema import Generation, LLMResult from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) @@ -215,25 +216,9 @@ class BaseOpenAI(BaseLLM): """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) return values @root_validator() diff --git a/libs/langchain/langchain/utils/utils.py b/libs/langchain/langchain/utils/utils.py index a9390d6a66..6257ca330e 100644 --- a/libs/langchain/langchain/utils/utils.py +++ b/libs/langchain/langchain/utils/utils.py @@ -2,8 +2,9 @@ import contextlib import datetime import importlib +import warnings from importlib.metadata import version -from typing import Any, Callable, Optional, Set, Tuple +from typing import Any, Callable, Dict, Optional, Set, Tuple from packaging.version import parse from requests import HTTPError, Response @@ -122,7 +123,7 @@ def check_package_version( ) -def get_pydantic_field_names(pydantic_cls: Any) -> Set: +def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: """Get field names, including aliases, for a pydantic class. Args: @@ -133,3 +134,30 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set: if field.has_alias: all_required_field_names.add(field.alias) return all_required_field_names + + +def build_extra_kwargs( + extra_kwargs: Dict[str, Any], + values: Dict[str, Any], + all_required_field_names: Set[str], +) -> Dict[str, Any]: + """""" + for field_name in list(values): + if field_name in extra_kwargs: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra_kwargs[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + return extra_kwargs diff --git a/libs/langchain/tests/integration_tests/llms/test_llamacpp.py b/libs/langchain/tests/integration_tests/llms/test_llamacpp.py index e1a28594a1..13a10398a2 100644 --- a/libs/langchain/tests/integration_tests/llms/test_llamacpp.py +++ b/libs/langchain/tests/integration_tests/llms/test_llamacpp.py @@ -4,6 +4,8 @@ import os from typing import Generator from urllib.request import urlretrieve +import pytest + from langchain.llms import LlamaCpp from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -68,3 +70,19 @@ def test_llamacpp_streaming_callback() -> None: ) llm("Q: Can you count to 10? A:'1, ") assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE + + +def test_llamacpp_model_kwargs() -> None: + llm = LlamaCpp(model_path=get_model(), model_kwargs={"n_gqa": None}) + assert llm.model_kwargs == {"n_gqa": None} + + +def test_llamacpp_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + LlamaCpp(model_path=get_model(), model_kwargs={"n_ctx": 1024}) + + +def test_llamacpp_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = LlamaCpp(model_path=get_model(), n_gqa=None) + llm.model_kwargs == {"n_gqa": None} diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index 54750a9592..7af941a432 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -22,6 +22,12 @@ def test_openai_model_param() -> None: assert llm.model_name == "foo" +@pytest.mark.requires("openai") +def test_openai_model_kwargs() -> None: + llm = OpenAI(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + @pytest.mark.requires("openai") def test_openai_invalid_model_kwargs() -> None: with pytest.raises(ValueError):