langchain_google_vertexai : added logic to override get_num_tokens_from_messages() for ChatVertexAI (#16784)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
- **Description: added logic to override get_num_tokens_from_messages()
for ChatVertexAI. Currently ChatVertexAI was inheriting
get_num_tokens_from_messages() from BaseChatModel which in-turn was
calling GPT-2 tokenizer
  - **Issue: NA
  - **Dependencies: NA
  - **Twitter handle:@aditya_rane

@lkuligin for review

---------

Co-authored-by: adityarane@google.com <adityarane@google.com>
Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
pull/17259/head
Aditya 8 months ago committed by GitHub
parent 00a09e1b71
commit 98176ac982
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -45,6 +45,12 @@ from vertexai.preview.generative_models import ( # type: ignore
Image,
Part,
)
from vertexai.preview.language_models import ( # type: ignore
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from langchain_google_vertexai._utils import (
get_generation_info,
@ -316,12 +322,20 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
return values
def _generate(

@ -31,6 +31,12 @@ from vertexai.preview.generative_models import ( # type: ignore[import-untyped]
Image,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
@ -239,6 +245,27 @@ class _VertexAICommon(_VertexAIBase):
params.pop("candidate_count")
return params
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
is_palm_chat_model = isinstance(
self.client_preview, PreviewChatModel
) or isinstance(self.client_preview, PreviewCodeChatModel)
if is_palm_chat_model:
result = self.client_preview.start_chat().count_tokens(text)
else:
result = self.client_preview.count_tokens([text])
return result.total_tokens
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
@ -300,20 +327,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
raise ValueError("Only one candidate can be generated with streaming!")
return values
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client_preview.count_tokens([text])
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse, *, stream: bool = False
) -> GenerationChunk:

@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_get_num_tokens_from_messages(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name, temperature=0.0)
else:
model = ChatVertexAI(temperature=0.0)
message = HumanMessage(content="Hello")
token = model.get_num_tokens_from_messages(messages=[message])
assert isinstance(token, int)
assert token == 3
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str

@ -81,7 +81,6 @@ def test_tools() -> None:
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
print(response)
assert isinstance(response, dict)
assert response["input"] == "What is 6 raised to the 0.43 power?"
@ -106,7 +105,6 @@ def test_stream() -> None:
]
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
assert len(response) == 1
# for chunk in response:
assert isinstance(response[0], AIMessageChunk)
assert "function_call" in response[0].additional_kwargs

Loading…
Cancel
Save