mirror of
https://github.com/corca-ai/EVAL
synced 2024-10-30 09:20:44 +00:00
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
import os
|
|
import random
|
|
import uuid
|
|
|
|
import numpy as np
|
|
|
|
os.makedirs("image", exist_ok=True)
|
|
os.makedirs("audio", exist_ok=True)
|
|
os.makedirs("video", exist_ok=True)
|
|
os.makedirs("dataframe", exist_ok=True)
|
|
os.makedirs("playground", exist_ok=True)
|
|
|
|
|
|
def seed_everything(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
try:
|
|
import torch
|
|
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
except:
|
|
pass
|
|
return seed
|
|
|
|
|
|
def prompts(name, description):
|
|
def decorator(func):
|
|
func.name = name
|
|
func.description = description
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
|
tokens = history_memory.split()
|
|
n_tokens = len(tokens)
|
|
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
|
|
if n_tokens < keep_last_n_words:
|
|
return history_memory
|
|
else:
|
|
paragraphs = history_memory.split("\n")
|
|
last_n_tokens = n_tokens
|
|
while last_n_tokens >= keep_last_n_words:
|
|
last_n_tokens = last_n_tokens - len(paragraphs[0].split(" "))
|
|
paragraphs = paragraphs[1:]
|
|
return "\n" + "\n".join(paragraphs)
|
|
|
|
|
|
def get_new_image_name(org_img_name, func_name="update"):
|
|
head_tail = os.path.split(org_img_name)
|
|
head = head_tail[0]
|
|
tail = head_tail[1]
|
|
name_split = tail.split(".")[0].split("_")
|
|
this_new_uuid = str(uuid.uuid4())[0:4]
|
|
if len(name_split) == 1:
|
|
most_org_file_name = name_split[0]
|
|
recent_prev_file_name = name_split[0]
|
|
new_file_name = "{}_{}_{}_{}.png".format(
|
|
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
|
|
)
|
|
else:
|
|
assert len(name_split) == 4
|
|
most_org_file_name = name_split[3]
|
|
recent_prev_file_name = name_split[0]
|
|
new_file_name = "{}_{}_{}_{}.png".format(
|
|
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
|
|
)
|
|
return os.path.join(head, new_file_name)
|
|
|
|
|
|
def get_new_dataframe_name(org_img_name, func_name="update"):
|
|
head_tail = os.path.split(org_img_name)
|
|
head = head_tail[0]
|
|
tail = head_tail[1]
|
|
name_split = tail.split(".")[0].split("_")
|
|
this_new_uuid = str(uuid.uuid4())[0:4]
|
|
if len(name_split) == 1:
|
|
most_org_file_name = name_split[0]
|
|
recent_prev_file_name = name_split[0]
|
|
new_file_name = "{}_{}_{}_{}.csv".format(
|
|
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
|
|
)
|
|
else:
|
|
assert len(name_split) == 4
|
|
most_org_file_name = name_split[3]
|
|
recent_prev_file_name = name_split[0]
|
|
new_file_name = "{}_{}_{}_{}.csv".format(
|
|
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
|
|
)
|
|
return os.path.join(head, new_file_name)
|