implement get_num_tokens to use google's count_tokens function (#10565)

can get the correct token count instead of using gpt-2 model

**Description:** 
Implement get_num_tokens within VertexLLM to use google's count_tokens
function.
(https://cloud.google.com/vertex-ai/docs/generative-ai/get-token-count).
So we don't need to download gpt-2 model from huggingface, also when we
do the mapreduce chain we can get correct token count.

**Tag maintainer:** 
@lkuligin 
**Twitter handle:** 
My twitter: @abehsu1992626

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12531/head^2
hsuyuming 9 months ago committed by GitHub
parent 33e77a1007
commit 630ae24b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -276,6 +276,27 @@ class VertexAI(_VertexAICommon, BaseLLM):
raise ValueError("Only one candidate can be generated with streaming!")
return values
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
try:
result = self.client.count_tokens(text)
except AttributeError:
raise NotImplementedError(
"Your google-cloud-aiplatform version didn't implement count_tokens."
"Please, install it with pip install google-cloud-aiplatform>=1.35.0"
)
return result.total_tokens
def _generate(
self,
prompts: List[str],

@ -2,7 +2,7 @@
In order to run this test, you need to install VertexAI SDK (that is is the private
preview) and be whitelisted to list the models themselves:
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
pip install google-cloud-aiplatform>=1.35.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).

@ -1,6 +1,6 @@
"""Test Vertex AI API wrapper.
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
pip install google-cloud-aiplatform>=1.35.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).

@ -2,7 +2,7 @@
In order to run this test, you need to install VertexAI SDK (that is is the private
preview) and be whitelisted to list the models themselves:
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.25.0
pip install google-cloud-aiplatform>=1.35.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
@ -10,7 +10,10 @@ Your end-user credentials would be used to make the calls (make sure you've run
import os
import pytest
from pytest_mock import MockerFixture
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.llms import VertexAI, VertexAIModelGarden
from langchain.schema import LLMResult
@ -108,3 +111,37 @@ async def test_model_garden_agenerate() -> None:
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
def test_vertex_call_trigger_count_tokens() -> None:
llm = VertexAI()
output = llm.get_num_tokens("Hi")
assert output == 2
@pytest.mark.requires("google.cloud.aiplatform")
def test_get_num_tokens_be_called_when_using_mapreduce_chain(
mocker: MockerFixture,
) -> None:
from vertexai.language_models._language_models import CountTokensResponse
m1 = mocker.patch(
"vertexai.preview.language_models._PreviewTextGenerationModel.count_tokens",
return_value=CountTokensResponse(
total_tokens=2,
total_billable_characters=2,
_count_tokens_response={"total_tokens": 2, "total_billable_characters": 2},
),
)
llm = VertexAI()
chain = load_summarize_chain(
llm,
chain_type="map_reduce",
return_intermediate_steps=False,
)
doc = Document(page_content="Hi")
output = chain({"input_documents": [doc]})
assert isinstance(output["output_text"], str)
m1.assert_called_once()
assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id

Loading…
Cancel
Save