mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
838 lines
29 KiB
Plaintext
838 lines
29 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "4b089493",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Multi-Agent Simulated Environment: Petting Zoo\n",
|
||
|
"\n",
|
||
|
"In this example, we show how to define multi-agent simulations with simulated environments. Like [ours single-agent example with Gymnasium](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html), we create an agent-environment loop with an externally defined environment. The main difference is that we now implement this kind of interaction loop with multiple agents instead. We will use the [Petting Zoo](https://pettingzoo.farama.org/) library, which is the multi-agent counterpart to [Gymnasium](https://gymnasium.farama.org/)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "10091333",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Install `pettingzoo` and other dependencies"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "0a3fde66",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Requirement already satisfied: pettingzoo in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (1.22.3)\n",
|
||
|
"Requirement already satisfied: pygame in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (2.4.0)\n",
|
||
|
"Requirement already satisfied: rlcard in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (1.2.0)\n",
|
||
|
"Requirement already satisfied: gymnasium>=0.26.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from pettingzoo) (0.28.1)\n",
|
||
|
"Requirement already satisfied: numpy>=1.18.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from pettingzoo) (1.24.3)\n",
|
||
|
"Requirement already satisfied: termcolor in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from rlcard) (2.2.0)\n",
|
||
|
"Requirement already satisfied: typing-extensions>=4.3.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (4.5.0)\n",
|
||
|
"Requirement already satisfied: jax-jumpy>=1.0.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (1.0.0)\n",
|
||
|
"Requirement already satisfied: farama-notifications>=0.0.1 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (0.0.4)\n",
|
||
|
"Requirement already satisfied: cloudpickle>=1.2.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (2.2.1)\n",
|
||
|
"Requirement already satisfied: importlib-metadata>=4.8.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (6.0.1)\n",
|
||
|
"Requirement already satisfied: zipp>=0.5 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from importlib-metadata>=4.8.0->gymnasium>=0.26.0->pettingzoo) (3.15.0)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"!pip install pettingzoo pygame rlcard"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "5fbe130c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Import modules"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "42cd2e5d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import collections\n",
|
||
|
"import inspect\n",
|
||
|
"import tenacity\n",
|
||
|
"\n",
|
||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||
|
"from langchain.schema import (\n",
|
||
|
" HumanMessage,\n",
|
||
|
" SystemMessage,\n",
|
||
|
")\n",
|
||
|
"from langchain.output_parsers import RegexParser"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e222e811",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## `GymnasiumAgent`\n",
|
||
|
"Here we reproduce the same `GymnasiumAgent` defined from [our Gymnasium example](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html). If after multiple retries it does not take a valid action, it simply takes a random action. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "72df0b59",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class GymnasiumAgent():\n",
|
||
|
" @classmethod\n",
|
||
|
" def get_docs(cls, env):\n",
|
||
|
" return env.unwrapped.__doc__\n",
|
||
|
" \n",
|
||
|
" def __init__(self, model, env):\n",
|
||
|
" self.model = model\n",
|
||
|
" self.env = env\n",
|
||
|
" self.docs = self.get_docs(env)\n",
|
||
|
" \n",
|
||
|
" self.instructions = \"\"\"\n",
|
||
|
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
|
||
|
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
|
||
|
"\n",
|
||
|
"Observation: <observation>\n",
|
||
|
"Reward: <reward>\n",
|
||
|
"Termination: <termination>\n",
|
||
|
"Truncation: <truncation>\n",
|
||
|
"Return: <sum_of_rewards>\n",
|
||
|
"\n",
|
||
|
"You will respond with an action, formatted as:\n",
|
||
|
"\n",
|
||
|
"Action: <action>\n",
|
||
|
"\n",
|
||
|
"where you replace <action> with your actual action.\n",
|
||
|
"Do nothing else but return the action.\n",
|
||
|
"\"\"\"\n",
|
||
|
" self.action_parser = RegexParser(\n",
|
||
|
" regex=r\"Action: (.*)\", \n",
|
||
|
" output_keys=['action'], \n",
|
||
|
" default_output_key='action')\n",
|
||
|
" \n",
|
||
|
" self.message_history = []\n",
|
||
|
" self.ret = 0\n",
|
||
|
" \n",
|
||
|
" def random_action(self):\n",
|
||
|
" action = self.env.action_space.sample()\n",
|
||
|
" return action\n",
|
||
|
" \n",
|
||
|
" def reset(self):\n",
|
||
|
" self.message_history = [\n",
|
||
|
" SystemMessage(content=self.docs),\n",
|
||
|
" SystemMessage(content=self.instructions),\n",
|
||
|
" ]\n",
|
||
|
" \n",
|
||
|
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
|
||
|
" self.ret += rew\n",
|
||
|
" \n",
|
||
|
" obs_message = f\"\"\"\n",
|
||
|
"Observation: {obs}\n",
|
||
|
"Reward: {rew}\n",
|
||
|
"Termination: {term}\n",
|
||
|
"Truncation: {trunc}\n",
|
||
|
"Return: {self.ret}\n",
|
||
|
" \"\"\"\n",
|
||
|
" self.message_history.append(HumanMessage(content=obs_message))\n",
|
||
|
" return obs_message\n",
|
||
|
" \n",
|
||
|
" def _act(self):\n",
|
||
|
" act_message = self.model(self.message_history)\n",
|
||
|
" self.message_history.append(act_message)\n",
|
||
|
" action = int(self.action_parser.parse(act_message.content)['action'])\n",
|
||
|
" return action\n",
|
||
|
" \n",
|
||
|
" def act(self):\n",
|
||
|
" try:\n",
|
||
|
" for attempt in tenacity.Retrying(\n",
|
||
|
" stop=tenacity.stop_after_attempt(2),\n",
|
||
|
" wait=tenacity.wait_none(), # No waiting time between retries\n",
|
||
|
" retry=tenacity.retry_if_exception_type(ValueError),\n",
|
||
|
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
|
||
|
" ):\n",
|
||
|
" with attempt:\n",
|
||
|
" action = self._act()\n",
|
||
|
" except tenacity.RetryError as e:\n",
|
||
|
" action = self.random_action()\n",
|
||
|
" return action"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "df51e302",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Main loop"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "0f07d7cf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def main(agents, env):\n",
|
||
|
" env.reset()\n",
|
||
|
"\n",
|
||
|
" for name, agent in agents.items():\n",
|
||
|
" agent.reset()\n",
|
||
|
"\n",
|
||
|
" for agent_name in env.agent_iter():\n",
|
||
|
" observation, reward, termination, truncation, info = env.last()\n",
|
||
|
" obs_message = agents[agent_name].observe(\n",
|
||
|
" observation, reward, termination, truncation, info)\n",
|
||
|
" print(obs_message)\n",
|
||
|
" if termination or truncation:\n",
|
||
|
" action = None\n",
|
||
|
" else:\n",
|
||
|
" action = agents[agent_name].act()\n",
|
||
|
" print(f'Action: {action}')\n",
|
||
|
" env.step(action)\n",
|
||
|
" env.close()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b4b0e921",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## `PettingZooAgent`\n",
|
||
|
"\n",
|
||
|
"The `PettingZooAgent` extends the `GymnasiumAgent` to the multi-agent setting. The main differences are:\n",
|
||
|
"- `PettingZooAgent` takes in a `name` argument to identify it among multiple agents\n",
|
||
|
"- the function `get_docs` is implemented differently because the `PettingZoo` repo structure is structured differently from the `Gymnasium` repo"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "f132c92a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class PettingZooAgent(GymnasiumAgent):\n",
|
||
|
" @classmethod\n",
|
||
|
" def get_docs(cls, env):\n",
|
||
|
" return inspect.getmodule(env.unwrapped).__doc__\n",
|
||
|
" \n",
|
||
|
" def __init__(self, name, model, env):\n",
|
||
|
" super().__init__(model, env)\n",
|
||
|
" self.name = name\n",
|
||
|
" \n",
|
||
|
" def random_action(self):\n",
|
||
|
" action = self.env.action_space(self.name).sample()\n",
|
||
|
" return action"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "a27f8a5d",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Rock, Paper, Scissors\n",
|
||
|
"We can now run a simulation of a multi-agent rock, paper, scissors game using the `PettingZooAgent`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "bd1256c0",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"Observation: 3\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: 3\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: 1\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 2\n",
|
||
|
"\n",
|
||
|
"Observation: 1\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: 1\n",
|
||
|
"Reward: 1\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 1\n",
|
||
|
" \n",
|
||
|
"Action: 0\n",
|
||
|
"\n",
|
||
|
"Observation: 2\n",
|
||
|
"Reward: -1\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: -1\n",
|
||
|
" \n",
|
||
|
"Action: 0\n",
|
||
|
"\n",
|
||
|
"Observation: 0\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: 1\n",
|
||
|
" \n",
|
||
|
"Action: None\n",
|
||
|
"\n",
|
||
|
"Observation: 0\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: -1\n",
|
||
|
" \n",
|
||
|
"Action: None\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from pettingzoo.classic import rps_v2\n",
|
||
|
"env = rps_v2.env(max_cycles=3, render_mode=\"human\")\n",
|
||
|
"agents = {name: PettingZooAgent(name=name, model=ChatOpenAI(temperature=1), env=env) for name in env.possible_agents}\n",
|
||
|
"main(agents, env)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "fbcee258",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## `ActionMaskAgent`\n",
|
||
|
"\n",
|
||
|
"Some `PettingZoo` environments provide an `action_mask` to tell the agent which actions are valid. The `ActionMaskAgent` subclasses `PettingZooAgent` to use information from the `action_mask` to select actions."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "bd33250a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ActionMaskAgent(PettingZooAgent):\n",
|
||
|
" def __init__(self, name, model, env):\n",
|
||
|
" super().__init__(name, model, env)\n",
|
||
|
" self.obs_buffer = collections.deque(maxlen=1)\n",
|
||
|
" \n",
|
||
|
" def random_action(self):\n",
|
||
|
" obs = self.obs_buffer[-1]\n",
|
||
|
" action = self.env.action_space(self.name).sample(obs[\"action_mask\"])\n",
|
||
|
" return action\n",
|
||
|
" \n",
|
||
|
" def reset(self):\n",
|
||
|
" self.message_history = [\n",
|
||
|
" SystemMessage(content=self.docs),\n",
|
||
|
" SystemMessage(content=self.instructions),\n",
|
||
|
" ]\n",
|
||
|
" \n",
|
||
|
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
|
||
|
" self.obs_buffer.append(obs)\n",
|
||
|
" return super().observe(obs, rew, term, trunc, info)\n",
|
||
|
" \n",
|
||
|
" def _act(self):\n",
|
||
|
" valid_action_instruction = \"Generate a valid action given by the indices of the `action_mask` that are not 0, according to the action formatting rules.\"\n",
|
||
|
" self.message_history.append(HumanMessage(content=valid_action_instruction))\n",
|
||
|
" return super()._act()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2e76d22c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Tic-Tac-Toe\n",
|
||
|
"Here is an example of a Tic-Tac-Toe game that uses the `ActionMaskAgent`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "9e902cfd",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 0\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" - | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" - | - | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[0, 1],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" - | - | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 2\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[0, 1],\n",
|
||
|
" [1, 0],\n",
|
||
|
" [0, 1]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 3\n",
|
||
|
" | | \n",
|
||
|
" X | O | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | - | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [1, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 1],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 4\n",
|
||
|
" | | \n",
|
||
|
" X | O | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | X | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" X | - | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[0, 1],\n",
|
||
|
" [1, 0],\n",
|
||
|
" [0, 1]],\n",
|
||
|
"\n",
|
||
|
" [[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [0, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 5\n",
|
||
|
" | | \n",
|
||
|
" X | O | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | X | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" X | O | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [1, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 1],\n",
|
||
|
" [1, 0],\n",
|
||
|
" [0, 1]],\n",
|
||
|
"\n",
|
||
|
" [[0, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 6\n",
|
||
|
" | | \n",
|
||
|
" X | O | X \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" O | X | - \n",
|
||
|
"_____|_____|_____\n",
|
||
|
" | | \n",
|
||
|
" X | O | - \n",
|
||
|
" | | \n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[0, 1],\n",
|
||
|
" [1, 0],\n",
|
||
|
" [0, 1]],\n",
|
||
|
"\n",
|
||
|
" [[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [1, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 1],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: -1\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: -1\n",
|
||
|
" \n",
|
||
|
"Action: None\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([[[1, 0],\n",
|
||
|
" [0, 1],\n",
|
||
|
" [1, 0]],\n",
|
||
|
"\n",
|
||
|
" [[0, 1],\n",
|
||
|
" [1, 0],\n",
|
||
|
" [0, 1]],\n",
|
||
|
"\n",
|
||
|
" [[1, 0],\n",
|
||
|
" [0, 0],\n",
|
||
|
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 1\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 1\n",
|
||
|
" \n",
|
||
|
"Action: None\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from pettingzoo.classic import tictactoe_v3\n",
|
||
|
"env = tictactoe_v3.env(render_mode=\"human\")\n",
|
||
|
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
|
||
|
"main(agents, env)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8728ac2a",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Texas Hold'em No Limit\n",
|
||
|
"Here is an example of a Texas Hold'em No Limit game that uses the `ActionMaskAgent`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "e350c62b",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 1., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 1\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 0\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
|
||
|
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 2\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
|
||
|
" 0., 2., 6.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 2\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 2., 8.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 3\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
|
||
|
" 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 6., 20.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 4\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
|
||
|
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
|
||
|
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: False\n",
|
||
|
"Truncation: False\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: 4\n",
|
||
|
"[WARNING]: Illegal move made, game terminating with current player losing. \n",
|
||
|
"obs['action_mask'] contains a mask of all legal moves that can be chosen.\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
|
||
|
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
|
||
|
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
|
||
|
"Reward: -1.0\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: -1.0\n",
|
||
|
" \n",
|
||
|
"Action: None\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
|
||
|
" 0., 0., 0., 0., 1., 0., 0., 0., 20., 100.],\n",
|
||
|
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: None\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 100., 100.],\n",
|
||
|
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: None\n",
|
||
|
"\n",
|
||
|
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
|
||
|
" 0., 0., 0., 0., 0., 0., 0., 0., 2., 100.],\n",
|
||
|
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
|
||
|
"Reward: 0\n",
|
||
|
"Termination: True\n",
|
||
|
"Truncation: True\n",
|
||
|
"Return: 0\n",
|
||
|
" \n",
|
||
|
"Action: None\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from pettingzoo.classic import texas_holdem_no_limit_v6\n",
|
||
|
"env = texas_holdem_no_limit_v6.env(num_players=4, render_mode=\"human\")\n",
|
||
|
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
|
||
|
"main(agents, env)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|