community[patch]:sparkllm standardize init args (#20194)

Related to https://github.com/langchain-ai/langchain/issues/20085
@baskaryan
This commit is contained in:
Guangdong Liu 2024-04-14 07:03:19 +08:00 committed by GitHub
parent 7d7a08e458
commit 4be7ca7b4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View File

@ -141,11 +141,16 @@ class ChatSparkLLM(BaseChatModel):
spark_llm_domain: Optional[str] = None spark_llm_domain: Optional[str] = None
spark_user_id: str = "lc_user" spark_user_id: str = "lc_user"
streaming: bool = False streaming: bool = False
request_timeout: int = 30 request_timeout: int = Field(30, alias="timeout")
temperature: float = 0.5 temperature: float = 0.5
top_k: int = 4 top_k: int = 4
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""

View File

@ -3,6 +3,15 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
from langchain_community.chat_models.sparkllm import ChatSparkLLM from langchain_community.chat_models.sparkllm import ChatSparkLLM
def test_initialization() -> None:
"""Test chat model initialization."""
for model in [
ChatSparkLLM(timeout=30),
ChatSparkLLM(request_timeout=30),
]:
assert model.request_timeout == 30
def test_chat_spark_llm() -> None: def test_chat_spark_llm() -> None:
chat = ChatSparkLLM() chat = ChatSparkLLM()
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")