diff --git a/docs/docs/integrations/chat/sambastudio.ipynb b/docs/docs/integrations/chat/sambastudio.ipynb index e719354038..dcf1c450c5 100644 --- a/docs/docs/integrations/chat/sambastudio.ipynb +++ b/docs/docs/integrations/chat/sambastudio.ipynb @@ -386,16 +386,17 @@ " return f\"Current date: {date}, Current time: {time}\"\n", "\n", "\n", + "tools = [get_time]\n", + "\n", + "\n", "def invoke_tools(tool_calls, messages):\n", + " available_functions = {tool.name: tool for tool in tools}\n", " for tool_call in tool_calls:\n", - " selected_tool = {\"get_time\": get_time}[tool_call[\"name\"].lower()]\n", + " selected_tool = available_functions[tool_call[\"name\"]]\n", " tool_output = selected_tool.invoke(tool_call[\"args\"])\n", " print(f\"Tool output: {tool_output}\")\n", " messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n", - " return messages\n", - "\n", - "\n", - "tools = [get_time]" + " return messages" ] }, { @@ -429,11 +430,11 @@ ], "source": [ "response = llm_with_tools.invoke(messages)\n", - "if response.tool_calls:\n", + "while len(response.tool_calls) > 0:\n", " print(f\"Intermediate model response: {response.tool_calls}\")\n", " messages.append(response)\n", " messages = invoke_tools(response.tool_calls, messages)\n", - "response = llm.invoke(messages)\n", + "response = llm_with_tools.invoke(messages)\n", "\n", "print(f\"final response: {response.content}\")" ] diff --git a/libs/community/langchain_community/chat_models/sambanova.py b/libs/community/langchain_community/chat_models/sambanova.py index 0227a6a6e2..0ffabb9e85 100644 --- a/libs/community/langchain_community/chat_models/sambanova.py +++ b/libs/community/langchain_community/chat_models/sambanova.py @@ -796,6 +796,9 @@ class ChatSambaStudio(BaseChatModel): model_kwargs: Optional[Dict[str, Any]] = None """Key word arguments to pass to the model.""" + + additional_headers: Dict[str, Any] = Field(default={}) + """Additional headers to send in request""" class Config: populate_by_name = True @@ -942,6 +945,7 @@ class ChatSambaStudio(BaseChatModel): "content": message.content, } ) + #TODO add tools msgs id and assistant msgs tool calls messages_string = json.dumps(messages_dict) else: messages_string = self.special_tokens["start"] @@ -1020,6 +1024,7 @@ class ChatSambaStudio(BaseChatModel): "Authorization": f"Bearer " f"{self.sambastudio_api_key.get_secret_value()}", "Content-Type": "application/json", + **self.additional_headers } # create request payload for generic v1 API @@ -1039,7 +1044,7 @@ class ChatSambaStudio(BaseChatModel): params = {**params, **self.model_kwargs} params = {key: value for key, value in params.items() if value is not None} data = {"items": items, "params": params} - headers = {"key": self.sambastudio_api_key.get_secret_value()} + headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers} # create request payload for generic v1 API elif "api/predict/generic" in self.sambastudio_url: @@ -1075,7 +1080,7 @@ class ChatSambaStudio(BaseChatModel): "instances": [self._messages_to_string(messages)], "params": params, } - headers = {"key": self.sambastudio_api_key.get_secret_value()} + headers = {"key": self.sambastudio_api_key.get_secret_value(), **self.additional_headers} else: raise ValueError(