From 9f4b720a632a373c87d6ee5ba3e393c662dbf2dc Mon Sep 17 00:00:00 2001 From: Aidan Holland Date: Wed, 7 Jun 2023 22:20:37 -0400 Subject: [PATCH] Add additional VertexAI Params (#5837) ## Changes - Added the `stop` param to the `_VertexAICommon` class so it can be set at llm initialization ## Example Usage ```python VertexAI( # ... temperature=0.15, max_output_tokens=128, top_p=1, top_k=40, stop=["\n```"], ) ``` ## Possible Reviewers - @hwchase17 - @agola11 --- langchain/llms/vertexai.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index d9d67b37..58ea7359 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -29,6 +29,8 @@ class _VertexAICommon(BaseModel): top_k: int = 40 "How the model selects tokens for output, the next token is selected from " "among the top-k most probable tokens." + stop: Optional[List[str]] = None + "Optional list of stop words to use when generating." project: Optional[str] = None "The default GCP project to use when making Vertex API calls." location: str = "us-central1" @@ -48,11 +50,13 @@ class _VertexAICommon(BaseModel): } return {**base_params} - def _predict(self, prompt: str, stop: Optional[List[str]]) -> str: + def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str: res = self.client.predict(prompt, **self._default_params) return self._enforce_stop_words(res.text, stop) - def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str: + def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: + if stop is None and self.stop is not None: + stop = self.stop if stop: return enforce_stop_tokens(text, stop) return text