reflexion-human-eval/webshop_runs/webshop_trial.py
2023-05-21 15:35:36 +02:00

312 lines
12 KiB
Python

import os
import sys
import openai
import requests
from bs4 import BeautifulSoup
from bs4.element import Comment
from env_history import EnvironmentHistory
from typing import Any, Dict, List, Tuple
openai.api_key = os.environ["OPENAI_API_KEY"]
WEBSHOP_URL = "http://3.83.245.205:3000"
ACTION_TO_TEMPLATE = {
'Description': 'description_page.html',
'Features': 'features_page.html',
'Reviews': 'review_page.html',
'Attributes': 'attributes_page.html',
}
with open("./base_prompt.txt", 'r') as f:
BASE_PROMPT = f.read()
def llm(prompt, stop=["\n"]):
try:
cur_try = 0
while cur_try < 6:
response = openai.Completion.create(
model="text-davinci-002",
prompt=prompt,
temperature=cur_try * 0.2,
max_tokens=100,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop
)
text = response["choices"][0]["text"]
# dumb way to do this
if len(text.strip()) >= 5:
return response["choices"][0]["text"]
cur_try += 1
return ""
except Exception as e:
print(prompt)
import sys
sys.exit(1)
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
def tag_visible(element):
ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'}
return (
element.parent.name not in ignore and not isinstance(element, Comment)
)
def webshop_text(session, page_type, query_string='', page_num=1, asin='', options={}, subpage='', **kwargs):
if page_type == 'init':
url = (
f'{WEBSHOP_URL}/{session}'
)
if page_type == 'search':
url = (
f'{WEBSHOP_URL}/search_results/{session}/'
f'{query_string}/{page_num}'
)
elif page_type == 'item':
url = (
f'{WEBSHOP_URL}/item_page/{session}/'
f'{asin}/{query_string}/{page_num}/{options}'
)
elif page_type == 'item_sub':
url = (
f'{WEBSHOP_URL}/item_sub_page/{session}/'
f'{asin}/{query_string}/{page_num}/{subpage}/{options}'
)
elif page_type == 'end':
url = (
f'{WEBSHOP_URL}/done/{session}/'
f'{asin}/{options}'
)
# print(url)
html = requests.get(url).text # type: ignore
html_obj = BeautifulSoup(html, 'html.parser')
texts = html_obj.findAll(text=True)
visible_texts = list(filter(tag_visible, texts))
# visible_texts = [str(text).strip().strip('\\n') for text in visible_texts]
# if page_type == 'end': import pdb; pdb.set_trace()
if False:
# For `simple` mode, return just [SEP] separators
return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n')
else:
# Otherwise, return an observation with tags mapped to specific, unique separators
observation = ''
option_type = ''
options = {}
asins = []
cnt = 0
prod_cnt = 0
just_prod = 0
for t in visible_texts:
if t == '\n': continue
if t.replace('\n', '').replace('\\n', '').replace(' ', '') == '': continue
# if t.startswith('Instruction:') and page_type != 'init': continue
# print(t.parent.name, t)
if t.parent.name == 'button': # button
processed_t = f'\n[{t}] '
elif t.parent.name == 'label': # options
if f"'{t}'" in url: # type: ignore
processed_t = f'[[{t}]]'
# observation = f'You have clicked {t}.\n' + observation
else:
processed_t = f'[{t}]'
options[str(t)] = option_type
# options[option_type] = options.get(option_type, []) + [str(t)]
elif t.parent.get('class') == ["product-link"]: # product asins
processed_t = f'\n[{t}] '
if prod_cnt >= 3:
processed_t = ''
prod_cnt += 1
asins.append(str(t))
just_prod = 0
else: # regular, unclickable text
processed_t = '\n' + str(t) + ' '
if cnt < 2 and page_type != 'init': processed_t = ''
if just_prod <= 2 and prod_cnt >= 4: processed_t = ''
option_type = str(t)
cnt += 1
just_prod += 1
observation += processed_t
info = {}
if options:
info['option_types'] = options
if asins:
info['asins'] = asins
if 'Your score (min 0.0, max 1.0)' in visible_texts:
idx = visible_texts.index('Your score (min 0.0, max 1.0)')
info['reward'] = float(visible_texts[idx + 1])
observation = 'Your score (min 0.0, max 1.0): ' + (visible_texts[idx + 1])
return clean_str(observation), info
class webshopEnv:
def __init__(self):
self.sessions = {}
def step(self, session, action):
done = False
observation_ = None
if action == 'reset':
self.sessions[session] = {'session': session, 'page_type': 'init'}
elif action.startswith('think['):
observation = 'OK.'
elif action.startswith('search['):
assert self.sessions[session]['page_type'] == 'init'
query = action[7:-1]
self.sessions[session] = {'session': session, 'page_type': 'search',
'query_string': query, 'page_num': 1}
elif action.startswith('click['):
button = action[6:-1]
if button == 'Buy Now':
assert self.sessions[session]['page_type'] == 'item'
self.sessions[session]['page_type'] = 'end'
done = True
elif button == 'Back to Search':
assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item']
self.sessions[session] = {'session': session, 'page_type': 'init'}
elif button == 'Next >':
assert False # ad hoc page limitation
assert self.sessions[session]['page_type'] == 'search'
self.sessions[session]['page_num'] += 1
elif button == '< Prev':
assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item']
if self.sessions[session]['page_type'] == 'search':
assert False
self.sessions[session]['page_num'] -= 1
elif self.sessions[session]['page_type'] == 'item_sub':
self.sessions[session]['page_type'] = 'item'
elif self.sessions[session]['page_type'] == 'item':
self.sessions[session]['page_type'] = 'search'
self.sessions[session]['options'] = {}
elif button in ACTION_TO_TEMPLATE:
assert self.sessions[session]['page_type'] == 'item'
self.sessions[session]['page_type'] = 'item_sub'
self.sessions[session]['subpage'] = button
else:
if self.sessions[session]['page_type'] == 'search':
assert button in self.sessions[session].get('asins', []) # must be asins
self.sessions[session]['page_type'] = 'item'
self.sessions[session]['asin'] = button
elif self.sessions[session]['page_type'] == 'item':
assert 'option_types' in self.sessions[session]
assert button in self.sessions[session]['option_types'], (button, self.sessions[session]['option_types']) # must be options
option_type = self.sessions[session]['option_types'][button]
if not 'options' in self.sessions[session]:
self.sessions[session]['options'] = {}
self.sessions[session]['options'][option_type] = button
observation_ = f'You have clicked {button}.'
else:
assert False
observation, info = webshop_text(**self.sessions[session])
if observation_:
observation = observation_
self.sessions[session].update(info)
reward = info.get('reward', 0.0)
return observation, reward, done
def webshop_run(idx, env, base_prompt, memory: List[str], to_print=True) -> Tuple[EnvironmentHistory, bool]:
action = 'reset'
init_prompt = base_prompt
prompt = ''
res = env.step(idx, action)
observation = res[0]
if len(memory) > 3:
env_history = EnvironmentHistory(base_prompt, observation, memory[-3:], [])
else:
env_history = EnvironmentHistory(base_prompt, observation, memory, [])
env_history.reset()
for i in range(15):
env_history.add("action", action)
try:
res = env.step(idx, action)
observation = res[0]
except AssertionError:
observation = 'Invalid action!'
if action.startswith('think'):
observation = 'OK.'
if to_print:
print(f'Action: {action}\nObservation: {observation}\n')
sys.stdout.flush()
if i:
prompt += f' {action}\nObservation: {observation}\n\nAction:'
else:
prompt += f'{observation}\n\nAction:'
env_history.add("observation", observation)
# if done, check if reward is complete value
if res[2]:
print(res)
return env_history, res[1] == 1.0
action = llm(init_prompt + prompt[-(6400-len(init_prompt)):], stop=['\n']).lstrip(' ')
return env_history, False
def run_trial(
trial_log_path: str,
world_log_path: str,
trial_idx: int,
env_configs: List[Dict[str, Any]],
use_memory: bool
) -> List[Dict[str, Any]]:
env = webshopEnv()
num_successes: int = 0
num_additional_successes: int = 0
num_envs: int = len(env_configs)
for z, env_config in enumerate(env_configs):
if env_config["is_success"]:
num_successes += 1
# log to world log
with open(world_log_path, 'a') as wf:
wf.write(f'Environment #{z} Trial #{trial_idx}: SUCCESS\n')
with open(trial_log_path, 'a') as wf:
wf.write(f'\n#####\n\nEnvironment #{z}: Success\n\n#####\n')
continue
try:
final_env_history, is_success = webshop_run(f'fixed_{z}', env, BASE_PROMPT, env_config["memory"] if use_memory else [], to_print=True)
if is_success:
status_str: str = f'Environment #{z} Trial #{trial_idx}: SUCCESS'
env_configs[z]["is_success"] = True
num_successes += 1
num_additional_successes += 1
else:
status_str: str = f'Environment #{z} Trial #{trial_idx}: FAIL'
# log env results to trial log
with open(trial_log_path, 'a') as wf:
wf.write(f'\n#####\n\nEnvironment #{z}:\n{str(final_env_history)}\n\nSTATUS: {"OK" if is_success else "FAIL"}\n\n#####\n')
except AssertionError:
status_str: str = f'Environment #{z} Trial #{trial_idx}: FAIL'
# log env results to trial log
with open(trial_log_path, 'a') as wf:
wf.write(f'\n#####\n\nEnvironment #{z}:\nAssertion Error\n\nSTATUS: FAIL\n\n#####\n')
# log to world log
with open(world_log_path, 'a') as f:
f.write(status_str + '\n')
# log trial results to trial and world logs
log_str: str = f"""
-----
SUCCESS: {num_successes}
ADDITIONAL SUCCESS: {num_additional_successes}
FAIL: {num_envs - num_successes}
TOTAL: {num_envs}
ACCURACY: {round(num_successes / num_envs, 2)}
-----"""
with open(trial_log_path, 'a') as wf:
wf.write(log_str)
with open(world_log_path, 'a') as wf:
wf.write(log_str + '\n')
return env_configs