@ -7,8 +7,11 @@ from tenacity import (
wait_random_exponential , # type: ignore
)
import openai
from transformers import AutoModelForCausalLM , AutoTokenizer
MessageRole = Literal [ " system " , " user " , " assistant " ]
@dataclasses.dataclass ( )
class Message ( ) :
role : MessageRole
@ -64,6 +67,7 @@ def gpt_chat(
return [ choice . message . content for choice in response . choices ] # type: ignore
class ModelBase ( ) :
def __init__ ( self , name : str ) :
self . name = name
@ -106,47 +110,115 @@ class GPTDavinci(ModelBase):
return gpt_completion ( self . name , prompt , max_tokens , stop_strs , temperature , num_comps )
class StarChat ( ModelBase ) :
def __init__ ( self ) :
import torch
from transformers import pipeline
self . name = " star-chat "
self . pipe = pipeline (
" text-generation " , model = " HuggingFaceH4/starchat-beta " , torch_dtype = torch . bfloat16 , device_map = " auto " )
class HFModelBase ( ModelBase ) :
"""
Base for huggingface chat models
"""
def __init__ ( self , model_name : str , model , tokenizer , eos_token_id = None ) :
self . name = model_name
self . model = model
self . tokenizer = tokenizer
self . eos_token_id = eos_token_id if eos_token_id is not None else self . tokenizer . eos_token_id
self . is_chat = True
def generate_chat ( self , messages : List [ Message ] , max_tokens : int = 1024 , temperature : float = 0.2 , num_comps : int = 1 ) - > Union [ List [ str ] , str ] :
# NOTE: HF does not like temp of 0.0.
# NOTE: HF does not like temp of 0.0.
if temperature < 0.0001 :
temperature = 0.0001
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
if i == len ( messages ) - 1 :
prompt + = " <|assistant|> \n "
outputs = self . pipe (
prompt = self . prepare_prompt ( messages )
outputs = self . model . generate (
prompt ,
max_new_tokens = max_tokens ,
max_new_tokens = min (
max_tokens , self . model . config . max_position_embeddings ) ,
use_cache = True ,
do_sample = True ,
temperature = temperature ,
top_p = 0.95 ,
eos_token_id = 49155 ,
eos_token_id = self . eos_token_id ,
num_return_sequences = num_comps ,
)
outs = [ output [ ' generated_text ' ] for output in outputs ] # type: ignore
outs = self . tokenizer . batch_decode ( outputs , skip_special_tokens = False )
assert isinstance ( outs , list )
for i , out in enumerate ( outs ) :
assert isinstance ( out , str )
out = out . split ( " <|assistant|> " ) [ 1 ]
if out . endswith ( " <|end|> " ) :
out = out [ : - len ( " <|end|> " ) ]
outs [ i ] = out
outs [ i ] = self . extract_output ( out )
if len ( outs ) == 1 :
return outs [ 0 ] # type: ignore
else :
return outs # type: ignore
def prepare_prompt ( self , messages : List [ Message ] ) - > List [ int ] :
raise NotImplementedError
def extract_output ( self , output : str ) - > str :
raise NotImplementedError
class StarChat ( HFModelBase ) :
def __init__ ( self ) :
import torch
model = AutoModelForCausalLM . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
torch_dtype = torch . bfloat16 ,
device_map = " auto " ,
)
tokenizer = AutoTokenizer . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
)
super ( ) . __init__ ( " star-chat " , model , tokenizer , eos_token_id = 49155 )
def prepare_prompt ( self , messages : List [ Message ] ) - > List [ int ] :
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
if i == len ( messages ) - 1 :
prompt + = " <|assistant|> \n "
return self . tokenizer . encode ( prompt , return_tensors = " pt " ) . to ( self . model . device )
def extract_output ( self , output : str ) - > str :
out = output . split ( " <|assistant|> " ) [ 1 ]
if out . endswith ( " <|end|> " ) :
out = out [ : - len ( " <|end|> " ) ]
return out
class CodeLlama ( HFModelBase ) :
B_INST , E_INST = " [INST] " , " [/INST] "
B_SYS , E_SYS = " <<SYS>> \n " , " \n <</SYS>> \n \n "
DEFAULT_SYSTEM_PROMPT = """ \
You are a helpful , respectful and honest assistant . Always answer as helpfully as possible , while being safe . Your answers should not include any harmful , unethical , racist , sexist , toxic , dangerous , or illegal content . Please ensure that your responses are socially unbiased and positive in nature .
If a question does not make any sense , or is not factually coherent , explain why instead of answering something not correct . If you don ' t know the answer to a question, please don ' t share false information . """
def __init__ ( self ) :
super ( ) . __init__ ( " code-llama " , " codellama/CodeLlama-34b-Instruct-hf " , 2 )
self . tokenizer = AutoTokenizer . from_pretrained (
self . hf_model_name ,
add_eos_token = True ,
add_bos_token = True ,
padding_side = ' left '
)
def prepare_prompt ( self , messages : List [ Message ] ) - > str :
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
if i == len ( messages ) - 1 :
prompt + = " <|assistant|> \n "
return prompt
def extract_output ( self , output : str ) - > str :
out = output . split ( " <|assistant|> " ) [ 1 ]
if out . endswith ( " <|end|> " ) :
out = out [ : - len ( " <|end|> " ) ]
return out