mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-02 09:40:42 +00:00
fix: concat
This commit is contained in:
parent
1b14b1f723
commit
985da51fbc
19
inference.py
19
inference.py
@ -11,6 +11,7 @@ import torch.distributed as dist
|
||||
from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify
|
||||
from transformers import DefaultDataCollator
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
|
||||
|
||||
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||
@ -99,16 +100,16 @@ def inference(config):
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||
|
||||
train_outputs["embeddings"].extend(pooled_logits)
|
||||
train_outputs["embeddings"].append(pooled_logits)
|
||||
train_outputs["index"].extend(batch["index"].to(model.device))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
dist.barrier()
|
||||
gathered_train = nested_numpify(distributed_concat(train_outputs))
|
||||
|
||||
gathered_train["index"] = [t.item() for t in gathered_train["index"]]
|
||||
gathered_train["loss"] = [t.item() for t in gathered_train["loss"]]
|
||||
gathered_train["index"] = np.concatenate(gathered_train["index"])
|
||||
gathered_train["loss"] = np.concatenate(gathered_train["loss"])
|
||||
gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"])
|
||||
|
||||
df_train = Dataset.from_dict(gathered_train)
|
||||
df_train = df_train.sort("index")
|
||||
@ -146,7 +147,7 @@ def inference(config):
|
||||
sequence_lengths = torch.tensor(sequence_lengths)
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
val_outputs["embeddings"].extend(pooled_logits)
|
||||
val_outputs["embeddings"].append(pooled_logits)
|
||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
@ -154,8 +155,9 @@ def inference(config):
|
||||
dist.barrier()
|
||||
gathered_val = nested_numpify(distributed_concat(val_outputs))
|
||||
|
||||
gathered_val["index"] = [t.item() for t in gathered_val["index"]]
|
||||
gathered_val["loss"] = [t.item() for t in gathered_val["loss"]]
|
||||
gathered_val["index"] = np.concatenate(gathered_val["index"])
|
||||
gathered_val["loss"] = np.concatenate(gathered_val["loss"])
|
||||
gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"])
|
||||
|
||||
df_val = Dataset.from_dict(gathered_val)
|
||||
df_val = df_val.sort("index")
|
||||
@ -165,7 +167,8 @@ def inference(config):
|
||||
val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset))
|
||||
|
||||
df = concatenate_datasets([train_dataset, val_dataset])
|
||||
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
|
||||
if local_rank == 0:
|
||||
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
Reference in New Issue
Block a user