google-vertexai[patch]: streaming bug (#16603)

Fixes errors seen here
https://github.com/langchain-ai/langchain/actions/runs/7661680517/job/20881556592#step:9:229
pull/15219/head
Bagatur 5 months ago committed by GitHub
parent a989f82027
commit 5b5115c408
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
candidate: Union[TextGenerationResponse, Candidate],
is_gemini: bool,
*,
stream: bool = False,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
info = {
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
"blocked": rating.blocked,
}
for rating in candidate.safety_ratings
],
"citation_metadata": candidate.citation_metadata,
}
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
candidate_dc = dataclasses.asdict(candidate)
candidate_dc.pop("text")
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
return info

@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
self, response: TextGenerationResponse, *, stream: bool = False
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
generation_info = get_generation_info(response, self._is_gemini_model)
generation_info = get_generation_info(
response, self._is_gemini_model, stream=stream
)
try:
text = response.text
except AttributeError:
@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
run_manager=run_manager,
**params,
):
chunk = self._response_to_generation(stream_resp)
# Gemini models return GenerationResponse even when streaming, which has a
# candidates field.
stream_resp = (
stream_resp
if isinstance(stream_resp, TextGenerationResponse)
else stream_resp.candidates[0]
)
chunk = self._response_to_generation(stream_resp, stream=True)
yield chunk
if run_manager:
run_manager.on_llm_new_token(

@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
"model_name",
model_names_to_test_with_default,
)
def test_vertex_call(model_name: str) -> None:
def test_vertex_invoke(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm("Say foo:")
output = llm.invoke("Say foo:")
assert isinstance(output, str)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_generate(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate() -> None:
def test_vertex_generate_multiple_candidates() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)

Loading…
Cancel
Save