imaginAIry/imaginairy/utils/prompt_schedules.py

77 lines
2.3 KiB
Python

"""Functions for managing prompt schedules"""
import csv
import re
from copy import copy
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import frange
def parse_schedule_str(schedule_str):
"""Parse a schedule string into a list of values."""
pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]")
match = pattern.match(schedule_str)
if not match:
msg = f"Invalid kwarg schedule: {schedule_str}"
raise ValueError(msg)
arg_name = match.group(1).replace("-", "_")
if not hasattr(ImaginePrompt(), arg_name):
msg = f"Invalid kwarg schedule. Not a valid argument name: {arg_name}"
raise ValueError(msg)
arg_values = match.group(2)
if ":" in arg_values:
start, end, step = arg_values.split(":")
arg_values = list(frange(float(start), float(end), float(step)))
else:
arg_values = parse_csv_line(arg_values)
return arg_name, arg_values
def parse_schedule_strs(schedule_strs):
"""Parse and validate input prompt schedules."""
schedules = {}
for schedule_str in schedule_strs:
arg_name, arg_values = parse_schedule_str(schedule_str)
schedules[arg_name] = arg_values
# Validate that all schedules have the same length
schedule_lengths = [len(v) for v in schedules.values()]
if len(set(schedule_lengths)) > 1:
raise ValueError("All schedules must have the same length")
return schedules
def prompt_mutator(prompt, schedules):
"""
Given a prompt and a list of kwarg schedules, return a series of prompts that follow the schedule.
kwarg_schedules example:
{
"prompt_strength": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
}
"""
schedule_length = len(next(iter(schedules.values())))
for i in range(schedule_length):
new_prompt = copy(prompt)
for attr_name, schedule in schedules.items():
setattr(new_prompt, attr_name, schedule[i])
new_prompt.validate()
yield new_prompt
def parse_csv_line(line):
reader = csv.reader([line])
for row in reader:
parsed_row = []
for value in row:
try:
parsed_row.append(float(value))
except ValueError:
parsed_row.append(value)
return parsed_row