Harrison/handle stop tokens ai21 (#1077)

Co-authored-by: Andrew Huang <jhuang16888@gmail.com>
searx-api
Harrison Chase 1 year ago committed by GitHub
parent d8ed286200
commit 52753066ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
"""Wrapper around AI21 APIs.""" """Wrapper around AI21 APIs."""
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -64,6 +64,8 @@ class AI21(LLM, BaseModel):
ai21_api_key: Optional[str] = None ai21_api_key: Optional[str] = None
stop: Optional[List[str]] = None
base_url: Optional[str] = None base_url: Optional[str] = None
"""Base url to use, if None decides based on model name.""" """Base url to use, if None decides based on model name."""
@ -80,7 +82,7 @@ class AI21(LLM, BaseModel):
return values return values
@property @property
def _default_params(self) -> Mapping[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling AI21 API.""" """Get the default parameters for calling AI21 API."""
return { return {
"temperature": self.temperature, "temperature": self.temperature,
@ -95,7 +97,7 @@ class AI21(LLM, BaseModel):
} }
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params} return {**{"model": self.model}, **self._default_params}
@ -119,7 +121,11 @@ class AI21(LLM, BaseModel):
response = ai21("Tell me a joke.") response = ai21("Tell me a joke.")
""" """
if stop is None: if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
elif stop is None:
stop = [] stop = []
if self.base_url is not None: if self.base_url is not None:
base_url = self.base_url base_url = self.base_url

Loading…
Cancel
Save