|
|
|
@ -1,7 +1,8 @@
|
|
|
|
|
"""Wrapper around Cohere APIs."""
|
|
|
|
|
import os
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
import cohere
|
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
@ -28,7 +29,6 @@ class Cohere(BaseModel, LLM):
|
|
|
|
|
cohere = Cohere(model="gptd-instruct-tft")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
|
model: str = "gptd-instruct-tft"
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
|
|
|
|
@ -63,15 +63,6 @@ class Cohere(BaseModel, LLM):
|
|
|
|
|
"Did not find Cohere API key, please add an environment variable"
|
|
|
|
|
" `COHERE_API_KEY` which contains it."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
import cohere
|
|
|
|
|
|
|
|
|
|
values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Could not import cohere python package. "
|
|
|
|
|
"Please it install it with `pip install cohere`."
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
@ -89,7 +80,8 @@ class Cohere(BaseModel, LLM):
|
|
|
|
|
|
|
|
|
|
response = cohere("Tell me a joke.")
|
|
|
|
|
"""
|
|
|
|
|
response = self.client.generate(
|
|
|
|
|
client = cohere.Client(os.environ["COHERE_API_KEY"])
|
|
|
|
|
response = client.generate(
|
|
|
|
|
model=self.model,
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
max_tokens=self.max_tokens,
|
|
|
|
|