You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/use_cases/agent_simulations/gymnasium.ipynb

244 lines
6.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "4b089493",
"metadata": {},
"source": [
"# Simulated Environment: Gymnasium\n",
"\n",
"For many applications of LLM agents, the environment is real (internet, database, REPL, etc). However, we can also define agents to interact in simulated environments like text-based games. This is an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym))."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f36427cf",
"metadata": {},
"outputs": [],
"source": [
"!pip install gymnasium"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "f9bd38b4",
"metadata": {},
"outputs": [],
"source": [
"import gymnasium as gym\n",
"import tenacity\n",
"\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema import (\n",
" AIMessage,\n",
" HumanMessage,\n",
" SystemMessage,\n",
" BaseMessage,\n",
")\n",
"from langchain.output_parsers import RegexParser"
]
},
{
"cell_type": "markdown",
"id": "e222e811",
"metadata": {},
"source": [
"## Define the agent"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "870c24bc",
"metadata": {},
"outputs": [],
"source": [
"def get_docs(env):\n",
" while 'env' in dir(env):\n",
" env = env.env\n",
" return env.__doc__\n",
"\n",
"class Agent():\n",
" def __init__(self, model, env):\n",
" self.model = model\n",
" self.docs = get_docs(env)\n",
" \n",
" self.instructions = \"\"\"\n",
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
"\n",
"Observation: <observation>\n",
"Reward: <reward>\n",
"Termination: <termination>\n",
"Truncation: <truncation>\n",
"Return: <sum_of_rewards>\n",
"\n",
"You will respond with an action, formatted as:\n",
"\n",
"Action: <action>\n",
"\n",
"where you replace <action> with your actual action.\n",
"Do nothing else but return the action.\n",
"\"\"\"\n",
" self.action_parser = RegexParser(\n",
" regex=r\"Action: (.*)\", \n",
" output_keys=['action'], \n",
" default_output_key='action')\n",
" \n",
" self.message_history = []\n",
" self.ret = 0\n",
" \n",
" def reset(self, obs):\n",
" self.message_history = [\n",
" SystemMessage(content=self.docs),\n",
" SystemMessage(content=self.instructions),\n",
" ]\n",
" obs_message = f\"\"\"\n",
"Observation: {obs}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \"\"\"\n",
" self.message_history.append(HumanMessage(content=obs_message))\n",
" return obs_message\n",
" \n",
" def observe(self, obs, rew, term, trunc, info):\n",
" self.ret += rew\n",
" \n",
" obs_message = f\"\"\"\n",
"Observation: {obs}\n",
"Reward: {rew}\n",
"Termination: {term}\n",
"Truncation: {trunc}\n",
"Return: {self.ret}\n",
" \"\"\"\n",
" self.message_history.append(HumanMessage(content=obs_message))\n",
" return obs_message\n",
" \n",
" @tenacity.retry(stop=tenacity.stop_after_attempt(2),\n",
" wait=tenacity.wait_none(), # No waiting time between retries\n",
" retry=tenacity.retry_if_exception_type(ValueError),\n",
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
" retry_error_callback=lambda retry_state: 0) # Default value when all retries are exhausted\n",
" def act(self):\n",
" act_message = self.model(self.message_history)\n",
" self.message_history.append(act_message)\n",
" action = int(self.action_parser.parse(act_message.content)['action'])\n",
" return action"
]
},
{
"cell_type": "markdown",
"id": "2e76d22c",
"metadata": {},
"source": [
"## Initialize the simulated environment and agent"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "9e902cfd",
"metadata": {},
"outputs": [],
"source": [
"env = gym.make(\"Blackjack-v1\")\n",
"agent = Agent(model=ChatOpenAI(temperature=0.2), env=env)"
]
},
{
"cell_type": "markdown",
"id": "e2c12b15",
"metadata": {},
"source": [
"## Main loop"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "ad361210",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: (10, 3, 0)\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: (18, 3, 0)\n",
"Reward: 0.0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0.0\n",
" \n",
"Action: 0\n",
"\n",
"Observation: (18, 3, 0)\n",
"Reward: 1.0\n",
"Termination: True\n",
"Truncation: False\n",
"Return: 1.0\n",
" \n",
"break True False\n"
]
}
],
"source": [
"observation, info = env.reset()\n",
"obs_message = agent.reset(observation)\n",
"print(obs_message)\n",
"while True:\n",
" action = agent.act()\n",
" observation, reward, termination, truncation, info = env.step(action)\n",
" obs_message = agent.observe(observation, reward, termination, truncation, info)\n",
" print(f'Action: {action}')\n",
" print(obs_message)\n",
" if termination or truncation:\n",
" print('break', termination, truncation)\n",
" break\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58a13e9c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}