EVAL/utils.py

93 lines
2.8 KiB
Python
Raw Normal View History

2023-03-17 15:55:15 +00:00
import os
import random
import uuid
2023-04-03 07:43:34 +00:00
import numpy as np
2023-03-17 15:55:15 +00:00
os.makedirs("image", exist_ok=True)
2023-03-18 06:05:02 +00:00
os.makedirs("audio", exist_ok=True)
os.makedirs("video", exist_ok=True)
2023-03-17 15:55:15 +00:00
os.makedirs("dataframe", exist_ok=True)
os.makedirs("playground", exist_ok=True)
2023-03-17 15:55:15 +00:00
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
2023-04-03 12:42:08 +00:00
try:
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
except:
pass
2023-03-17 15:55:15 +00:00
return seed
def prompts(name, description):
def decorator(func):
func.name = name
func.description = description
return func
return decorator
def cut_dialogue_history(history_memory, keep_last_n_words=500):
tokens = history_memory.split()
n_tokens = len(tokens)
2023-03-20 08:27:20 +00:00
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
2023-03-17 15:55:15 +00:00
if n_tokens < keep_last_n_words:
return history_memory
else:
paragraphs = history_memory.split("\n")
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens = last_n_tokens - len(paragraphs[0].split(" "))
paragraphs = paragraphs[1:]
return "\n" + "\n".join(paragraphs)
def get_new_image_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split(".")[0].split("_")
this_new_uuid = str(uuid.uuid4())[0:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format(
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
)
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format(
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
)
return os.path.join(head, new_file_name)
def get_new_dataframe_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split(".")[0].split("_")
this_new_uuid = str(uuid.uuid4())[0:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format(
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
)
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format(
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
)
return os.path.join(head, new_file_name)