Mypy bug fix

This commit is contained in:
Laurel Orr 2022-08-06 06:38:07 +00:00
parent 55e0be83e2
commit 5d79281c4c
2 changed files with 19 additions and 18 deletions

View File

@ -47,11 +47,12 @@ class Pipeline:
):
"""Initialize."""
self.model = model
config = model.config # type: ignore
# Used for GPT
self.max_length = getattr(model.config, "max_position_embeddings", None)
self.max_length = getattr(config, "max_position_embeddings", None)
if self.max_length is None:
# Used for T0
self.max_length = model.config.d_model
self.max_length = config.d_model
self.tokenizer = tokenizer
self.device = (
torch.device("cpu")
@ -341,26 +342,26 @@ class HuggingFaceModel(Model):
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
]
# Convert to tensors
tensor_features = {}
for k in features:
features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
print(k, features[k].shape)
tensor_features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
# Reduce GPU memory by feeding one at a time
logits = [
self.pipeline.model(
input_ids=features["input_ids"][bs].unsqueeze(0),
attention_mask=features["attention_mask"][bs].unsqueeze(0),
labels=features["labels"][bs].unsqueeze(0),
self.pipeline.model( # type: ignore
input_ids=tensor_features["input_ids"][bs].unsqueeze(0),
attention_mask=tensor_features["attention_mask"][bs].unsqueeze(0),
labels=tensor_features["labels"][bs].unsqueeze(0),
).logits
for bs in range(len(features["input_ids"]))
for bs in range(len(tensor_features["input_ids"]))
]
logits = torch.vstack(logits)
stacked_logits = torch.vstack(logits)
# Compute most likely option
masked_log_probs = features["labels_attention_mask"].unsqueeze(
masked_log_probs = tensor_features["labels_attention_mask"].unsqueeze(
-1
) * torch.log_softmax(logits, dim=-1)
) * torch.log_softmax(stacked_logits, dim=-1)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, features["labels"].unsqueeze(-1)
masked_log_probs, -1, tensor_features["labels"].unsqueeze(-1)
)
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
prediction = seq_log_prob.argmax(dim=-1).item()
return gold_choices[prediction]
return gold_choices[int(prediction)]

View File

@ -1,6 +1,6 @@
"""Manifest class."""
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union, cast
from tqdm.auto import tqdm
@ -134,9 +134,9 @@ class Manifest:
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
else:
try:
possible_request, full_kwargs = self.client.get_choice_logit_request(
prompt_str, gold_choices, kwargs
)
possible_request, full_kwargs = cast(
HuggingFaceClient, self.client
).get_choice_logit_request(prompt_str, gold_choices, kwargs)
except AttributeError:
raise ValueError("`gold_choices` only supported for HF models.")
if len(kwargs) > 0: