@ -7,7 +7,6 @@ from tenacity import (
wait_random_exponential , # type: ignore
)
import openai
from transformers import AutoModelForCausalLM , AutoTokenizer
MessageRole = Literal [ " system " , " user " , " assistant " ]
@ -18,6 +17,14 @@ class Message():
content : str
def message_to_str ( message : Message ) - > str :
return f " { message . role } : { message . content } "
def messages_to_str ( messages : List [ Message ] ) - > str :
return " \n " . join ( [ message_to_str ( message ) for message in messages ] )
@retry ( wait = wait_random_exponential ( min = 1 , max = 60 ) , stop = stop_after_attempt ( 6 ) )
def gpt_completion (
model : str ,
@ -152,7 +159,7 @@ class HFModelBase(ModelBase):
else :
return outs # type: ignore
def prepare_prompt ( self , messages : List [ Message ] ) - > List [ int ] :
def prepare_prompt ( self , messages : List [ Message ] ) :
raise NotImplementedError
def extract_output ( self , output : str ) - > str :
@ -162,6 +169,7 @@ class HFModelBase(ModelBase):
class StarChat ( HFModelBase ) :
def __init__ ( self ) :
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer
model = AutoModelForCausalLM . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
torch_dtype = torch . bfloat16 ,
@ -170,9 +178,9 @@ class StarChat(HFModelBase):
tokenizer = AutoTokenizer . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
)
super ( ) . __init__ ( " star - chat" , model , tokenizer , eos_token_id = 49155 )
super ( ) . __init__ ( " star chat" , model , tokenizer , eos_token_id = 49155 )
def prepare_prompt ( self , messages : List [ Message ] ) - > List [ int ] :
def prepare_prompt ( self , messages : List [ Message ] ) :
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
@ -198,27 +206,58 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a
If a question does not make any sense , or is not factually coherent , explain why instead of answering something not correct . If you don ' t know the answer to a question, please don ' t share false information . """
def __init__ ( self ) :
super ( ) . __init__ ( " code-llama " , " codellama/CodeLlama-34b-Instruct-hf " , 2 )
self . tokenizer = AutoTokenizer . from_pretrained (
self . hf_model_name ,
def __init__ ( self , version : Literal [ " 34b " , " 13b " , " 7b " ] = " 34b " ) :
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained (
f " codellama/CodeLlama- { version } -Instruct-hf " ,
add_eos_token = True ,
add_bos_token = True ,
padding_side = ' left '
)
def prepare_prompt ( self , messages : List [ Message ] ) - > str :
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
if i == len ( messages ) - 1 :
prompt + = " <|assistant|> \n "
return prompt
model = AutoModelForCausalLM . from_pretrained (
f " codellama/CodeLlama- { version } -Instruct-hf " ,
torch_dtype = torch . bfloat16 ,
device_map = " auto " ,
)
super ( ) . __init__ ( " codellama " , model , tokenizer )
def prepare_prompt ( self , messages : List [ Message ] ) :
if messages [ 0 ] . role != " system " :
messages = [
Message ( role = " system " , content = self . DEFAULT_SYSTEM_PROMPT )
] + messages
messages = [
Message ( role = messages [ 1 ] . role , content = self . B_SYS +
messages [ 0 ] . content + self . E_SYS + messages [ 1 ] . content )
] + messages [ 2 : ]
assert all ( [ msg . role == " user " for msg in messages [ : : 2 ] ] ) and all (
[ msg . role == " assistant " for msg in messages [ 1 : : 2 ] ]
) , (
" model only supports ' system ' , ' user ' and ' assistant ' roles, "
" starting with ' system ' , then ' user ' and alternating (u/a/u/a/u...) "
)
messages_tokens : List [ int ] = sum (
[
self . tokenizer . encode (
f " { self . B_INST } { ( prompt . content ) . strip ( ) } { self . E_INST } { ( answer . content ) . strip ( ) } " ,
)
for prompt , answer in zip (
messages [ : : 2 ] ,
messages [ 1 : : 2 ] ,
)
] ,
[ ] ,
)
assert messages [ - 1 ] . role == " user " , f " Last message must be from user, got { messages [ - 1 ] . role } "
messages_tokens + = self . tokenizer . encode (
f " { self . B_INST } { ( messages [ - 1 ] . content ) . strip ( ) } { self . E_INST } " ,
)
# remove eos token from last message
messages_tokens = messages_tokens [ : - 1 ]
import torch
return torch . tensor ( [ messages_tokens ] ) . to ( self . model . device )
def extract_output ( self , output : str ) - > str :
out = output . split ( " <|assistant|> " ) [ 1 ]
if out . endswith ( " <|end|> " ) :
out = out [ : - len ( " <|end|> " ) ]
out = output . split ( " [/INST] " ) [ - 1 ] . split ( " </s> " ) [ 0 ] . strip ( )
return out