mirror of
https://github.com/hwchase17/langchain
synced 2024-11-20 03:25:56 +00:00
community[patch]:sparkllm standardize init args (#20194)
Related to https://github.com/langchain-ai/langchain/issues/20085 @baskaryan
This commit is contained in:
parent
7d7a08e458
commit
4be7ca7b4c
@ -141,11 +141,16 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
spark_llm_domain: Optional[str] = None
|
spark_llm_domain: Optional[str] = None
|
||||||
spark_user_id: str = "lc_user"
|
spark_user_id: str = "lc_user"
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
request_timeout: int = 30
|
request_timeout: int = Field(30, alias="timeout")
|
||||||
temperature: float = 0.5
|
temperature: float = 0.5
|
||||||
top_k: int = 4
|
top_k: int = 4
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
@ -3,6 +3,15 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
|||||||
from langchain_community.chat_models.sparkllm import ChatSparkLLM
|
from langchain_community.chat_models.sparkllm import ChatSparkLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization() -> None:
|
||||||
|
"""Test chat model initialization."""
|
||||||
|
for model in [
|
||||||
|
ChatSparkLLM(timeout=30),
|
||||||
|
ChatSparkLLM(request_timeout=30),
|
||||||
|
]:
|
||||||
|
assert model.request_timeout == 30
|
||||||
|
|
||||||
|
|
||||||
def test_chat_spark_llm() -> None:
|
def test_chat_spark_llm() -> None:
|
||||||
chat = ChatSparkLLM()
|
chat = ChatSparkLLM()
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
|
Loading…
Reference in New Issue
Block a user