added candidate_count for Vertex models (#11729)

- **Description:** added support for `candidate_count` parameter on
Vertex
pull/11762/head^2
Leonid Kuligin 12 months ago committed by GitHub
parent 9d200e6cbe
commit 9f0a718198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,10 +12,7 @@ from langchain.callbacks.manager import (
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.pydantic_v1 import root_validator
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
@ -177,16 +174,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = self._prepare_params(stop=stop, **kwargs)
params = self._prepare_params(stop=stop, stream=False, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = chat.send_message(question.content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=response.text))]
)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
async def _agenerate(
self,
@ -219,11 +222,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = await chat.send_message_async(question.content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=response.text))]
)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
def _stream(
self,
@ -239,7 +247,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
@ -247,11 +255,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
def _start_chat(
self, history: _ChatHistory, params: dict
self, history: _ChatHistory, **kwargs: Any
) -> Union[ChatSession, CodeChatSession]:
if not self.is_codey_model:
return self.client.start_chat(
context=history.context, message_history=history.history, **params
context=history.context, message_history=history.history, **kwargs
)
else:
return self.client.start_chat(message_history=history.history, **params)
return self.client.start_chat(message_history=history.history, **kwargs)

@ -175,7 +175,10 @@ class _VertexAICommon(_VertexAIBase):
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
n: int = 1
"""How many completions to generate for each prompt."""
streaming: bool = False
"""Whether to stream the results or not."""
@property
def _llm_type(self) -> str:
@ -203,6 +206,7 @@ class _VertexAICommon(_VertexAIBase):
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"candidate_count": self.n,
}
@classmethod
@ -215,10 +219,16 @@ class _VertexAICommon(_VertexAIBase):
def _prepare_params(
self,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
params_mapping = {"n": "candidate_count"}
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
if stream or self.streaming:
params.pop("candidate_count")
return params
class VertexAI(_VertexAICommon, BaseLLM):
@ -260,6 +270,9 @@ class VertexAI(_VertexAICommon, BaseLLM):
values["client"] = CodeGenerationModel.from_pretrained(model_name)
except ImportError:
raise_vertex_import_error()
if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
return values
def _generate(
@ -271,7 +284,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any,
) -> LLMResult:
should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, **kwargs)
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations = []
for prompt in prompts:
if should_stream:
@ -285,7 +298,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
res = completion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations)
async def _agenerate(
@ -301,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
res = await acompletion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations)
def _stream(
@ -311,7 +324,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, **kwargs)
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
):

@ -5,7 +5,7 @@ if TYPE_CHECKING:
from google.auth.credentials import Credentials
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:

@ -20,7 +20,10 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
def test_vertexai_instantiation(model_name: str) -> None:
model = ChatVertexAI(model_name=model_name)
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
@ -38,6 +41,15 @@ def test_vertexai_single_call(model_name: str) -> None:
assert isinstance(response.content, str)
def test_candidates() -> None:
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
message = HumanMessage(content="Hello")
response = model.generate(messages=[[message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert len(response.generations[0]) == 2
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_vertexai_agenerate() -> None:
@ -153,7 +165,8 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
with patch(
"vertexai.language_models._language_models.ChatModel.start_chat"
) as start_chat:
mock_response = Mock(text=response_text)
mock_response = MagicMock()
mock_response.candidates = [Mock(text=response_text)]
mock_chat = MagicMock()
start_chat.return_value = mock_chat
mock_send_message = MagicMock(return_value=mock_response)
@ -167,7 +180,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
response = model([message])
assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt)
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
expected_stop_sequence = [stop] if stop else None
start_chat.assert_called_once_with(
context=None,

@ -29,29 +29,31 @@ def test_vertex_call() -> None:
@pytest.mark.scheduled
def test_vertex_generate() -> None:
llm = VertexAI(temperate=0)
output = llm.generate(["Please say foo:"])
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_vertex_agenerate() -> None:
llm = VertexAI(temperate=0)
llm = VertexAI(temperature=0)
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
@pytest.mark.scheduled
def test_vertex_stream() -> None:
llm = VertexAI(temperate=0)
llm = VertexAI(temperature=0)
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
@pytest.mark.asyncio
async def test_vertex_consistency() -> None:
llm = VertexAI(temperate=0)
llm = VertexAI(temperature=0)
output = llm.generate(["Please say foo:"])
streaming_output = llm.generate(["Please say foo:"], stream=True)
async_output = await llm.agenerate(["Please say foo:"])

Loading…
Cancel
Save