SeAIPalette/Palette/algos/multi_wildfire.py

22 lines
705 B
Python

from Palette.algos.wildfire import WildFire
class MultiWildFire(object):
def __init__(self, n_ships: int):
super().__init__()
self._n_ships = n_ships
self._wildfires = [WildFire() for _ in range(self._n_ships)]
def __len__(self):
return self._n_ships
def __getitem__(self, idx):
return self._wildfires[idx]
def step(self, state, info, idx):
agent_info = {k: info[k]
for k in (info.keys()-{'fields', 'redecide_direction'})}
agent_info['fields'] = info['fields'][idx]
agent_info['redecide_direction'] = info['redecide_direction'][idx]
return self._wildfires[idx].step(state[idx], agent_info)