add model kwargs to handle stop token from cohere (#773)

pull/782/head
Harrison Chase 2 years ago committed by GitHub
parent 7198a1cb22
commit 966611bbfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,6 @@
"""Wrapper around Cohere APIs.""" """Wrapper around Cohere APIs."""
from typing import Any, Dict, List, Mapping, Optional import logging
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -7,6 +8,8 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Cohere(LLM, BaseModel): class Cohere(LLM, BaseModel):
"""Wrapper around Cohere large language models. """Wrapper around Cohere large language models.
@ -46,6 +49,8 @@ class Cohere(LLM, BaseModel):
cohere_api_key: Optional[str] = None cohere_api_key: Optional[str] = None
stop: Optional[List[str]] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -69,7 +74,7 @@ class Cohere(LLM, BaseModel):
return values return values
@property @property
def _default_params(self) -> Mapping[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API.""" """Get the default parameters for calling Cohere API."""
return { return {
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -81,7 +86,7 @@ class Cohere(LLM, BaseModel):
} }
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params} return {**{"model": self.model}, **self._default_params}
@ -105,9 +110,15 @@ class Cohere(LLM, BaseModel):
response = cohere("Tell me a joke.") response = cohere("Tell me a joke.")
""" """
response = self.client.generate( params = self._default_params
model=self.model, prompt=prompt, stop_sequences=stop, **self._default_params if self.stop is not None and stop is not None:
) raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
response = self.client.generate(model=self.model, prompt=prompt, **params)
text = response.generations[0].text text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them. # If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them. # In order to make this consistent with other endpoints, we strip them.

Loading…
Cancel
Save