2023-05-02 03:24:15 +00:00
{
"cells": [
{
"cell_type": "markdown",
"id": "4b089493",
"metadata": {},
"source": [
"# Multi-Agent Simulated Environment: Petting Zoo\n",
"\n",
"In this example, we show how to define multi-agent simulations with simulated environments. Like [ours single-agent example with Gymnasium](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html), we create an agent-environment loop with an externally defined environment. The main difference is that we now implement this kind of interaction loop with multiple agents instead. We will use the [Petting Zoo](https://pettingzoo.farama.org/) library, which is the multi-agent counterpart to [Gymnasium](https://gymnasium.farama.org/)."
]
},
{
"cell_type": "markdown",
"id": "10091333",
"metadata": {},
"source": [
"## Install `pettingzoo` and other dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a3fde66",
"metadata": {},
2023-05-02 05:05:22 +00:00
"outputs": [],
2023-05-02 03:24:15 +00:00
"source": [
"!pip install pettingzoo pygame rlcard"
]
},
{
"cell_type": "markdown",
"id": "5fbe130c",
"metadata": {},
"source": [
"## Import modules"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "42cd2e5d",
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"import inspect\n",
"import tenacity\n",
"\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema import (\n",
" HumanMessage,\n",
" SystemMessage,\n",
")\n",
"from langchain.output_parsers import RegexParser"
]
},
{
"cell_type": "markdown",
"id": "e222e811",
"metadata": {},
"source": [
"## `GymnasiumAgent`\n",
"Here we reproduce the same `GymnasiumAgent` defined from [our Gymnasium example](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html). If after multiple retries it does not take a valid action, it simply takes a random action. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "72df0b59",
"metadata": {},
"outputs": [],
"source": [
"class GymnasiumAgent():\n",
" @classmethod\n",
" def get_docs(cls, env):\n",
" return env.unwrapped.__doc__\n",
" \n",
" def __init__(self, model, env):\n",
" self.model = model\n",
" self.env = env\n",
" self.docs = self.get_docs(env)\n",
" \n",
" self.instructions = \"\"\"\n",
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
"\n",
"Observation: <observation>\n",
"Reward: <reward>\n",
"Termination: <termination>\n",
"Truncation: <truncation>\n",
"Return: <sum_of_rewards>\n",
"\n",
"You will respond with an action, formatted as:\n",
"\n",
"Action: <action>\n",
"\n",
"where you replace <action> with your actual action.\n",
"Do nothing else but return the action.\n",
"\"\"\"\n",
" self.action_parser = RegexParser(\n",
" regex=r\"Action: (.*)\", \n",
" output_keys=['action'], \n",
" default_output_key='action')\n",
" \n",
" self.message_history = []\n",
" self.ret = 0\n",
" \n",
" def random_action(self):\n",
" action = self.env.action_space.sample()\n",
" return action\n",
" \n",
" def reset(self):\n",
" self.message_history = [\n",
" SystemMessage(content=self.docs),\n",
" SystemMessage(content=self.instructions),\n",
" ]\n",
" \n",
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
" self.ret += rew\n",
" \n",
" obs_message = f\"\"\"\n",
"Observation: {obs}\n",
"Reward: {rew}\n",
"Termination: {term}\n",
"Truncation: {trunc}\n",
"Return: {self.ret}\n",
" \"\"\"\n",
" self.message_history.append(HumanMessage(content=obs_message))\n",
" return obs_message\n",
" \n",
" def _act(self):\n",
" act_message = self.model(self.message_history)\n",
" self.message_history.append(act_message)\n",
" action = int(self.action_parser.parse(act_message.content)['action'])\n",
" return action\n",
" \n",
" def act(self):\n",
" try:\n",
" for attempt in tenacity.Retrying(\n",
" stop=tenacity.stop_after_attempt(2),\n",
" wait=tenacity.wait_none(), # No waiting time between retries\n",
" retry=tenacity.retry_if_exception_type(ValueError),\n",
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
" ):\n",
" with attempt:\n",
" action = self._act()\n",
" except tenacity.RetryError as e:\n",
" action = self.random_action()\n",
" return action"
]
},
{
"cell_type": "markdown",
"id": "df51e302",
"metadata": {},
"source": [
"## Main loop"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0f07d7cf",
"metadata": {},
"outputs": [],
"source": [
"def main(agents, env):\n",
" env.reset()\n",
"\n",
" for name, agent in agents.items():\n",
" agent.reset()\n",
"\n",
" for agent_name in env.agent_iter():\n",
" observation, reward, termination, truncation, info = env.last()\n",
" obs_message = agents[agent_name].observe(\n",
" observation, reward, termination, truncation, info)\n",
" print(obs_message)\n",
" if termination or truncation:\n",
" action = None\n",
" else:\n",
" action = agents[agent_name].act()\n",
" print(f'Action: {action}')\n",
" env.step(action)\n",
" env.close()"
]
},
{
"cell_type": "markdown",
"id": "b4b0e921",
"metadata": {},
"source": [
"## `PettingZooAgent`\n",
"\n",
"The `PettingZooAgent` extends the `GymnasiumAgent` to the multi-agent setting. The main differences are:\n",
"- `PettingZooAgent` takes in a `name` argument to identify it among multiple agents\n",
"- the function `get_docs` is implemented differently because the `PettingZoo` repo structure is structured differently from the `Gymnasium` repo"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f132c92a",
"metadata": {},
"outputs": [],
"source": [
"class PettingZooAgent(GymnasiumAgent):\n",
" @classmethod\n",
" def get_docs(cls, env):\n",
" return inspect.getmodule(env.unwrapped).__doc__\n",
" \n",
" def __init__(self, name, model, env):\n",
" super().__init__(model, env)\n",
" self.name = name\n",
" \n",
" def random_action(self):\n",
" action = self.env.action_space(self.name).sample()\n",
" return action"
]
},
{
"cell_type": "markdown",
"id": "a27f8a5d",
"metadata": {},
"source": [
"## Rock, Paper, Scissors\n",
"We can now run a simulation of a multi-agent rock, paper, scissors game using the `PettingZooAgent`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bd1256c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: 3\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 3\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 1\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: 1\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 1\n",
"Reward: 1\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 1\n",
" \n",
"Action: 0\n",
"\n",
"Observation: 2\n",
"Reward: -1\n",
"Termination: False\n",
"Truncation: False\n",
"Return: -1\n",
" \n",
"Action: 0\n",
"\n",
"Observation: 0\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: True\n",
"Return: 1\n",
" \n",
"Action: None\n",
"\n",
"Observation: 0\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: True\n",
"Return: -1\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import rps_v2\n",
"env = rps_v2.env(max_cycles=3, render_mode=\"human\")\n",
"agents = {name: PettingZooAgent(name=name, model=ChatOpenAI(temperature=1), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
},
{
"cell_type": "markdown",
"id": "fbcee258",
"metadata": {},
"source": [
"## `ActionMaskAgent`\n",
"\n",
"Some `PettingZoo` environments provide an `action_mask` to tell the agent which actions are valid. The `ActionMaskAgent` subclasses `PettingZooAgent` to use information from the `action_mask` to select actions."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bd33250a",
"metadata": {},
"outputs": [],
"source": [
"class ActionMaskAgent(PettingZooAgent):\n",
" def __init__(self, name, model, env):\n",
" super().__init__(name, model, env)\n",
" self.obs_buffer = collections.deque(maxlen=1)\n",
" \n",
" def random_action(self):\n",
" obs = self.obs_buffer[-1]\n",
" action = self.env.action_space(self.name).sample(obs[\"action_mask\"])\n",
" return action\n",
" \n",
" def reset(self):\n",
" self.message_history = [\n",
" SystemMessage(content=self.docs),\n",
" SystemMessage(content=self.instructions),\n",
" ]\n",
" \n",
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
" self.obs_buffer.append(obs)\n",
" return super().observe(obs, rew, term, trunc, info)\n",
" \n",
" def _act(self):\n",
" valid_action_instruction = \"Generate a valid action given by the indices of the `action_mask` that are not 0, according to the action formatting rules.\"\n",
" self.message_history.append(HumanMessage(content=valid_action_instruction))\n",
" return super()._act()"
]
},
{
"cell_type": "markdown",
"id": "2e76d22c",
"metadata": {},
"source": [
"## Tic-Tac-Toe\n",
"Here is an example of a Tic-Tac-Toe game that uses the `ActionMaskAgent`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9e902cfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: {'observation': array([[[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 0\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 3\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 1],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 5\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | O | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 6\n",
" | | \n",
" X | O | X \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | O | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
"Reward: -1\n",
"Termination: True\n",
"Truncation: False\n",
"Return: -1\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
"Reward: 1\n",
"Termination: True\n",
"Truncation: False\n",
"Return: 1\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import tictactoe_v3\n",
"env = tictactoe_v3.env(render_mode=\"human\")\n",
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
},
{
"cell_type": "markdown",
"id": "8728ac2a",
"metadata": {},
"source": [
"## Texas Hold'em No Limit\n",
"Here is an example of a Texas Hold'em No Limit game that uses the `ActionMaskAgent`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e350c62b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 0\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
" 0., 2., 6.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 2., 8.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 3\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
" 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 6., 20.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
"[WARNING]: Illegal move made, game terminating with current player losing. \n",
"obs['action_mask'] contains a mask of all legal moves that can be chosen.\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: -1.0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: -1.0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 20., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 100., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 2., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import texas_holdem_no_limit_v6\n",
"env = texas_holdem_no_limit_v6.env(num_players=4, render_mode=\"human\")\n",
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}