From 4be7ca7b4c3886ae4707dd3edda809a583169f41 Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Sun, 14 Apr 2024 07:03:19 +0800 Subject: [PATCH] community[patch]:sparkllm standardize init args (#20194) Related to https://github.com/langchain-ai/langchain/issues/20085 @baskaryan --- .../langchain_community/chat_models/sparkllm.py | 7 ++++++- .../tests/integration_tests/chat_models/test_sparkllm.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index e3b36bed9b..f173519746 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -141,11 +141,16 @@ class ChatSparkLLM(BaseChatModel): spark_llm_domain: Optional[str] = None spark_user_id: str = "lc_user" streaming: bool = False - request_timeout: int = 30 + request_timeout: int = Field(30, alias="timeout") temperature: float = 0.5 top_k: int = 4 model_kwargs: Dict[str, Any] = Field(default_factory=dict) + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" diff --git a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py index fcb3a7a7f9..65fc38712c 100644 --- a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py @@ -3,6 +3,15 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_community.chat_models.sparkllm import ChatSparkLLM +def test_initialization() -> None: + """Test chat model initialization.""" + for model in [ + ChatSparkLLM(timeout=30), + ChatSparkLLM(request_timeout=30), + ]: + assert model.request_timeout == 30 + + def test_chat_spark_llm() -> None: chat = ChatSparkLLM() message = HumanMessage(content="Hello")