new example: multi-agent simulations with environment (#3928)

fix_agent_callbacks
mbchang 1 year ago committed by GitHub
parent f7a828685d
commit 81601d886c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ Agent simulations generally involve two main components:
Specific implementations of agent simulations (or parts of agent simulations) include
## Simulations with One Agent
- [Simulated Environment: Gymnasium](agent_simulations/gymnasium.ipynb): an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym)).
- [Simulated Environment: Gymnasium](agent_simulations/gymnasium.ipynb): an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://gymnasium.farama.org/) (formerly [OpenAI Gym](https://github.com/openai/gym)).
## Simulations with Two Agents
- [CAMEL](agent_simulations/camel_role_playing.ipynb): an implementation of the CAMEL (Communicative Agents for “Mind” Exploration of Large Scale Language Model Society) paper, where two agents communicate with each other.
@ -19,4 +19,5 @@ Specific implementations of agent simulations (or parts of agent simulations) in
- [Multi-Player D&D](agent_simulations/multi_player_dnd.ipynb): an example of how to use a generic dialogue simulator for multiple dialogue agents with a custom speaker-ordering, illustrated with a variant of the popular Dungeons & Dragons role playing game.
- [Decentralized Speaker Selection](agent_simulations/multiagent_bidding.ipynb): an example of how to implement a multi-agent dialogue without a fixed schedule for who speaks when. Instead the agents decide for themselves who speaks by outputting bids to speak. This example shows how to do this in the context of a fictitious presidential debate.
- [Authoritarian Speaker Selection](agent_simulations/multiagent_authoritarian.ipynb): an example of how to implement a multi-agent dialogue, where a privileged agent directs who speaks what. This example also showcases how to enable the privileged agent to determine when the conversation terminates. This example shows how to do this in the context of a fictitious news show.
- [Simulated Environment: PettingZoo](agent_simulations/petting_zoo.ipynb): an example of how to create a agent-environment interaction loop for multiple agents with [PettingZoo](https://pettingzoo.farama.org/) (a multi-agent version of [Gymnasium](https://gymnasium.farama.org/)).
- [Generative Agents](agent_simulations/characters.ipynb): This notebook implements a generative agent based on the paper [Generative Agents: Interactive Simulacra of Human Behavior](https://arxiv.org/abs/2304.03442) by Park, et. al.

@ -0,0 +1,837 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4b089493",
"metadata": {},
"source": [
"# Multi-Agent Simulated Environment: Petting Zoo\n",
"\n",
"In this example, we show how to define multi-agent simulations with simulated environments. Like [ours single-agent example with Gymnasium](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html), we create an agent-environment loop with an externally defined environment. The main difference is that we now implement this kind of interaction loop with multiple agents instead. We will use the [Petting Zoo](https://pettingzoo.farama.org/) library, which is the multi-agent counterpart to [Gymnasium](https://gymnasium.farama.org/)."
]
},
{
"cell_type": "markdown",
"id": "10091333",
"metadata": {},
"source": [
"## Install `pettingzoo` and other dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a3fde66",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pettingzoo in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (1.22.3)\n",
"Requirement already satisfied: pygame in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (2.4.0)\n",
"Requirement already satisfied: rlcard in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (1.2.0)\n",
"Requirement already satisfied: gymnasium>=0.26.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from pettingzoo) (0.28.1)\n",
"Requirement already satisfied: numpy>=1.18.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from pettingzoo) (1.24.3)\n",
"Requirement already satisfied: termcolor in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from rlcard) (2.2.0)\n",
"Requirement already satisfied: typing-extensions>=4.3.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (4.5.0)\n",
"Requirement already satisfied: jax-jumpy>=1.0.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (1.0.0)\n",
"Requirement already satisfied: farama-notifications>=0.0.1 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (0.0.4)\n",
"Requirement already satisfied: cloudpickle>=1.2.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (2.2.1)\n",
"Requirement already satisfied: importlib-metadata>=4.8.0 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from gymnasium>=0.26.0->pettingzoo) (6.0.1)\n",
"Requirement already satisfied: zipp>=0.5 in /Users/michaelchang/.miniconda3/envs/langchain/lib/python3.9/site-packages (from importlib-metadata>=4.8.0->gymnasium>=0.26.0->pettingzoo) (3.15.0)\n"
]
}
],
"source": [
"!pip install pettingzoo pygame rlcard"
]
},
{
"cell_type": "markdown",
"id": "5fbe130c",
"metadata": {},
"source": [
"## Import modules"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "42cd2e5d",
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"import inspect\n",
"import tenacity\n",
"\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema import (\n",
" HumanMessage,\n",
" SystemMessage,\n",
")\n",
"from langchain.output_parsers import RegexParser"
]
},
{
"cell_type": "markdown",
"id": "e222e811",
"metadata": {},
"source": [
"## `GymnasiumAgent`\n",
"Here we reproduce the same `GymnasiumAgent` defined from [our Gymnasium example](https://python.langchain.com/en/latest/use_cases/agent_simulations/gymnasium.html). If after multiple retries it does not take a valid action, it simply takes a random action. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "72df0b59",
"metadata": {},
"outputs": [],
"source": [
"class GymnasiumAgent():\n",
" @classmethod\n",
" def get_docs(cls, env):\n",
" return env.unwrapped.__doc__\n",
" \n",
" def __init__(self, model, env):\n",
" self.model = model\n",
" self.env = env\n",
" self.docs = self.get_docs(env)\n",
" \n",
" self.instructions = \"\"\"\n",
"Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
"I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
"\n",
"Observation: <observation>\n",
"Reward: <reward>\n",
"Termination: <termination>\n",
"Truncation: <truncation>\n",
"Return: <sum_of_rewards>\n",
"\n",
"You will respond with an action, formatted as:\n",
"\n",
"Action: <action>\n",
"\n",
"where you replace <action> with your actual action.\n",
"Do nothing else but return the action.\n",
"\"\"\"\n",
" self.action_parser = RegexParser(\n",
" regex=r\"Action: (.*)\", \n",
" output_keys=['action'], \n",
" default_output_key='action')\n",
" \n",
" self.message_history = []\n",
" self.ret = 0\n",
" \n",
" def random_action(self):\n",
" action = self.env.action_space.sample()\n",
" return action\n",
" \n",
" def reset(self):\n",
" self.message_history = [\n",
" SystemMessage(content=self.docs),\n",
" SystemMessage(content=self.instructions),\n",
" ]\n",
" \n",
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
" self.ret += rew\n",
" \n",
" obs_message = f\"\"\"\n",
"Observation: {obs}\n",
"Reward: {rew}\n",
"Termination: {term}\n",
"Truncation: {trunc}\n",
"Return: {self.ret}\n",
" \"\"\"\n",
" self.message_history.append(HumanMessage(content=obs_message))\n",
" return obs_message\n",
" \n",
" def _act(self):\n",
" act_message = self.model(self.message_history)\n",
" self.message_history.append(act_message)\n",
" action = int(self.action_parser.parse(act_message.content)['action'])\n",
" return action\n",
" \n",
" def act(self):\n",
" try:\n",
" for attempt in tenacity.Retrying(\n",
" stop=tenacity.stop_after_attempt(2),\n",
" wait=tenacity.wait_none(), # No waiting time between retries\n",
" retry=tenacity.retry_if_exception_type(ValueError),\n",
" before_sleep=lambda retry_state: print(f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"),\n",
" ):\n",
" with attempt:\n",
" action = self._act()\n",
" except tenacity.RetryError as e:\n",
" action = self.random_action()\n",
" return action"
]
},
{
"cell_type": "markdown",
"id": "df51e302",
"metadata": {},
"source": [
"## Main loop"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0f07d7cf",
"metadata": {},
"outputs": [],
"source": [
"def main(agents, env):\n",
" env.reset()\n",
"\n",
" for name, agent in agents.items():\n",
" agent.reset()\n",
"\n",
" for agent_name in env.agent_iter():\n",
" observation, reward, termination, truncation, info = env.last()\n",
" obs_message = agents[agent_name].observe(\n",
" observation, reward, termination, truncation, info)\n",
" print(obs_message)\n",
" if termination or truncation:\n",
" action = None\n",
" else:\n",
" action = agents[agent_name].act()\n",
" print(f'Action: {action}')\n",
" env.step(action)\n",
" env.close()"
]
},
{
"cell_type": "markdown",
"id": "b4b0e921",
"metadata": {},
"source": [
"## `PettingZooAgent`\n",
"\n",
"The `PettingZooAgent` extends the `GymnasiumAgent` to the multi-agent setting. The main differences are:\n",
"- `PettingZooAgent` takes in a `name` argument to identify it among multiple agents\n",
"- the function `get_docs` is implemented differently because the `PettingZoo` repo structure is structured differently from the `Gymnasium` repo"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f132c92a",
"metadata": {},
"outputs": [],
"source": [
"class PettingZooAgent(GymnasiumAgent):\n",
" @classmethod\n",
" def get_docs(cls, env):\n",
" return inspect.getmodule(env.unwrapped).__doc__\n",
" \n",
" def __init__(self, name, model, env):\n",
" super().__init__(model, env)\n",
" self.name = name\n",
" \n",
" def random_action(self):\n",
" action = self.env.action_space(self.name).sample()\n",
" return action"
]
},
{
"cell_type": "markdown",
"id": "a27f8a5d",
"metadata": {},
"source": [
"## Rock, Paper, Scissors\n",
"We can now run a simulation of a multi-agent rock, paper, scissors game using the `PettingZooAgent`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bd1256c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: 3\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 3\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 1\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: 1\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: 1\n",
"Reward: 1\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 1\n",
" \n",
"Action: 0\n",
"\n",
"Observation: 2\n",
"Reward: -1\n",
"Termination: False\n",
"Truncation: False\n",
"Return: -1\n",
" \n",
"Action: 0\n",
"\n",
"Observation: 0\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: True\n",
"Return: 1\n",
" \n",
"Action: None\n",
"\n",
"Observation: 0\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: True\n",
"Return: -1\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import rps_v2\n",
"env = rps_v2.env(max_cycles=3, render_mode=\"human\")\n",
"agents = {name: PettingZooAgent(name=name, model=ChatOpenAI(temperature=1), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
},
{
"cell_type": "markdown",
"id": "fbcee258",
"metadata": {},
"source": [
"## `ActionMaskAgent`\n",
"\n",
"Some `PettingZoo` environments provide an `action_mask` to tell the agent which actions are valid. The `ActionMaskAgent` subclasses `PettingZooAgent` to use information from the `action_mask` to select actions."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bd33250a",
"metadata": {},
"outputs": [],
"source": [
"class ActionMaskAgent(PettingZooAgent):\n",
" def __init__(self, name, model, env):\n",
" super().__init__(name, model, env)\n",
" self.obs_buffer = collections.deque(maxlen=1)\n",
" \n",
" def random_action(self):\n",
" obs = self.obs_buffer[-1]\n",
" action = self.env.action_space(self.name).sample(obs[\"action_mask\"])\n",
" return action\n",
" \n",
" def reset(self):\n",
" self.message_history = [\n",
" SystemMessage(content=self.docs),\n",
" SystemMessage(content=self.instructions),\n",
" ]\n",
" \n",
" def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
" self.obs_buffer.append(obs)\n",
" return super().observe(obs, rew, term, trunc, info)\n",
" \n",
" def _act(self):\n",
" valid_action_instruction = \"Generate a valid action given by the indices of the `action_mask` that are not 0, according to the action formatting rules.\"\n",
" self.message_history.append(HumanMessage(content=valid_action_instruction))\n",
" return super()._act()"
]
},
{
"cell_type": "markdown",
"id": "2e76d22c",
"metadata": {},
"source": [
"## Tic-Tac-Toe\n",
"Here is an example of a Tic-Tac-Toe game that uses the `ActionMaskAgent`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9e902cfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: {'observation': array([[[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 0\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" - | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
" | | \n",
" X | - | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 3\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | - | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [0, 0],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | - | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 1],\n",
" [0, 0]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 5\n",
" | | \n",
" X | O | - \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | O | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[0, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 6\n",
" | | \n",
" X | O | X \n",
"_____|_____|_____\n",
" | | \n",
" O | X | - \n",
"_____|_____|_____\n",
" | | \n",
" X | O | - \n",
" | | \n",
"\n",
"Observation: {'observation': array([[[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
"Reward: -1\n",
"Termination: True\n",
"Truncation: False\n",
"Return: -1\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([[[1, 0],\n",
" [0, 1],\n",
" [1, 0]],\n",
"\n",
" [[0, 1],\n",
" [1, 0],\n",
" [0, 1]],\n",
"\n",
" [[1, 0],\n",
" [0, 0],\n",
" [0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}\n",
"Reward: 1\n",
"Termination: True\n",
"Truncation: False\n",
"Return: 1\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import tictactoe_v3\n",
"env = tictactoe_v3.env(render_mode=\"human\")\n",
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
},
{
"cell_type": "markdown",
"id": "8728ac2a",
"metadata": {},
"source": [
"## Texas Hold'em No Limit\n",
"Here is an example of a Texas Hold'em No Limit game that uses the `ActionMaskAgent`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e350c62b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 1\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 0\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,\n",
" 0., 2., 6.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 2\n",
"\n",
"Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 2., 8.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 3\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
" 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 6., 20.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: False\n",
"Truncation: False\n",
"Return: 0\n",
" \n",
"Action: 4\n",
"[WARNING]: Illegal move made, game terminating with current player losing. \n",
"obs['action_mask'] contains a mask of all legal moves that can be chosen.\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,\n",
" 0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: -1.0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: -1.0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 20., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 100., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n",
"\n",
"Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 2., 100.],\n",
" dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}\n",
"Reward: 0\n",
"Termination: True\n",
"Truncation: True\n",
"Return: 0\n",
" \n",
"Action: None\n"
]
}
],
"source": [
"from pettingzoo.classic import texas_holdem_no_limit_v6\n",
"env = texas_holdem_no_limit_v6.env(num_players=4, render_mode=\"human\")\n",
"agents = {name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env) for name in env.possible_agents}\n",
"main(agents, env)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save