community: sambastudio chat model integration minor fix (#27238)

**Description:** sambastudio chat model integration minor fix
 fix default params
 fix usage metadata when streaming
This commit is contained in:
Jorge Piedrahita Ortiz 2024-10-15 12:24:36 -05:00 committed by GitHub
parent fead4749b9
commit 12fea5b868
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,10 +174,10 @@ class ChatSambaNovaCloud(BaseChatModel):
temperature: float = Field(default=0.7)
"""model temperature"""
top_p: Optional[float] = Field()
top_p: Optional[float] = Field(default=None)
"""model top p"""
top_k: Optional[int] = Field()
top_k: Optional[int] = Field(default=None)
"""model top k"""
stream_options: dict = Field(default={"include_usage": True})
@ -593,7 +593,7 @@ class ChatSambaStudio(BaseChatModel):
streaming_url: str = Field(default="", exclude=True)
"""SambaStudio streaming Url"""
model: Optional[str] = Field()
model: Optional[str] = Field(default=None)
"""The name of the model or expert to use (for CoE endpoints)"""
streaming: bool = Field(default=False)
@ -605,16 +605,16 @@ class ChatSambaStudio(BaseChatModel):
temperature: Optional[float] = Field(default=0.7)
"""model temperature"""
top_p: Optional[float] = Field()
top_p: Optional[float] = Field(default=None)
"""model top p"""
top_k: Optional[int] = Field()
top_k: Optional[int] = Field(default=None)
"""model top k"""
do_sample: Optional[bool] = Field()
do_sample: Optional[bool] = Field(default=None)
"""whether to do sampling"""
process_prompt: Optional[bool] = Field()
process_prompt: Optional[bool] = Field(default=True)
"""whether process prompt (for CoE generic v1 and v2 endpoints)"""
stream_options: dict = Field(default={"include_usage": True})
@ -1012,6 +1012,16 @@ class ChatSambaStudio(BaseChatModel):
"system_fingerprint": data["system_fingerprint"],
"created": data["created"],
}
if data.get("usage") is not None:
content = ""
id = data["id"]
metadata = {
"finish_reason": finish_reason,
"usage": data.get("usage"),
"model_name": data["model"],
"system_fingerprint": data["system_fingerprint"],
"created": data["created"],
}
yield AIMessageChunk(
content=content,
id=id,