mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
Mypy bug fix
This commit is contained in:
parent
55e0be83e2
commit
5d79281c4c
@ -47,11 +47,12 @@ class Pipeline:
|
|||||||
):
|
):
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
config = model.config # type: ignore
|
||||||
# Used for GPT
|
# Used for GPT
|
||||||
self.max_length = getattr(model.config, "max_position_embeddings", None)
|
self.max_length = getattr(config, "max_position_embeddings", None)
|
||||||
if self.max_length is None:
|
if self.max_length is None:
|
||||||
# Used for T0
|
# Used for T0
|
||||||
self.max_length = model.config.d_model
|
self.max_length = config.d_model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = (
|
self.device = (
|
||||||
torch.device("cpu")
|
torch.device("cpu")
|
||||||
@ -341,26 +342,26 @@ class HuggingFaceModel(Model):
|
|||||||
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
|
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
|
||||||
]
|
]
|
||||||
# Convert to tensors
|
# Convert to tensors
|
||||||
|
tensor_features = {}
|
||||||
for k in features:
|
for k in features:
|
||||||
features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
|
tensor_features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
|
||||||
print(k, features[k].shape)
|
|
||||||
# Reduce GPU memory by feeding one at a time
|
# Reduce GPU memory by feeding one at a time
|
||||||
logits = [
|
logits = [
|
||||||
self.pipeline.model(
|
self.pipeline.model( # type: ignore
|
||||||
input_ids=features["input_ids"][bs].unsqueeze(0),
|
input_ids=tensor_features["input_ids"][bs].unsqueeze(0),
|
||||||
attention_mask=features["attention_mask"][bs].unsqueeze(0),
|
attention_mask=tensor_features["attention_mask"][bs].unsqueeze(0),
|
||||||
labels=features["labels"][bs].unsqueeze(0),
|
labels=tensor_features["labels"][bs].unsqueeze(0),
|
||||||
).logits
|
).logits
|
||||||
for bs in range(len(features["input_ids"]))
|
for bs in range(len(tensor_features["input_ids"]))
|
||||||
]
|
]
|
||||||
logits = torch.vstack(logits)
|
stacked_logits = torch.vstack(logits)
|
||||||
# Compute most likely option
|
# Compute most likely option
|
||||||
masked_log_probs = features["labels_attention_mask"].unsqueeze(
|
masked_log_probs = tensor_features["labels_attention_mask"].unsqueeze(
|
||||||
-1
|
-1
|
||||||
) * torch.log_softmax(logits, dim=-1)
|
) * torch.log_softmax(stacked_logits, dim=-1)
|
||||||
seq_token_log_probs = torch.gather(
|
seq_token_log_probs = torch.gather(
|
||||||
masked_log_probs, -1, features["labels"].unsqueeze(-1)
|
masked_log_probs, -1, tensor_features["labels"].unsqueeze(-1)
|
||||||
)
|
)
|
||||||
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
|
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
|
||||||
prediction = seq_log_prob.argmax(dim=-1).item()
|
prediction = seq_log_prob.argmax(dim=-1).item()
|
||||||
return gold_choices[prediction]
|
return gold_choices[int(prediction)]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Manifest class."""
|
"""Manifest class."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Iterable, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
@ -134,9 +134,9 @@ class Manifest:
|
|||||||
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
|
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
possible_request, full_kwargs = self.client.get_choice_logit_request(
|
possible_request, full_kwargs = cast(
|
||||||
prompt_str, gold_choices, kwargs
|
HuggingFaceClient, self.client
|
||||||
)
|
).get_choice_logit_request(prompt_str, gold_choices, kwargs)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError("`gold_choices` only supported for HF models.")
|
raise ValueError("`gold_choices` only supported for HF models.")
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user