forked from SeAIPalette/SeAIPalette
97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
import numpy as np
|
|
from Palette.algos.boustrophedon import Boustrophedon
|
|
from Palette.algos.multi_wildfire import MultiWildFire
|
|
from Palette.algos.multi_charge import MultiCharge
|
|
from Palette.algos.utils import Node
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN, \
|
|
FILLING_UP, FILLING_DOWN, FILLING_LEFT, FILLING_RIGHT
|
|
|
|
from typing import List
|
|
|
|
|
|
class MultiGreedy(Boustrophedon):
|
|
def __init__(
|
|
self,
|
|
ship_num: int,
|
|
wildfires: MultiWildFire,
|
|
charges: MultiCharge
|
|
):
|
|
super().__init__()
|
|
self.ship_num = ship_num
|
|
self.dest_direction = [None for _ in range(ship_num)]
|
|
self.wildfires = wildfires
|
|
self.charges = charges
|
|
self.finished_idx = [False for _ in range(ship_num)]
|
|
|
|
#self.ship_energy = ship_energy
|
|
#self.current_ship_energy = [ship_energy for _ in range(self.ship_num)]
|
|
|
|
self.delta_x = (0, 0, 1, -1, ) # 1, 1, -1, -1)
|
|
self.delta_y = (1, -1, 0, 0, ) # 1, -1, 1, -1)
|
|
|
|
def step(self, state, info):
|
|
ret_directions = []
|
|
redecide_directions = [False for _ in range(self.ship_num)]
|
|
|
|
for i, (x, y, direction) in enumerate(state):
|
|
if self.finished_idx[i]:
|
|
ret_directions.append(FINISHED)
|
|
continue
|
|
|
|
if not self.charges[i].empty():
|
|
ret_directions.append(self.charges[i].step(state, info)[0])
|
|
if self.charges[i].empty():
|
|
redecide_directions[i] = True
|
|
continue
|
|
|
|
# Decide whether charge or not
|
|
first_charge_act, charge_len = self.charges.step(state, info, i)
|
|
if first_charge_act == FINISHED:
|
|
self.finished_idx[i] = True
|
|
ret_directions.append(FINISHED)
|
|
continue
|
|
|
|
if info['batteries'][i] < charge_len + 2:
|
|
ret_directions.append(first_charge_act)
|
|
redecide_directions[i] = True
|
|
continue
|
|
else:
|
|
self.charges[i].clear()
|
|
self.wildfires[i].clear()
|
|
|
|
if not self.wildfires[i].empty():
|
|
ret_directions.append(self.wildfires[i].step(state, info)[0])
|
|
if self.wildfires[i].empty():
|
|
redecide_directions[i] = True
|
|
continue
|
|
|
|
fields = info['fields'][i]
|
|
|
|
if np.sum(fields == 0) == 0:
|
|
self.finished_idx[i] = True
|
|
ret_directions.append(FINISHED) # finished
|
|
continue
|
|
|
|
up_steps = self._count_valid_steps((x, y), UP, fields)
|
|
down_steps = self._count_valid_steps((x, y), DOWN, fields)
|
|
left_steps = self._count_valid_steps((x, y), LEFT, fields)
|
|
right_steps = self._count_valid_steps((x, y), RIGHT, fields)
|
|
|
|
if up_steps == 0 and down_steps == 0 and left_steps == 0 and right_steps == 0:
|
|
cur_act = BROKEN
|
|
else:
|
|
ds, ss = [], []
|
|
for d, s in zip((UP, DOWN, LEFT, RIGHT), (up_steps, down_steps, left_steps, right_steps)):
|
|
if s > 0:
|
|
ds.append(d)
|
|
ss.append(s)
|
|
cur_act = ds[np.argmin(ss)]
|
|
|
|
if cur_act == BROKEN:
|
|
ret_directions.append(self.wildfires.step(state, info, i)[0])
|
|
redecide_directions[i] = True
|
|
else:
|
|
ret_directions.append(cur_act)
|
|
|
|
return np.asarray(ret_directions), redecide_directions
|