From 5b5115c408e9e8d622b74d1391b29c1d840b709c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 26 Jan 2024 09:45:34 -0800 Subject: [PATCH] google-vertexai[patch]: streaming bug (#16603) Fixes errors seen here https://github.com/langchain-ai/langchain/actions/runs/7661680517/job/20881556592#step:9:229 --- .../langchain_google_vertexai/_utils.py | 19 ++++++++++++----- .../langchain_google_vertexai/llms.py | 15 ++++++++++--- .../tests/integration_tests/test_llms.py | 21 ++++++++++++++++--- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py b/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py index 4fe052a56b..4f0bdf3723 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py @@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool: def get_generation_info( - candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool + candidate: Union[TextGenerationResponse, Candidate], + is_gemini: bool, + *, + stream: bool = False, ) -> Dict[str, Any]: if is_gemini: # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body - return { + info = { "is_blocked": any([rating.blocked for rating in candidate.safety_ratings]), "safety_ratings": [ { "category": rating.category.name, "probability_label": rating.probability.name, + "blocked": rating.blocked, } for rating in candidate.safety_ratings ], "citation_metadata": candidate.citation_metadata, } # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body - candidate_dc = dataclasses.asdict(candidate) - candidate_dc.pop("text") - return {k: v for k, v in candidate_dc.items() if not k.startswith("_")} + else: + info = dataclasses.asdict(candidate) + info.pop("text") + info = {k: v for k, v in info.items() if not k.startswith("_")} + if stream: + # Remove non-streamable types, like bools. + info.pop("is_blocked") + return info diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py index b4274c2488..bd4d346a01 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py @@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM): return result.total_tokens def _response_to_generation( - self, response: TextGenerationResponse + self, response: TextGenerationResponse, *, stream: bool = False ) -> GenerationChunk: """Converts a stream response to a generation chunk.""" - generation_info = get_generation_info(response, self._is_gemini_model) + generation_info = get_generation_info( + response, self._is_gemini_model, stream=stream + ) try: text = response.text except AttributeError: @@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM): run_manager=run_manager, **params, ): - chunk = self._response_to_generation(stream_resp) + # Gemini models return GenerationResponse even when streaming, which has a + # candidates field. + stream_resp = ( + stream_resp + if isinstance(stream_resp, TextGenerationResponse) + else stream_resp.candidates[0] + ) + chunk = self._response_to_generation(stream_resp, stream=True) yield chunk if run_manager: run_manager.on_llm_new_token( diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_llms.py b/libs/partners/google-vertexai/tests/integration_tests/test_llms.py index 823c8671dc..ae10d9377c 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_llms.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_llms.py @@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None: "model_name", model_names_to_test_with_default, ) -def test_vertex_call(model_name: str) -> None: +def test_vertex_invoke(model_name: str) -> None: llm = ( VertexAI(model_name=model_name, temperature=0) if model_name else VertexAI(temperature=0.0) ) - output = llm("Say foo:") + output = llm.invoke("Say foo:") assert isinstance(output, str) +@pytest.mark.parametrize( + "model_name", + model_names_to_test_with_default, +) +def test_vertex_generate(model_name: str) -> None: + llm = ( + VertexAI(model_name=model_name, temperature=0) + if model_name + else VertexAI(temperature=0.0) + ) + output = llm.generate(["Say foo:"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 1 + + @pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates") -def test_vertex_generate() -> None: +def test_vertex_generate_multiple_candidates() -> None: llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001") output = llm.generate(["Say foo:"]) assert isinstance(output, LLMResult)