forked from Archives/langchain
244 lines
6.8 KiB
Plaintext
244 lines
6.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4b089493",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Simulated Environment: Gymnasium\n",
|
|
"\n",
|
|
"For many applications of LLM agents, the environment is real (internet, database, REPL, etc). However, we can also define agents to interact in simulated environments like text-based games. This is an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym))."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f36427cf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install gymnasium"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"id": "f9bd38b4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"import tenacity\n",
|
|
"\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.schema import (\n",
|
|
" AIMessage,\n",
|
|
" HumanMessage,\n",
|
|
" SystemMessage,\n",
|
|
" BaseMessage,\n",
|
|
")\n",
|
|
"from langchain.output_parsers import RegexParser"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e222e811",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define the agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 51,
|
|
"id": "870c24bc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_docs(env):\n",
|
|
" while 'env' in dir(env):\n",
|
|
" env = env.env\n",
|
|
" return env.__doc__\n",
|
|
"\n",
|
|
"class Agent():\n",
|
|
" def __init__(self, model, env):\n",
|
|
" self.model = model\n",
|
|
" self.docs = get_docs(env)\n",
|
|
" \n",
|
|
" self.instructions = \"\"\"\n",
|
|
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
|
|
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
|
|
"\n",
|
|
"Observation: <observation>\n",
|
|
"Reward: <reward>\n",
|
|
"Termination: <termination>\n",
|
|
"Truncation: <truncation>\n",
|
|
"Return: <sum_of_rewards>\n",
|
|
"\n",
|
|
"You will respond with an action, formatted as:\n",
|
|
"\n",
|
|
"Action: <action>\n",
|
|
"\n",
|
|
"where you replace <action> with your actual action.\n",
|
|
"Do nothing else but return the action.\n",
|
|
"\"\"\"\n",
|
|
" self.action_parser = RegexParser(\n",
|
|
" regex=r\"Action: (.*)\", \n",
|
|
" output_keys=['action'], \n",
|
|
" default_output_key='action')\n",
|
|
" \n",
|
|
" self.message_history = []\n",
|
|
" self.ret = 0\n",
|
|
" \n",
|
|
" def reset(self, obs):\n",
|
|
" self.message_history = [\n",
|
|
" SystemMessage(content=self.docs),\n",
|
|
" SystemMessage(content=self.instructions),\n",
|
|
" ]\n",
|
|
" obs_message = f\"\"\"\n",
|
|
"Observation: {obs}\n",
|
|
"Reward: 0\n",
|
|
"Termination: False\n",
|
|
"Truncation: False\n",
|
|
"Return: 0\n",
|
|
" \"\"\"\n",
|
|
" self.message_history.append(HumanMessage(content=obs_message))\n",
|
|
" return obs_message\n",
|
|
" \n",
|
|
" def observe(self, obs, rew, term, trunc, info):\n",
|
|
" self.ret += rew\n",
|
|
" \n",
|
|
" obs_message = f\"\"\"\n",
|
|
"Observation: {obs}\n",
|
|
"Reward: {rew}\n",
|
|
"Termination: {term}\n",
|
|
"Truncation: {trunc}\n",
|
|
"Return: {self.ret}\n",
|
|
" \"\"\"\n",
|
|
" self.message_history.append(HumanMessage(content=obs_message))\n",
|
|
" return obs_message\n",
|
|
" \n",
|
|
" @tenacity.retry(stop=tenacity.stop_after_attempt(2),\n",
|
|
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
|
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
|
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
|
|
" retry_error_callback=lambda retry_state: 0) # Default value when all retries are exhausted\n",
|
|
" def act(self):\n",
|
|
" act_message = self.model(self.message_history)\n",
|
|
" self.message_history.append(act_message)\n",
|
|
" action = int(self.action_parser.parse(act_message.content)['action'])\n",
|
|
" return action"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e76d22c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Initialize the simulated environment and agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"id": "9e902cfd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = gym.make(\"Blackjack-v1\")\n",
|
|
"agent = Agent(model=ChatOpenAI(temperature=0.2), env=env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e2c12b15",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Main loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"id": "ad361210",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"Observation: (10, 3, 0)\n",
|
|
"Reward: 0\n",
|
|
"Termination: False\n",
|
|
"Truncation: False\n",
|
|
"Return: 0\n",
|
|
" \n",
|
|
"Action: 1\n",
|
|
"\n",
|
|
"Observation: (18, 3, 0)\n",
|
|
"Reward: 0.0\n",
|
|
"Termination: False\n",
|
|
"Truncation: False\n",
|
|
"Return: 0.0\n",
|
|
" \n",
|
|
"Action: 0\n",
|
|
"\n",
|
|
"Observation: (18, 3, 0)\n",
|
|
"Reward: 1.0\n",
|
|
"Termination: True\n",
|
|
"Truncation: False\n",
|
|
"Return: 1.0\n",
|
|
" \n",
|
|
"break True False\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"observation, info = env.reset()\n",
|
|
"obs_message = agent.reset(observation)\n",
|
|
"print(obs_message)\n",
|
|
"while True:\n",
|
|
" action = agent.act()\n",
|
|
" observation, reward, termination, truncation, info = env.step(action)\n",
|
|
" obs_message = agent.observe(observation, reward, termination, truncation, info)\n",
|
|
" print(f'Action: {action}')\n",
|
|
" print(obs_message)\n",
|
|
" if termination or truncation:\n",
|
|
" print('break', termination, truncation)\n",
|
|
" break\n",
|
|
"env.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "58a13e9c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|