fix: inference save shards

This commit is contained in:
Zach Nussbaum 2023-04-07 16:23:34 +00:00
parent 573272ad69
commit f974ca651c

View File

@ -6,11 +6,11 @@ from read import read_config
from accelerate.utils import set_seed
from data import load_data_for_inference
from tqdm import tqdm
from datasets import concatenate_datasets, Dataset
from datasets import Dataset
import torch.distributed as dist
from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify
from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
@ -46,7 +46,7 @@ def inference(config):
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
train_sampler = ShardSampler(train_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
@ -55,7 +55,7 @@ def inference(config):
drop_last=True
)
val_sampler = ShardSampler(val_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
@ -69,7 +69,6 @@ def inference(config):
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model.to(f"cuda:{local_rank}")
with torch.no_grad():
@ -107,17 +106,23 @@ def inference(config):
torch.cuda.empty_cache()
dist.barrier()
gathered_train = nested_numpify(distributed_concat(train_outputs))
gathered_train["index"] = np.concatenate(gathered_train["index"])
gathered_train["loss"] = np.concatenate(gathered_train["loss"])
gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"])
train_outputs = nested_numpify(train_outputs)
# stack since they're 0-dim arrays
train_outputs["index"] = np.stack(train_outputs["index"])
train_outputs["loss"] = np.stack(train_outputs["loss"])
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
df_train = Dataset.from_dict(gathered_train)
df_train = Dataset.from_dict(train_outputs)
df_train = df_train.sort("index")
train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"])
train_dataset = train_dataset.add_column("loss", df_train["loss"])
train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset))
curr_idx = df_train["index"]
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx)
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
val_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(val_dataloader, disable=local_rank != 0):
@ -153,25 +158,24 @@ def inference(config):
torch.cuda.empty_cache()
dist.barrier()
gathered_val = nested_numpify(distributed_concat(val_outputs))
val_outputs = nested_numpify(val_outputs)
val_outputs["index"] = np.stack(val_outputs["index"])
val_outputs["loss"] = np.stack(val_outputs["loss"])
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
gathered_val["index"] = np.concatenate(gathered_val["index"])
gathered_val["loss"] = np.concatenate(gathered_val["loss"])
gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"])
df_val = Dataset.from_dict(gathered_val)
df_val = Dataset.from_dict(val_outputs)
df_val = df_val.sort("index")
curr_idx = df_val["index"]
val_dataset = val_dataset.add_column("embeddings", df_val["embeddings"])
val_dataset = val_dataset.add_column("loss", df_val["loss"])
val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset))
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx)
df = concatenate_datasets([train_dataset, val_dataset])
if local_rank == 0:
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"])
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()