minor changes

This commit is contained in:
jhpiedrahitao 2024-11-13 13:03:22 -05:00
parent 5b40dbe180
commit 43ae94d99e
2 changed files with 15 additions and 9 deletions

View File

@ -386,16 +386,17 @@
" return f\"Current date: {date}, Current time: {time}\"\n",
"\n",
"\n",
"tools = [get_time]\n",
"\n",
"\n",
"def invoke_tools(tool_calls, messages):\n",
" available_functions = {tool.name: tool for tool in tools}\n",
" for tool_call in tool_calls:\n",
" selected_tool = {\"get_time\": get_time}[tool_call[\"name\"].lower()]\n",
" selected_tool = available_functions[tool_call[\"name\"]]\n",
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
" print(f\"Tool output: {tool_output}\")\n",
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
" return messages\n",
"\n",
"\n",
"tools = [get_time]"
" return messages"
]
},
{
@ -429,11 +430,11 @@
],
"source": [
"response = llm_with_tools.invoke(messages)\n",
"if response.tool_calls:\n",
"while len(response.tool_calls) > 0:\n",
" print(f\"Intermediate model response: {response.tool_calls}\")\n",
" messages.append(response)\n",
" messages = invoke_tools(response.tool_calls, messages)\n",
"response = llm.invoke(messages)\n",
"response = llm_with_tools.invoke(messages)\n",
"\n",
"print(f\"final response: {response.content}\")"
]

View File

@ -796,6 +796,9 @@ class ChatSambaStudio(BaseChatModel):
model_kwargs: Optional[Dict[str, Any]] = None
"""Key word arguments to pass to the model."""
additional_headers: Dict[str, Any] = Field(default={})
"""Additional headers to send in request"""
class Config:
populate_by_name = True
@ -942,6 +945,7 @@ class ChatSambaStudio(BaseChatModel):
"content": message.content,
}
)
#TODO add tools msgs id and assistant msgs tool calls
messages_string = json.dumps(messages_dict)
else:
messages_string = self.special_tokens["start"]
@ -1020,6 +1024,7 @@ class ChatSambaStudio(BaseChatModel):
"Authorization": f"Bearer "
f"{self.sambastudio_api_key.get_secret_value()}",
"Content-Type": "application/json",
**self.additional_headers
}
# create request payload for generic v1 API
@ -1039,7 +1044,7 @@ class ChatSambaStudio(BaseChatModel):
params = {**params, **self.model_kwargs}
params = {key: value for key, value in params.items() if value is not None}
data = {"items": items, "params": params}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
# create request payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
@ -1075,7 +1080,7 @@ class ChatSambaStudio(BaseChatModel):
"instances": [self._messages_to_string(messages)],
"params": params,
}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
else:
raise ValueError(