Add additional VertexAI Params (#5837)

## Changes

- Added the `stop` param to the `_VertexAICommon` class so it can be set
at llm initialization

## Example Usage

```python
VertexAI(
    # ...
    temperature=0.15,
    max_output_tokens=128,
    top_p=1,
    top_k=40,
    stop=["\n```"],
)
```

## Possible Reviewers

- @hwchase17 
- @agola11
searx_updates
Aidan Holland 11 months ago committed by GitHub
parent 76fcd96dae
commit 9f4b720a63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,8 @@ class _VertexAICommon(BaseModel):
top_k: int = 40
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens."
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
@ -48,11 +50,13 @@ class _VertexAICommon(BaseModel):
}
return {**base_params}
def _predict(self, prompt: str, stop: Optional[List[str]]) -> str:
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
res = self.client.predict(prompt, **self._default_params)
return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str:
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
if stop is None and self.stop is not None:
stop = self.stop
if stop:
return enforce_stop_tokens(text, stop)
return text

Loading…
Cancel
Save