mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
openai[patch]: fix get_num_tokens for function calls (#25785)
Closes https://github.com/langchain-ai/langchain/issues/25784
See additional discussion
[here](0a4ee864e9 (r145147380)
).
This commit is contained in:
parent
2aa35d80a0
commit
2e5c379632
@ -947,7 +947,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
else:
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
num_tokens += len(encoding.encode(value))
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
|
@ -677,7 +677,10 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
AIMessage(
|
||||
"",
|
||||
additional_kwargs={
|
||||
"function_call": json.dumps({"arguments": "old", "name": "fun"})
|
||||
"function_call": {
|
||||
"arguments": json.dumps({"arg1": "arg1"}),
|
||||
"name": "fun",
|
||||
}
|
||||
},
|
||||
),
|
||||
AIMessage(
|
||||
@ -688,6 +691,6 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
),
|
||||
ToolMessage("foobar", tool_call_id="foo"),
|
||||
]
|
||||
expected = 170
|
||||
expected = 176
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
assert expected == actual
|
||||
|
Loading…
Reference in New Issue
Block a user