From 3df2d831f93ae075c5cc38549ebc42d2ed0286bb Mon Sep 17 00:00:00 2001 From: Jari Bakken Date: Fri, 19 May 2023 01:32:27 +0200 Subject: [PATCH] Fix get_num_tokens for Anthropic models (#4911) The Anthropic classes used `BaseLanguageModel.get_num_tokens` because of an issue with multiple inheritance. Fixed by moving the method from `_AnthropicCommon` to both its subclasses. This change will significantly speed up token counting for Anthropic users. --- langchain/chat_models/anthropic.py | 6 ++++++ langchain/llms/anthropic.py | 12 ++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py index daed935b..3ac59507 100644 --- a/langchain/chat_models/anthropic.py +++ b/langchain/chat_models/anthropic.py @@ -141,3 +141,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): completion = response["completion"] message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + if not self.count_tokens: + raise NameError("Please ensure the anthropic package is loaded") + return self.count_tokens(text) diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index b71fe682..5c2349f4 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -97,12 +97,6 @@ class _AnthropicCommon(BaseModel): return stop - def get_num_tokens(self, text: str) -> int: - """Calculate number of tokens.""" - if not self.count_tokens: - raise NameError("Please ensure the anthropic package is loaded") - return self.count_tokens(text) - class Anthropic(LLM, _AnthropicCommon): r"""Wrapper around Anthropic's large language models. @@ -263,3 +257,9 @@ class Anthropic(LLM, _AnthropicCommon): stop_sequences=stop, **self._default_params, ) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + if not self.count_tokens: + raise NameError("Please ensure the anthropic package is loaded") + return self.count_tokens(text)