AnthropicFunctions function_call compatibility (#13901)

- **Description:** Updates to `AnthropicFunctions` to be compatible with
the OpenAI `function_call` functionality.
- **Issue:** The functionality to indicate `auto`, `none` and a forced
function_call was not completely implemented in the existing code.
  - **Dependencies:** None
- **Tag maintainer:** @baskaryan , and any of the other maintainers if
needed.
  - **Twitter handle:** None

I have specifically tested this functionality via AWS Bedrock with the
Claude-2 and Claude-Instant models.
pull/13986/head
Johannes Foulds 10 months ago committed by GitHub
parent 14cc907d35
commit fc40bd4cdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -144,19 +144,30 @@ class AnthropicFunctions(BaseChatModel):
forced = False forced = False
function_call = "" function_call = ""
if "functions" in kwargs: if "functions" in kwargs:
content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2)) # get the function call method
system = SystemMessage(content=content) if "function_call" in kwargs:
messages = [system] + messages function_call = kwargs["function_call"]
del kwargs["function_call"]
else:
function_call = "auto"
# should function calling be used
if function_call != "none":
content = prompt.format(tools=json.dumps(kwargs["functions"], indent=2))
system = SystemMessage(content=content)
messages = [system] + messages
# is the function call a dictionary (forced function calling)
if isinstance(function_call, dict):
forced = True
function_call_name = function_call["name"]
messages.append(AIMessage(content=f"<tool>{function_call_name}</tool>"))
del kwargs["functions"] del kwargs["functions"]
if stop is None: if stop is None:
stop = ["</tool_input>"] stop = ["</tool_input>"]
else: else:
stop.append("</tool_input>") stop.append("</tool_input>")
if "function_call" in kwargs:
forced = True
function_call = kwargs["function_call"]["name"]
AIMessage(content=f"<tool>{function_call}</tool>")
del kwargs["function_call"]
else: else:
if "function_call" in kwargs: if "function_call" in kwargs:
raise ValueError( raise ValueError(
@ -168,12 +179,19 @@ class AnthropicFunctions(BaseChatModel):
completion = response.content completion = response.content
if forced: if forced:
tag_parser = TagParser() tag_parser = TagParser()
tag_parser.feed(completion.strip() + "</tool_input>")
v1 = tag_parser.parse_data["tool_input"][0] if "<tool_input>" in completion:
tag_parser.feed(completion.strip() + "</tool_input>")
v1 = tag_parser.parse_data["tool_input"][0]
arguments = json.dumps(_destrip(v1))
else:
v1 = completion
arguments = ""
kwargs = { kwargs = {
"function_call": { "function_call": {
"name": function_call, "name": function_call_name,
"arguments": json.dumps(_destrip(v1)), "arguments": arguments,
} }
} }
message = AIMessage(content="", additional_kwargs=kwargs) message = AIMessage(content="", additional_kwargs=kwargs)
@ -181,7 +199,7 @@ class AnthropicFunctions(BaseChatModel):
elif "<tool>" in completion: elif "<tool>" in completion:
tag_parser = TagParser() tag_parser = TagParser()
tag_parser.feed(completion.strip() + "</tool_input>") tag_parser.feed(completion.strip() + "</tool_input>")
msg = completion.split("<tool>")[0] msg = completion.split("<tool>")[0].strip()
v1 = tag_parser.parse_data["tool_input"][0] v1 = tag_parser.parse_data["tool_input"][0]
kwargs = { kwargs = {
"function_call": { "function_call": {
@ -192,6 +210,7 @@ class AnthropicFunctions(BaseChatModel):
message = AIMessage(content=msg, additional_kwargs=kwargs) message = AIMessage(content=msg, additional_kwargs=kwargs)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
else: else:
response.content = response.content.strip()
return ChatResult(generations=[ChatGeneration(message=response)]) return ChatResult(generations=[ChatGeneration(message=response)])
@property @property

Loading…
Cancel
Save