mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
Mypy bug fix
This commit is contained in:
parent
55e0be83e2
commit
5d79281c4c
@ -47,11 +47,12 @@ class Pipeline:
|
||||
):
|
||||
"""Initialize."""
|
||||
self.model = model
|
||||
config = model.config # type: ignore
|
||||
# Used for GPT
|
||||
self.max_length = getattr(model.config, "max_position_embeddings", None)
|
||||
self.max_length = getattr(config, "max_position_embeddings", None)
|
||||
if self.max_length is None:
|
||||
# Used for T0
|
||||
self.max_length = model.config.d_model
|
||||
self.max_length = config.d_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = (
|
||||
torch.device("cpu")
|
||||
@ -341,26 +342,26 @@ class HuggingFaceModel(Model):
|
||||
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
|
||||
]
|
||||
# Convert to tensors
|
||||
tensor_features = {}
|
||||
for k in features:
|
||||
features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
|
||||
print(k, features[k].shape)
|
||||
tensor_features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
|
||||
# Reduce GPU memory by feeding one at a time
|
||||
logits = [
|
||||
self.pipeline.model(
|
||||
input_ids=features["input_ids"][bs].unsqueeze(0),
|
||||
attention_mask=features["attention_mask"][bs].unsqueeze(0),
|
||||
labels=features["labels"][bs].unsqueeze(0),
|
||||
self.pipeline.model( # type: ignore
|
||||
input_ids=tensor_features["input_ids"][bs].unsqueeze(0),
|
||||
attention_mask=tensor_features["attention_mask"][bs].unsqueeze(0),
|
||||
labels=tensor_features["labels"][bs].unsqueeze(0),
|
||||
).logits
|
||||
for bs in range(len(features["input_ids"]))
|
||||
for bs in range(len(tensor_features["input_ids"]))
|
||||
]
|
||||
logits = torch.vstack(logits)
|
||||
stacked_logits = torch.vstack(logits)
|
||||
# Compute most likely option
|
||||
masked_log_probs = features["labels_attention_mask"].unsqueeze(
|
||||
masked_log_probs = tensor_features["labels_attention_mask"].unsqueeze(
|
||||
-1
|
||||
) * torch.log_softmax(logits, dim=-1)
|
||||
) * torch.log_softmax(stacked_logits, dim=-1)
|
||||
seq_token_log_probs = torch.gather(
|
||||
masked_log_probs, -1, features["labels"].unsqueeze(-1)
|
||||
masked_log_probs, -1, tensor_features["labels"].unsqueeze(-1)
|
||||
)
|
||||
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
|
||||
prediction = seq_log_prob.argmax(dim=-1).item()
|
||||
return gold_choices[prediction]
|
||||
return gold_choices[int(prediction)]
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Manifest class."""
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@ -134,9 +134,9 @@ class Manifest:
|
||||
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
|
||||
else:
|
||||
try:
|
||||
possible_request, full_kwargs = self.client.get_choice_logit_request(
|
||||
prompt_str, gold_choices, kwargs
|
||||
)
|
||||
possible_request, full_kwargs = cast(
|
||||
HuggingFaceClient, self.client
|
||||
).get_choice_logit_request(prompt_str, gold_choices, kwargs)
|
||||
except AttributeError:
|
||||
raise ValueError("`gold_choices` only supported for HF models.")
|
||||
if len(kwargs) > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user