diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py index 39e79ba4f4..a29b037944 100644 --- a/libs/community/langchain_community/llms/yandex.py +++ b/libs/community/langchain_community/llms/yandex.py @@ -57,7 +57,7 @@ class _BaseYandexGPT(Serializable): disable_request_logging: bool = False """YandexGPT API logs all request data by default. If you provide personal data, confidential information, disable logging.""" - _grpc_metadata: Sequence + _grpc_metadata: Optional[Sequence] = None @property def _llm_type(self) -> str: diff --git a/libs/community/tests/unit_tests/llms/test_yandex.py b/libs/community/tests/unit_tests/llms/test_yandex.py new file mode 100644 index 0000000000..aadbc6ad84 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_yandex.py @@ -0,0 +1,38 @@ +import pytest + +from langchain_community.llms.yandex import YandexGPT + + +def test_yandexgpt_initialization() -> None: + llm = YandexGPT( + iam_token="your_iam_token", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + folder_id="your_folder_id", + ) + assert llm.model_name == "yandexgpt-lite" + assert llm.model_uri.startswith("gpt://your_folder_id/yandexgpt-lite/") + + +def test_yandexgpt_model_params() -> None: + llm = YandexGPT( + model_name="custom-model", + model_version="v1", + iam_token="your_iam_token", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + folder_id="your_folder_id", + ) + assert llm.model_name == "custom-model" + assert llm.model_version == "v1" + assert llm.iam_token.get_secret_value() == "your_iam_token" + assert llm.model_uri == "gpt://your_folder_id/custom-model/v1" + + +def test_yandexgpt_invalid_model_params() -> None: + with pytest.raises(ValueError): + YandexGPT(model_uri="", iam_token="your_iam_token") # type: ignore[arg-type] + with pytest.raises(ValueError): + YandexGPT( + iam_token="", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + model_uri="", + )