This commit is contained in:
jhpiedrahitao 2024-11-13 13:18:50 -05:00
parent 43ae94d99e
commit 7b53800e27

View File

@ -796,7 +796,7 @@ class ChatSambaStudio(BaseChatModel):
model_kwargs: Optional[Dict[str, Any]] = None
"""Key word arguments to pass to the model."""
additional_headers: Dict[str, Any] = Field(default={})
"""Additional headers to send in request"""
@ -945,7 +945,7 @@ class ChatSambaStudio(BaseChatModel):
"content": message.content,
}
)
#TODO add tools msgs id and assistant msgs tool calls
# TODO add tools msgs id and assistant msgs tool calls
messages_string = json.dumps(messages_dict)
else:
messages_string = self.special_tokens["start"]
@ -1024,7 +1024,7 @@ class ChatSambaStudio(BaseChatModel):
"Authorization": f"Bearer "
f"{self.sambastudio_api_key.get_secret_value()}",
"Content-Type": "application/json",
**self.additional_headers
**self.additional_headers,
}
# create request payload for generic v1 API
@ -1044,7 +1044,10 @@ class ChatSambaStudio(BaseChatModel):
params = {**params, **self.model_kwargs}
params = {key: value for key, value in params.items() if value is not None}
data = {"items": items, "params": params}
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
headers = {
"key": self.sambastudio_api_key.get_secret_value(),
**self.additional_headers,
}
# create request payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
@ -1080,7 +1083,10 @@ class ChatSambaStudio(BaseChatModel):
"instances": [self._messages_to_string(messages)],
"params": params,
}
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
headers = {
"key": self.sambastudio_api_key.get_secret_value(),
**self.additional_headers,
}
else:
raise ValueError(