Fix get_num_tokens for Anthropic models (#4911)

The Anthropic classes used `BaseLanguageModel.get_num_tokens` because of
an issue with multiple inheritance. Fixed by moving the method from
`_AnthropicCommon` to both its subclasses.

This change will significantly speed up token counting for Anthropic
users.
This commit is contained in:
Jari Bakken 2023-05-19 01:32:27 +02:00 committed by GitHub
parent c8c2276ccb
commit 3df2d831f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 6 deletions

View File

@ -141,3 +141,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
completion = response["completion"]
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)

View File

@ -97,12 +97,6 @@ class _AnthropicCommon(BaseModel):
return stop
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
class Anthropic(LLM, _AnthropicCommon):
r"""Wrapper around Anthropic's large language models.
@ -263,3 +257,9 @@ class Anthropic(LLM, _AnthropicCommon):
stop_sequences=stop,
**self._default_params,
)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)