Use correct tokenizer for Bedrock/Anthropic LLMs (#11561)

**Description**

This PR implements the usage of the correct tokenizer in Bedrock LLMs,
if using anthropic models.

**Issue:** #11560

**Dependencies:** optional dependency on `anthropic` python library.

**Twitter handle:** jtolgyesi


---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11718/head
Janos Tolgyesi 12 months ago committed by GitHub
parent 467b082c34
commit 15687a28d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,10 @@ from langchain.llms.bedrock import BedrockBase
from langchain.pydantic_v1 import Extra
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain.utilities.anthropic import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
class ChatPromptAdapter:
@ -86,3 +90,15 @@ class BedrockChat(BaseChatModel, BedrockBase):
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)

@ -8,6 +8,10 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utilities.anthropic import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
@ -222,6 +226,10 @@ class BedrockBase(BaseModel, ABC):
def _get_provider(self) -> str:
return self.model_id.split(".")[0]
@property
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"
def _prepare_input_and_invoke(
self,
prompt: str,
@ -318,7 +326,7 @@ class Bedrock(LLM, BedrockBase):
from bedrock_langchain.bedrock_llm import BedrockLLM
llm = BedrockLLM(
credentials_profile_name="default",
credentials_profile_name="default",
model_id="amazon.titan-text-express-v1",
streaming=True
)
@ -393,3 +401,15 @@ class Bedrock(LLM, BedrockBase):
return completion
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)

@ -0,0 +1,25 @@
from typing import Any, List
def _get_anthropic_client() -> Any:
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"This is needed in order to accurately tokenize the text "
"for anthropic models. Please install it with `pip install anthropic`."
)
return anthropic.Anthropic()
def get_num_tokens_anthropic(text: str) -> int:
client = _get_anthropic_client()
return client.count_tokens(text=text)
def get_token_ids_anthropic(text: str) -> List[int]:
client = _get_anthropic_client()
tokenizer = client.get_tokenizer()
encoded_text = tokenizer.encode(text)
return encoded_text.ids
Loading…
Cancel
Save