@ -13,8 +13,8 @@ from langchain.utils import get_from_dict_or_env
class Writer ( LLM ) :
""" Wrapper around Writer large language models.
To use , you should have the environment variable ` ` WRITER_API_KEY ` `
set with your API ke y.
To use , you should have the environment variable ` ` WRITER_API_KEY ` ` and
` ` WRITER_ORG_ID ` ` set with your API ke y and organization ID respectivel y.
Example :
. . code - block : : python
@ -23,56 +23,44 @@ class Writer(LLM):
writer = Writer ( model_id = " palmyra-base " )
"""
model_id : str = " palmyra-base "
writer_org_id : Optional [ str ] = None
""" Writer organization ID. """
model_id : str = " palmyra-instruct "
""" Model name to use. """
tokens_to_generate : int = 24
""" M ax number of tokens to generate."""
min_tokens: Optional [ int ] = None
""" M inimum number of tokens to generate."""
logprobs: bool = Fals e
""" Whether to return log probabilities ."""
max_tokens: Optional [ int ] = Non e
""" Maximum number of tokens to generate ."""
temperature : float = 1.0
temperature : Optional [ float ] = None
""" What sampling temperature to use. """
length : int = 256
""" The maximum number of tokens to generate in the completion. """
top_p : float = 1.0
top_p : Optional [ float ] = None
""" Total probability mass of tokens to consider at each step. """
top_k : int = 1
""" The number of highest probability vocabulary tokens to
keep for top - k - filtering . """
stop : Optional [ List [ str ] ] = None
""" Sequences when completion generation will stop. """
presence_penalty : Optional [ float ] = None
""" Penalizes repeated tokens regardless of frequency. """
repetition_penalty : float = 1.0
repetition_penalty : Optional [ float ] = None
""" Penalizes repeated tokens according to frequency. """
random_seed : int = 0
""" The model generates random results.
Changing the random seed alone will produce a different response
with similar characteristics . It is possible to reproduce results
by fixing the random seed ( assuming all other hyperparameters
are also fixed ) """
beam_search_diversity_rate : float = 1.0
""" Only applies to beam search, i.e. when the beam width is >1.
A higher value encourages beam search to return a more diverse
set of candidates """
best_of : Optional [ int ] = None
""" Generates this many completions server-side and returns the " best " . """
beam_width : Optional [ int ] = None
""" The number of concurrent candidates to keep track of during
beam search """
logprobs : bool = False
""" Whether to return log probabilities. """
length_pentaly : float = 1.0
""" Only applies to beam search, i.e. when the beam width is >1.
Larger values penalize long candidates more heavily , thus preferring
shorter candidates """
n : Optional [ int ] = None
""" How many completions to generate. """
writer_api_key : Optional [ str ] = None
stop : Optional [ List [ str ] ] = None
""" Sequences when completion generation will stop """
""" Writer API key. """
base_url : Optional [ str ] = None
""" Base url to use, if None decides based on model name. """
@ -84,34 +72,41 @@ class Writer(LLM):
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that api key exists in environment. """
""" Validate that api key and organization id exist in environment. """
writer_api_key = get_from_dict_or_env (
values , " writer_api_key " , " WRITER_API_KEY "
)
values [ " writer_api_key " ] = writer_api_key
writer_org_id = get_from_dict_or_env ( values , " writer_org_id " , " WRITER_ORG_ID " )
values [ " writer_org_id " ] = writer_org_id
return values
@property
def _default_params ( self ) - > Mapping [ str , Any ] :
""" Get the default parameters for calling Writer API. """
return {
" tokens_to_generate " : self . tokens_to_generate ,
" minTokens " : self . min_tokens ,
" maxTokens " : self . max_tokens ,
" temperature " : self . temperature ,
" topP " : self . top_p ,
" stop " : self . stop ,
" presencePenalty " : self . presence_penalty ,
" repetitionPenalty " : self . repetition_penalty ,
" bestOf " : self . best_of ,
" logprobs " : self . logprobs ,
" temperature " : self . temperature ,
" top_p " : self . top_p ,
" top_k " : self . top_k ,
" repetition_penalty " : self . repetition_penalty ,
" random_seed " : self . random_seed ,
" beam_search_diversity_rate " : self . beam_search_diversity_rate ,
" beam_width " : self . beam_width ,
" length_pentaly " : self . length_pentaly ,
" n " : self . n ,
}
@property
def _identifying_params ( self ) - > Mapping [ str , Any ] :
""" Get the identifying parameters. """
return { * * { " model_id " : self . model_id } , * * self . _default_params }
return {
* * { " model_id " : self . model_id , " writer_org_id " : self . writer_org_id } ,
* * self . _default_params ,
}
@property
def _llm_type ( self ) - > str :
@ -124,7 +119,7 @@ class Writer(LLM):
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
) - > str :
""" Call out to Writer ' s complet e endpoint.
""" Call out to Writer ' s complet ions endpoint.
Args :
prompt : The prompt to pass into the model .
@ -142,12 +137,15 @@ class Writer(LLM):
base_url = self . base_url
else :
base_url = (
" https://api.llm.writer.com/v1/models/ {self.model_id} /completions "
" https://enterprise-api.writer.com/llm "
f " /organization/ { self . writer_org_id } "
f " /model/ { self . model_id } /completions "
)
response = requests . post (
url = base_url ,
headers = {
" Authorization " : f " Bearer { self . writer_api_key } " ,
" Authorization " : f " { self . writer_api_key } " ,
" Content-Type " : " application/json " ,
" Accept " : " application/json " ,
} ,