Mypy bug fix

This commit is contained in:
Laurel Orr 2022-08-06 06:38:07 +00:00
parent 55e0be83e2
commit 5d79281c4c
2 changed files with 19 additions and 18 deletions

View File

@ -47,11 +47,12 @@ class Pipeline:
): ):
"""Initialize.""" """Initialize."""
self.model = model self.model = model
config = model.config # type: ignore
# Used for GPT # Used for GPT
self.max_length = getattr(model.config, "max_position_embeddings", None) self.max_length = getattr(config, "max_position_embeddings", None)
if self.max_length is None: if self.max_length is None:
# Used for T0 # Used for T0
self.max_length = model.config.d_model self.max_length = config.d_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = ( self.device = (
torch.device("cpu") torch.device("cpu")
@ -341,26 +342,26 @@ class HuggingFaceModel(Model):
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices)) tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
] ]
# Convert to tensors # Convert to tensors
tensor_features = {}
for k in features: for k in features:
features[k] = torch.LongTensor(features[k]).to(self.pipeline.device) tensor_features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
print(k, features[k].shape)
# Reduce GPU memory by feeding one at a time # Reduce GPU memory by feeding one at a time
logits = [ logits = [
self.pipeline.model( self.pipeline.model( # type: ignore
input_ids=features["input_ids"][bs].unsqueeze(0), input_ids=tensor_features["input_ids"][bs].unsqueeze(0),
attention_mask=features["attention_mask"][bs].unsqueeze(0), attention_mask=tensor_features["attention_mask"][bs].unsqueeze(0),
labels=features["labels"][bs].unsqueeze(0), labels=tensor_features["labels"][bs].unsqueeze(0),
).logits ).logits
for bs in range(len(features["input_ids"])) for bs in range(len(tensor_features["input_ids"]))
] ]
logits = torch.vstack(logits) stacked_logits = torch.vstack(logits)
# Compute most likely option # Compute most likely option
masked_log_probs = features["labels_attention_mask"].unsqueeze( masked_log_probs = tensor_features["labels_attention_mask"].unsqueeze(
-1 -1
) * torch.log_softmax(logits, dim=-1) ) * torch.log_softmax(stacked_logits, dim=-1)
seq_token_log_probs = torch.gather( seq_token_log_probs = torch.gather(
masked_log_probs, -1, features["labels"].unsqueeze(-1) masked_log_probs, -1, tensor_features["labels"].unsqueeze(-1)
) )
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1) seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
prediction = seq_log_prob.argmax(dim=-1).item() prediction = seq_log_prob.argmax(dim=-1).item()
return gold_choices[prediction] return gold_choices[int(prediction)]

View File

@ -1,6 +1,6 @@
"""Manifest class.""" """Manifest class."""
import logging import logging
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import Any, Iterable, List, Optional, Tuple, Union, cast
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -134,9 +134,9 @@ class Manifest:
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs) possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
else: else:
try: try:
possible_request, full_kwargs = self.client.get_choice_logit_request( possible_request, full_kwargs = cast(
prompt_str, gold_choices, kwargs HuggingFaceClient, self.client
) ).get_choice_logit_request(prompt_str, gold_choices, kwargs)
except AttributeError: except AttributeError:
raise ValueError("`gold_choices` only supported for HF models.") raise ValueError("`gold_choices` only supported for HF models.")
if len(kwargs) > 0: if len(kwargs) > 0: