From 630ae24b282b5e1d5d60b4150bef0808062b3a14 Mon Sep 17 00:00:00 2001 From: hsuyuming Date: Tue, 31 Oct 2023 07:10:05 +0900 Subject: [PATCH] implement get_num_tokens to use google's count_tokens function (#10565) can get the correct token count instead of using gpt-2 model **Description:** Implement get_num_tokens within VertexLLM to use google's count_tokens function. (https://cloud.google.com/vertex-ai/docs/generative-ai/get-token-count). So we don't need to download gpt-2 model from huggingface, also when we do the mapreduce chain we can get correct token count. **Tag maintainer:** @lkuligin **Twitter handle:** My twitter: @abehsu1992626 --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/vertexai.py | 21 ++++++++++ .../chat_models/test_vertexai.py | 2 +- .../embeddings/test_vertexai.py | 2 +- .../integration_tests/llms/test_vertexai.py | 39 ++++++++++++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 70a6ab663f..f2f95035ac 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -276,6 +276,27 @@ class VertexAI(_VertexAICommon, BaseLLM): raise ValueError("Only one candidate can be generated with streaming!") return values + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + try: + result = self.client.count_tokens(text) + except AttributeError: + raise NotImplementedError( + "Your google-cloud-aiplatform version didn't implement count_tokens." + "Please, install it with pip install google-cloud-aiplatform>=1.35.0" + ) + + return result.total_tokens + def _generate( self, prompts: List[str], diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py index 17a7c0ac72..b0bc62ffa3 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py @@ -2,7 +2,7 @@ In order to run this test, you need to install VertexAI SDK (that is is the private preview) and be whitelisted to list the models themselves: In order to run this test, you need to install VertexAI SDK -pip install google-cloud-aiplatform>=1.25.0 +pip install google-cloud-aiplatform>=1.35.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). diff --git a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py b/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py index e59547af8b..bdfe1d24b2 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py @@ -1,6 +1,6 @@ """Test Vertex AI API wrapper. In order to run this test, you need to install VertexAI SDK -pip install google-cloud-aiplatform>=1.25.0 +pip install google-cloud-aiplatform>=1.35.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index 20cfb6d85b..85e8ca7a26 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -2,7 +2,7 @@ In order to run this test, you need to install VertexAI SDK (that is is the private preview) and be whitelisted to list the models themselves: In order to run this test, you need to install VertexAI SDK -pip install google-cloud-aiplatform>=1.25.0 +pip install google-cloud-aiplatform>=1.35.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). @@ -10,7 +10,10 @@ Your end-user credentials would be used to make the calls (make sure you've run import os import pytest +from pytest_mock import MockerFixture +from langchain.chains.summarize import load_summarize_chain +from langchain.docstore.document import Document from langchain.llms import VertexAI, VertexAIModelGarden from langchain.schema import LLMResult @@ -108,3 +111,37 @@ async def test_model_garden_agenerate() -> None: output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2 + + +def test_vertex_call_trigger_count_tokens() -> None: + llm = VertexAI() + output = llm.get_num_tokens("Hi") + assert output == 2 + + +@pytest.mark.requires("google.cloud.aiplatform") +def test_get_num_tokens_be_called_when_using_mapreduce_chain( + mocker: MockerFixture, +) -> None: + from vertexai.language_models._language_models import CountTokensResponse + + m1 = mocker.patch( + "vertexai.preview.language_models._PreviewTextGenerationModel.count_tokens", + return_value=CountTokensResponse( + total_tokens=2, + total_billable_characters=2, + _count_tokens_response={"total_tokens": 2, "total_billable_characters": 2}, + ), + ) + llm = VertexAI() + chain = load_summarize_chain( + llm, + chain_type="map_reduce", + return_intermediate_steps=False, + ) + doc = Document(page_content="Hi") + output = chain({"input_documents": [doc]}) + assert isinstance(output["output_text"], str) + m1.assert_called_once() + assert llm._llm_type == "vertexai" + assert llm.model_name == llm.client._model_id