diff --git a/langchain/llms/writer.py b/langchain/llms/writer.py index 2cec183515..d704205d65 100644 --- a/langchain/llms/writer.py +++ b/langchain/llms/writer.py @@ -13,8 +13,8 @@ from langchain.utils import get_from_dict_or_env class Writer(LLM): """Wrapper around Writer large language models. - To use, you should have the environment variable ``WRITER_API_KEY`` - set with your API key. + To use, you should have the environment variable ``WRITER_API_KEY`` and + ``WRITER_ORG_ID`` set with your API key and organization ID respectively. Example: .. code-block:: python @@ -23,56 +23,44 @@ class Writer(LLM): writer = Writer(model_id="palmyra-base") """ - model_id: str = "palmyra-base" + writer_org_id: Optional[str] = None + """Writer organization ID.""" + + model_id: str = "palmyra-instruct" """Model name to use.""" - tokens_to_generate: int = 24 - """Max number of tokens to generate.""" + min_tokens: Optional[int] = None + """Minimum number of tokens to generate.""" + + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + + temperature: Optional[float] = None + """What sampling temperature to use.""" + + top_p: Optional[float] = None + """Total probability mass of tokens to consider at each step.""" + + stop: Optional[List[str]] = None + """Sequences when completion generation will stop.""" + + presence_penalty: Optional[float] = None + """Penalizes repeated tokens regardless of frequency.""" + + repetition_penalty: Optional[float] = None + """Penalizes repeated tokens according to frequency.""" + + best_of: Optional[int] = None + """Generates this many completions server-side and returns the "best".""" logprobs: bool = False """Whether to return log probabilities.""" - temperature: float = 1.0 - """What sampling temperature to use.""" - - length: int = 256 - """The maximum number of tokens to generate in the completion.""" - - top_p: float = 1.0 - """Total probability mass of tokens to consider at each step.""" - - top_k: int = 1 - """The number of highest probability vocabulary tokens to - keep for top-k-filtering.""" - - repetition_penalty: float = 1.0 - """Penalizes repeated tokens according to frequency.""" - - random_seed: int = 0 - """The model generates random results. - Changing the random seed alone will produce a different response - with similar characteristics. It is possible to reproduce results - by fixing the random seed (assuming all other hyperparameters - are also fixed)""" - - beam_search_diversity_rate: float = 1.0 - """Only applies to beam search, i.e. when the beam width is >1. - A higher value encourages beam search to return a more diverse - set of candidates""" - - beam_width: Optional[int] = None - """The number of concurrent candidates to keep track of during - beam search""" - - length_pentaly: float = 1.0 - """Only applies to beam search, i.e. when the beam width is >1. - Larger values penalize long candidates more heavily, thus preferring - shorter candidates""" + n: Optional[int] = None + """How many completions to generate.""" writer_api_key: Optional[str] = None - - stop: Optional[List[str]] = None - """Sequences when completion generation will stop""" + """Writer API key.""" base_url: Optional[str] = None """Base url to use, if None decides based on model name.""" @@ -84,34 +72,41 @@ class Writer(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key exists in environment.""" + """Validate that api key and organization id exist in environment.""" + writer_api_key = get_from_dict_or_env( values, "writer_api_key", "WRITER_API_KEY" ) values["writer_api_key"] = writer_api_key + + writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID") + values["writer_org_id"] = writer_org_id + return values @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Writer API.""" return { - "tokens_to_generate": self.tokens_to_generate, - "stop": self.stop, - "logprobs": self.logprobs, + "minTokens": self.min_tokens, + "maxTokens": self.max_tokens, "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "repetition_penalty": self.repetition_penalty, - "random_seed": self.random_seed, - "beam_search_diversity_rate": self.beam_search_diversity_rate, - "beam_width": self.beam_width, - "length_pentaly": self.length_pentaly, + "topP": self.top_p, + "stop": self.stop, + "presencePenalty": self.presence_penalty, + "repetitionPenalty": self.repetition_penalty, + "bestOf": self.best_of, + "logprobs": self.logprobs, + "n": self.n, } @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" - return {**{"model_id": self.model_id}, **self._default_params} + return { + **{"model_id": self.model_id, "writer_org_id": self.writer_org_id}, + **self._default_params, + } @property def _llm_type(self) -> str: @@ -124,7 +119,7 @@ class Writer(LLM): stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: - """Call out to Writer's complete endpoint. + """Call out to Writer's completions endpoint. Args: prompt: The prompt to pass into the model. @@ -142,12 +137,15 @@ class Writer(LLM): base_url = self.base_url else: base_url = ( - "https://api.llm.writer.com/v1/models/{self.model_id}/completions" + "https://enterprise-api.writer.com/llm" + f"/organization/{self.writer_org_id}" + f"/model/{self.model_id}/completions" ) + response = requests.post( url=base_url, headers={ - "Authorization": f"Bearer {self.writer_api_key}", + "Authorization": f"{self.writer_api_key}", "Content-Type": "application/json", "Accept": "application/json", },