refactor llama/model.py

pull/554/head
Denis Mazur 4 months ago
parent 3b00036ec6
commit a945711f58

@ -90,7 +90,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
past_key_values = past_key_values if past_key_values is not None else RemotePastKeyValues()
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Remove prefix

Loading…
Cancel
Save