Add shallow prefix-tuned inference (#55)

* Add prefix-tuned inference

* Add prefix-tuned inference

* Add preseq_length in prefix size
pull/57/head
Artem Chumachenko 2 years ago committed by GitHub
parent d271b75dd4
commit 77220c718c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,6 +63,7 @@ class RemoteGenerationMixin:
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@ -104,6 +105,9 @@ class RemoteGenerationMixin:
hypo_ids = torch.arange(outputs[0].size(0))
while True:
embs = self.transformer.word_embeddings(outputs[-1])
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, _ = self.transformer.get_prompt(embs.size(0))
embs = torch.cat([prompts, embs], dim=1)
embs = self.transformer.word_embeddings_layernorm(embs)
hidden_state = sess.step(embs)[:, -1]
hidden_state = self.transformer.ln_f(hidden_state)

Loading…
Cancel
Save