alfworld chat

main
Noah Shinn 10 months ago
parent ff7bbeb22b
commit 8a2aa8afb8

@ -8,6 +8,7 @@ import openai
import importlib
import alfworld
import alfworld.agents.environment
from utils import Model, get_chat, get_completion
from env_history import EnvironmentHistory
from typing import List, Dict, Any, Tuple
@ -20,24 +21,17 @@ with open(os.path.join(FOLDER, PROMPT_FILE), 'r') as f:
with open('./challenge_few_shot_examples.txt', 'r') as f:
challenge_examples = f.read()
def llm(prompt, stop=["\n"]):
def llm(prompt: str, model: Model, stop: List[str] = ["\n"]):
try:
cur_try = 0
while cur_try < 6:
response = openai.Completion.create(
model="text-davinci-002",
prompt=prompt,
temperature=cur_try * 0.2,
max_tokens=100,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop
)
text = response["choices"][0]["text"]
if model == "text-davinci-003":
text = get_completion(prompt=prompt, temperature=cur_try * 0.2)
else:
text = get_chat(prompt=prompt, model=model, temperature=cur_try * 0.2)
# dumb way to do this
if len(text.strip()) >= 5:
return response["choices"][0]["text"]
return text
cur_try += 1
return ""
except Exception as e:
@ -51,7 +45,7 @@ def process_ob(ob):
ob = ob[ob.find('. ')+2:]
return ob
def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='') -> Tuple[EnvironmentHistory, bool]:
def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='', model: Model = "text-davinci-003") -> Tuple[EnvironmentHistory, bool]:
if len(memory) > 3:
env_history = EnvironmentHistory(base_prompt, ob, memory[-3:], [])
else:
@ -97,7 +91,8 @@ def run_trial(
world_log_path: str,
trial_idx: int,
env_configs: List[Dict[str, Any]],
use_memory: bool
use_memory: bool,
model: Model,
) -> List[Dict[str, Any]]:
importlib.reload(alfworld)
importlib.reload(alfworld.agents.environment)
@ -133,7 +128,7 @@ def run_trial(
for i, (k, v) in enumerate(PREFIXES.items()):
if name.startswith(k):
base_prompt = 'Interact with a household to solve a task. Here are two examples.\n' + d[f'react_{v}_1'] + d[f'react_{v}_0']
final_env_history, is_success = alfworld_run(env, base_prompt, env_config["memory"] if use_memory else [], to_print=True, ob=ob)
final_env_history, is_success = alfworld_run(env, base_prompt, env_config["memory"] if use_memory else [], to_print=True, ob=ob, model=model)
# update env config
if is_success:

@ -16,6 +16,7 @@ def get_args():
parser.add_argument("--is_resume", action='store_true', help="To resume run")
parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
parser.add_argument("--start_trial_num", type=int, help="If resume, the start trial num", default=0)
parser.add_argument("--model", type=str, help="The model to use. One of `gpt-4`, `gpt-3.5-turbo`, or `text-davinci-003")
args = parser.parse_args()
@ -96,7 +97,7 @@ def main(args) -> None:
open(trial_env_configs_log_path, 'w').close()
# run trial
run_trial(trial_log_path, world_log_path, trial_idx, env_configs, args.use_memory)
run_trial(trial_log_path, world_log_path, trial_idx, env_configs, args.use_memory, args.model)
# update memory if needed
if args.use_memory:

@ -1,4 +0,0 @@
python main.py \
--num_trials 10 \
--num_envs 134 \
--run_name "base_run_logs" \

@ -2,4 +2,5 @@ python main.py \
--num_trials 10 \
--num_envs 134 \
--run_name "reflexion_run_logs" \
--use_memory
--use_memory \
--model "gpt-3.5-turbo"

@ -0,0 +1,6 @@
python main.py \
--num_trials 10 \
--num_envs 134 \
--run_name "base_run_logs_gpt_35_turbo" \
--model "gpt-3.5-turbo"

@ -6,26 +6,39 @@ from tenacity import (
wait_random_exponential, # type: ignore
)
from typing import Optional, List, Union
from typing import Optional, List, Union, Literal
Model = Literal["gpt-4", "gpt-3.5-turbo", "text-davinci-003"]
openai.api_key = os.getenv('OPENAI_API_KEY')
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def get_completion(prompt: Union[str, List[str]], max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> Union[str, List[str]]:
assert (not is_batched and isinstance(prompt, str)) or (is_batched and isinstance(prompt, list))
def get_completion(prompt: str, temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None) -> str:
response = openai.Completion.create(
model='text-davinci-003',
prompt=prompt,
temperature=0.0,
temperature=temperature,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop_strs,
)
if is_batched:
res: List[str] = [""] * len(prompt)
for choice in response.choices:
res[choice.index] = choice.text
return res
return response.choices[0].text
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def get_chat(prompt: str, model: Model, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
assert model != "text-davinci-003"
messages = [
{
"role": "user",
"content": prompt
}
]
response = openai.Completion.create(
model=model,
messages=messages,
max_tokens=max_tokens,
stop=stop_strs,
)
return response.choices[0].message.content

@ -1,55 +0,0 @@
# Fails 2
def minReverseOperations(self, n: int, p: int, banned: List[int], k: int) -> List[int]:
from collections import deque
banned = set(banned)
arr = tuple(0 if i in banned else (1 if i == p else 0) for i in range(n))
queue = deque([(arr, p, 0)]) # Add a third element to the tuple to store the number of
ans = [-1] * n
visited = set()
while queue:
cur_arr, cur_pos, ops = queue.popleft()
if cur_pos not in visited:
visited.add(cur_pos)
ans[cur_pos] = ops
for i in range(n):
for j in range(i + k, n + 1):
new_arr = cur_arr[:i] + tuple(reversed(cur_arr[i:j])) + cur_arr[j:]
new_pos = new_arr.index(1)
if new_pos not in banned and (new_arr, new_pos) not in visited:
queue.append((new_arr, new_pos, ops + 1))
return ans
# Fails 1
def minReverseOperations(self, n: int, p: int, banned: List[int], k: int) -> List[int]:
from collections import deque
banned = set(banned)
arr = tuple(0 if i in banned else (1 if i == p else 0) for i in range(n))
queue = deque([(arr, p, 0)]) # Add a third element to the tuple to store the number of operations
ans = [-1] * n
visited = set()
while queue:
cur_arr, cur_pos, ops = queue.popleft()
if cur_pos not in visited:
visited.add(cur_pos)
ans[cur_pos] = ops
for i in range(n):
for j in range(i + k, n + 1):
# Check if the subarray to be reversed contains any banned positions
if any(cur_arr[i:x] for x in range(i, j) if x in banned):
continue
new_arr = cur_arr[:i] + tuple(reversed(cur_arr[i:j])) + cur_arr[j:]
new_pos = new_arr.index(1)
if new_pos not in banned and (new_arr, new_pos) not in visited:
queue.append((new_arr, new_pos, ops + 1))
return ans
Loading…
Cancel
Save