@ -104,17 +104,18 @@ class RemoteGenerationMixin:
elif max_length is None and max_new_tokens is not None :
max_length = prefix_length + max_new_tokens
if num_beams > 1 and session is not None :
resuming_session = session is not None and session . last_token_id is not None
if num_beams > 1 and resuming_session :
raise NotImplementedError (
" Re us ing inference session in .generate() along with beam search is not supported yet"
" Re sum ing inference session in .generate() along with beam search is not supported yet"
)
if inputs is not None :
assert isinstance ( inputs , torch . Tensor ) and inputs . ndim == 2 , " inputs must be a 2d tensor [batch, length] "
if session is not None and session . last_token_id is not None :
if resuming_session :
inputs = torch . cat ( [ session . last_token_id , inputs ] , dim = 1 )
else :
if session is not None and session . last_token_id is not None :
if resuming_session :
inputs = session . last_token_id
else :
assert bos_token_id is not None , " You have to provide a bos_token_id if you do not provide inputs "
@ -207,6 +208,8 @@ class RemoteGenerationMixin:
outputs = torch . cat ( outputs , dim = - 1 )
if resuming_session :
outputs = outputs [ : , 1 : ]
if num_beams > 1 :
pre_return_idx = [
torch . arange ( idx , num_return_sequences * batch_size , batch_size ) for idx in range ( batch_size )