mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
minor changes
This commit is contained in:
parent
5b40dbe180
commit
43ae94d99e
@ -386,16 +386,17 @@
|
||||
" return f\"Current date: {date}, Current time: {time}\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [get_time]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def invoke_tools(tool_calls, messages):\n",
|
||||
" available_functions = {tool.name: tool for tool in tools}\n",
|
||||
" for tool_call in tool_calls:\n",
|
||||
" selected_tool = {\"get_time\": get_time}[tool_call[\"name\"].lower()]\n",
|
||||
" selected_tool = available_functions[tool_call[\"name\"]]\n",
|
||||
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
|
||||
" print(f\"Tool output: {tool_output}\")\n",
|
||||
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
|
||||
" return messages\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [get_time]"
|
||||
" return messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -429,11 +430,11 @@
|
||||
],
|
||||
"source": [
|
||||
"response = llm_with_tools.invoke(messages)\n",
|
||||
"if response.tool_calls:\n",
|
||||
"while len(response.tool_calls) > 0:\n",
|
||||
" print(f\"Intermediate model response: {response.tool_calls}\")\n",
|
||||
" messages.append(response)\n",
|
||||
" messages = invoke_tools(response.tool_calls, messages)\n",
|
||||
"response = llm.invoke(messages)\n",
|
||||
"response = llm_with_tools.invoke(messages)\n",
|
||||
"\n",
|
||||
"print(f\"final response: {response.content}\")"
|
||||
]
|
||||
|
@ -797,6 +797,9 @@ class ChatSambaStudio(BaseChatModel):
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
additional_headers: Dict[str, Any] = Field(default={})
|
||||
"""Additional headers to send in request"""
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@ -942,6 +945,7 @@ class ChatSambaStudio(BaseChatModel):
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
#TODO add tools msgs id and assistant msgs tool calls
|
||||
messages_string = json.dumps(messages_dict)
|
||||
else:
|
||||
messages_string = self.special_tokens["start"]
|
||||
@ -1020,6 +1024,7 @@ class ChatSambaStudio(BaseChatModel):
|
||||
"Authorization": f"Bearer "
|
||||
f"{self.sambastudio_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
**self.additional_headers
|
||||
}
|
||||
|
||||
# create request payload for generic v1 API
|
||||
@ -1039,7 +1044,7 @@ class ChatSambaStudio(BaseChatModel):
|
||||
params = {**params, **self.model_kwargs}
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
data = {"items": items, "params": params}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value()}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
|
||||
|
||||
# create request payload for generic v1 API
|
||||
elif "api/predict/generic" in self.sambastudio_url:
|
||||
@ -1075,7 +1080,7 @@ class ChatSambaStudio(BaseChatModel):
|
||||
"instances": [self._messages_to_string(messages)],
|
||||
"params": params,
|
||||
}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value()}
|
||||
headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user