Edit model card

Experiment

PPO agents trained in a selfplay settings. This repo includes checkpoints collected during training for 4 experiments:

  • Shared weights for actor and critic
  • No shared weights
  • Resume training for extra steps for both shared and no shared setup Please check our wandb report for more details and the training code on our GitHub

Environment

Multiplayer pong_v3 from PettingZoo with :

  • 4 stacked frame
  • Agent is trained to predict left agent policy (observation is mirrored for right agent)
def pong_obs_modification(obs, _space, player_id):
    obs[:9, :, :] = 0
    if "second" in player_id:
        # Mirror the image
        obs = obs[:, ::-1, :]
    return obs


def get_env(args, run_name):
    env = importlib.import_module(f"pettingzoo.atari.{args.env_id}").parallel_env()
    env = ss.max_observation_v0(env, 2)
    env = ss.frame_skip_v0(env, 4)
    env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.frame_stack_v1(env, 4)
    # Remove the score from the observation
    if "pong" in args.env_id:
        env = ss.lambda_wrappers.observation_lambda_v0(
            env,
            pong_obs_modification,
        )
    # env = ss.agent_indicator_v0(env, type_only=False)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym")
    envs.single_observation_space = envs.observation_space
    envs.single_action_space = envs.action_space
    envs.is_vector_env = True
    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    if args.capture_video:
        envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
    assert isinstance(
        envs.single_action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"
    return envs

Model architecture

def atari_network(orth_init=False):
    init = layer_init if orth_init else lambda m: m
    return nn.Sequential(
        init(nn.Conv2d(4, 32, 8, stride=4)),
        nn.ReLU(),
        init(nn.Conv2d(32, 64, 4, stride=2)),
        nn.ReLU(),
        init(nn.Conv2d(64, 64, 3, stride=1)),
        nn.ReLU(),
        nn.Flatten(),
        init(nn.Linear(64 * 7 * 7, 512)),
        nn.ReLU(),
    )

class Agent(nn.Module):
    def __init__(self, envs, share_network=False):
        super().__init__()
        self.actor_network = atari_network(orth_init=True)
        self.share_network = share_network
        if share_network:
            self.critic_network = self.actor_network
        else:
            self.critic_network = atari_network(orth_init=True)
        self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def get_value(self, x):
        x = x.clone()
        x[:, :, :, [0, 1, 2, 3]] /= 255.0
        return self.critic(self.critic_network(x.permute((0, 3, 1, 2))))

    def get_action_and_value(self, x, action=None):
        x = x.clone()
        x[:, :, :, [0, 1, 2, 3]] /= 255.0
        logits = self.actor(self.actor_network(x.permute((0, 3, 1, 2))))
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return (
            action,
            probs.log_prob(action),
            probs.entropy(),
            self.critic(self.critic_network(x.permute((0, 3, 1, 2)))),
        )
    
    def load(self, path):
        self.load_state_dict(torch.load(path))
        if self.share_network:
            self.critic_network = self.actor_network            
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading