mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
246 lines
6.8 KiB
Plaintext
246 lines
6.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4b089493",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Simulated Environment: Gymnasium\n",
|
|
"\n",
|
|
"For many applications of LLM agents, the environment is real (internet, database, REPL, etc). However, we can also define agents to interact in simulated environments like text-based games. This is an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym))."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "f36427cf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install gymnasium"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "f9bd38b4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"import inspect\n",
|
|
"import tenacity\n",
|
|
"\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.schema import (\n",
|
|
" AIMessage,\n",
|
|
" HumanMessage,\n",
|
|
" SystemMessage,\n",
|
|
" BaseMessage,\n",
|
|
")\n",
|
|
"from langchain.output_parsers import RegexParser"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e222e811",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define the agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "870c24bc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GymnasiumAgent:\n",
|
|
" @classmethod\n",
|
|
" def get_docs(cls, env):\n",
|
|
" return env.unwrapped.__doc__\n",
|
|
"\n",
|
|
" def __init__(self, model, env):\n",
|
|
" self.model = model\n",
|
|
" self.env = env\n",
|
|
" self.docs = self.get_docs(env)\n",
|
|
"\n",
|
|
" self.instructions = \"\"\"\n",
|
|
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
|
|
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
|
|
"\n",
|
|
"Observation: <observation>\n",
|
|
"Reward: <reward>\n",
|
|
"Termination: <termination>\n",
|
|
"Truncation: <truncation>\n",
|
|
"Return: <sum_of_rewards>\n",
|
|
"\n",
|
|
"You will respond with an action, formatted as:\n",
|
|
"\n",
|
|
"Action: <action>\n",
|
|
"\n",
|
|
"where you replace <action> with your actual action.\n",
|
|
"Do nothing else but return the action.\n",
|
|
"\"\"\"\n",
|
|
" self.action_parser = RegexParser(\n",
|
|
" regex=r\"Action: (.*)\", output_keys=[\"action\"], default_output_key=\"action\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" self.message_history = []\n",
|
|
" self.ret = 0\n",
|
|
"\n",
|
|
" def random_action(self):\n",
|
|
" action = self.env.action_space.sample()\n",
|
|
" return action\n",
|
|
"\n",
|
|
" def reset(self):\n",
|
|
" self.message_history = [\n",
|
|
" SystemMessage(content=self.docs),\n",
|
|
" SystemMessage(content=self.instructions),\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
|
|
" self.ret += rew\n",
|
|
"\n",
|
|
" obs_message = f\"\"\"\n",
|
|
"Observation: {obs}\n",
|
|
"Reward: {rew}\n",
|
|
"Termination: {term}\n",
|
|
"Truncation: {trunc}\n",
|
|
"Return: {self.ret}\n",
|
|
" \"\"\"\n",
|
|
" self.message_history.append(HumanMessage(content=obs_message))\n",
|
|
" return obs_message\n",
|
|
"\n",
|
|
" def _act(self):\n",
|
|
" act_message = self.model(self.message_history)\n",
|
|
" self.message_history.append(act_message)\n",
|
|
" action = int(self.action_parser.parse(act_message.content)[\"action\"])\n",
|
|
" return action\n",
|
|
"\n",
|
|
" def act(self):\n",
|
|
" try:\n",
|
|
" for attempt in tenacity.Retrying(\n",
|
|
" stop=tenacity.stop_after_attempt(2),\n",
|
|
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
|
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
|
" before_sleep=lambda retry_state: print(\n",
|
|
" f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"\n",
|
|
" ),\n",
|
|
" ):\n",
|
|
" with attempt:\n",
|
|
" action = self._act()\n",
|
|
" except tenacity.RetryError as e:\n",
|
|
" action = self.random_action()\n",
|
|
" return action"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2e76d22c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Initialize the simulated environment and agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9e902cfd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = gym.make(\"Blackjack-v1\")\n",
|
|
"agent = GymnasiumAgent(model=ChatOpenAI(temperature=0.2), env=env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e2c12b15",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Main loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ad361210",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"Observation: (15, 4, 0)\n",
|
|
"Reward: 0\n",
|
|
"Termination: False\n",
|
|
"Truncation: False\n",
|
|
"Return: 0\n",
|
|
" \n",
|
|
"Action: 1\n",
|
|
"\n",
|
|
"Observation: (25, 4, 0)\n",
|
|
"Reward: -1.0\n",
|
|
"Termination: True\n",
|
|
"Truncation: False\n",
|
|
"Return: -1.0\n",
|
|
" \n",
|
|
"break True False\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"observation, info = env.reset()\n",
|
|
"agent.reset()\n",
|
|
"\n",
|
|
"obs_message = agent.observe(observation)\n",
|
|
"print(obs_message)\n",
|
|
"\n",
|
|
"while True:\n",
|
|
" action = agent.act()\n",
|
|
" observation, reward, termination, truncation, info = env.step(action)\n",
|
|
" obs_message = agent.observe(observation, reward, termination, truncation, info)\n",
|
|
" print(f\"Action: {action}\")\n",
|
|
" print(obs_message)\n",
|
|
"\n",
|
|
" if termination or truncation:\n",
|
|
" print(\"break\", termination, truncation)\n",
|
|
" break\n",
|
|
"env.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "58a13e9c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|