diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index efbb8889..e0b804c8 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -1,5 +1,5 @@ """Wrapper around AI21 APIs.""" -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional import requests from pydantic import BaseModel, Extra, root_validator @@ -64,6 +64,8 @@ class AI21(LLM, BaseModel): ai21_api_key: Optional[str] = None + stop: Optional[List[str]] = None + base_url: Optional[str] = None """Base url to use, if None decides based on model name.""" @@ -80,7 +82,7 @@ class AI21(LLM, BaseModel): return values @property - def _default_params(self) -> Mapping[str, Any]: + def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling AI21 API.""" return { "temperature": self.temperature, @@ -95,7 +97,7 @@ class AI21(LLM, BaseModel): } @property - def _identifying_params(self) -> Mapping[str, Any]: + def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model": self.model}, **self._default_params} @@ -119,7 +121,11 @@ class AI21(LLM, BaseModel): response = ai21("Tell me a joke.") """ - if stop is None: + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + stop = self.stop + elif stop is None: stop = [] if self.base_url is not None: base_url = self.base_url