2023-04-07 01:40:39 +00:00
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from read import read_config
|
|
|
|
from accelerate.utils import set_seed
|
|
|
|
from data import load_data_for_inference
|
|
|
|
from tqdm import tqdm
|
2023-04-07 16:23:34 +00:00
|
|
|
from datasets import Dataset
|
2023-04-07 01:40:39 +00:00
|
|
|
import torch.distributed as dist
|
2023-04-07 16:23:34 +00:00
|
|
|
from transformers.trainer_pt_utils import nested_numpify
|
2023-04-07 01:40:39 +00:00
|
|
|
from transformers import DefaultDataCollator
|
2023-04-07 16:23:34 +00:00
|
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
2023-04-07 04:33:34 +00:00
|
|
|
import numpy as np
|
2023-04-07 19:04:19 +00:00
|
|
|
import pyarrow as pa
|
|
|
|
from pyarrow import compute as pc
|
2023-04-07 01:40:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
|
|
|
# calculate cross entropy across batch dim
|
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
# Flatten the tokens
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
|
|
|
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def rank0_print(msg):
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
print(msg)
|
|
|
|
|
|
|
|
|
|
|
|
def inference(config):
|
|
|
|
set_seed(config['seed'])
|
|
|
|
|
|
|
|
rank0_print(f"World size: {dist.get_world_size()}")
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
|
|
|
# llama has no pad token, set it to new token
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
|
|
|
|
|
|
|
|
num_processes = dist.get_world_size()
|
|
|
|
local_rank = dist.get_rank()
|
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
2023-04-07 01:40:39 +00:00
|
|
|
train_dataloader = DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
collate_fn=DefaultDataCollator(),
|
|
|
|
batch_size=config["batch_size"],
|
2023-04-07 12:09:31 +00:00
|
|
|
sampler=train_sampler,
|
|
|
|
drop_last=True
|
2023-04-07 01:40:39 +00:00
|
|
|
)
|
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
|
2023-04-07 01:40:39 +00:00
|
|
|
val_dataloader = DataLoader(
|
|
|
|
val_dataset,
|
|
|
|
collate_fn=DefaultDataCollator(),
|
|
|
|
batch_size=config["batch_size"],
|
2023-04-07 12:09:31 +00:00
|
|
|
sampler=val_sampler,
|
|
|
|
drop_last=True
|
2023-04-07 01:40:39 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
|
|
|
|
trust_remote_code=True,
|
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
|
)
|
|
|
|
model.to(f"cuda:{local_rank}")
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
train_outputs = {"loss": [], "embeddings": [], "index": []}
|
|
|
|
for batch in tqdm(train_dataloader, disable=local_rank != 0):
|
|
|
|
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
|
|
|
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
|
|
|
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
|
|
|
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
|
|
|
train_outputs["loss"].extend(loss)
|
|
|
|
|
|
|
|
embeddings = outputs.hidden_states[-1]
|
|
|
|
batch_size = batch["input_ids"].shape[0]
|
|
|
|
sequence_lengths = []
|
|
|
|
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
|
|
|
# <|endoftext|> is repeated
|
|
|
|
for item in batch["input_ids"]:
|
|
|
|
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
|
|
|
found = False
|
|
|
|
for index in indices:
|
|
|
|
# case where sequence is less than max length
|
|
|
|
if torch.all(item[index:] == tokenizer.pad_token_id):
|
|
|
|
sequence_lengths.append(index)
|
|
|
|
found = True
|
|
|
|
break
|
|
|
|
# case where sequence is >= max length
|
|
|
|
if not found:
|
|
|
|
sequence_lengths.append(len(item) - 1)
|
|
|
|
|
|
|
|
sequence_lengths = torch.tensor(sequence_lengths)
|
|
|
|
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
|
|
|
|
2023-04-07 04:33:34 +00:00
|
|
|
train_outputs["embeddings"].append(pooled_logits)
|
2023-04-07 01:40:39 +00:00
|
|
|
train_outputs["index"].extend(batch["index"].to(model.device))
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
train_outputs = nested_numpify(train_outputs)
|
|
|
|
# stack since they're 0-dim arrays
|
|
|
|
train_outputs["index"] = np.stack(train_outputs["index"])
|
|
|
|
train_outputs["loss"] = np.stack(train_outputs["loss"])
|
|
|
|
train_outputs["embeddings"] = np.concatenate(train_outputs["embeddings"])
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
df_train = Dataset.from_dict(train_outputs)
|
|
|
|
curr_idx = df_train["index"]
|
|
|
|
|
2023-04-07 19:04:19 +00:00
|
|
|
# compute mask in pyarrow since it's super fast
|
|
|
|
# ty @bmschmidt for showing me this!
|
|
|
|
table = train_dataset.data
|
|
|
|
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
|
|
|
filtered_table = table.filter(mask)
|
|
|
|
# convert from pyarrow to Dataset
|
|
|
|
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
|
2023-04-07 16:23:34 +00:00
|
|
|
|
|
|
|
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
|
|
|
|
filtered_train = filtered_train.add_column("loss", df_train["loss"])
|
|
|
|
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
|
|
|
|
|
|
|
|
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
2023-04-07 01:40:39 +00:00
|
|
|
|
|
|
|
val_outputs = {"loss": [], "embeddings": [], "index": []}
|
|
|
|
for batch in tqdm(val_dataloader, disable=local_rank != 0):
|
|
|
|
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
|
|
|
|
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
|
2023-04-08 17:05:40 +00:00
|
|
|
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
|
2023-04-07 01:40:39 +00:00
|
|
|
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
|
|
|
|
val_outputs["loss"].extend(loss)
|
|
|
|
|
2023-04-08 17:05:40 +00:00
|
|
|
embeddings = outputs.hidden_states[-1]
|
2023-04-07 01:40:39 +00:00
|
|
|
batch_size = batch["input_ids"].shape[0]
|
|
|
|
sequence_lengths = []
|
|
|
|
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
|
|
|
|
# <|endoftext|> is repeated
|
|
|
|
for item in batch["input_ids"]:
|
|
|
|
indices = torch.where(item == tokenizer.pad_token_id)[0]
|
|
|
|
found = False
|
|
|
|
for index in indices:
|
2023-04-08 17:05:40 +00:00
|
|
|
# case where sequence is less than max length
|
2023-04-07 01:40:39 +00:00
|
|
|
if torch.all(item[index:] == tokenizer.pad_token_id):
|
|
|
|
sequence_lengths.append(index)
|
|
|
|
found = True
|
|
|
|
break
|
2023-04-08 17:05:40 +00:00
|
|
|
# case where sequence is >= max length
|
2023-04-07 01:40:39 +00:00
|
|
|
if not found:
|
|
|
|
sequence_lengths.append(len(item) - 1)
|
|
|
|
|
|
|
|
sequence_lengths = torch.tensor(sequence_lengths)
|
2023-04-08 17:05:40 +00:00
|
|
|
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 04:33:34 +00:00
|
|
|
val_outputs["embeddings"].append(pooled_logits)
|
2023-04-07 01:40:39 +00:00
|
|
|
val_outputs["index"].extend(batch["index"].to(model.device))
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
val_outputs = nested_numpify(val_outputs)
|
|
|
|
val_outputs["index"] = np.stack(val_outputs["index"])
|
|
|
|
val_outputs["loss"] = np.stack(val_outputs["loss"])
|
|
|
|
val_outputs["embeddings"] = np.concatenate(val_outputs["embeddings"])
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
df_val = Dataset.from_dict(val_outputs)
|
|
|
|
curr_idx = df_val["index"]
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 19:04:19 +00:00
|
|
|
# compute mask in pyarrow since it's super fast
|
|
|
|
# ty @bmschmidt for showing me this!
|
|
|
|
table = val_dataset.data
|
|
|
|
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
|
|
|
|
filtered_table = table.filter(mask)
|
|
|
|
# convert from pyarrow to Dataset
|
|
|
|
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
|
2023-04-07 16:23:34 +00:00
|
|
|
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
|
|
|
|
filtered_val = filtered_val.add_column("loss", df_val["loss"])
|
|
|
|
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
|
2023-04-07 01:40:39 +00:00
|
|
|
|
2023-04-07 16:23:34 +00:00
|
|
|
|
2023-04-07 01:40:39 +00:00
|
|
|
def main():
|
|
|
|
dist.init_process_group("nccl")
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument("--config", type=str, default="config.yaml")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
config = read_config(args.config)
|
|
|
|
|
|
|
|
inference(config)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# parse arguments by reading in a config
|
|
|
|
main()
|
|
|
|
|