SeAIPalette/Palette/algos/multi_greedy.py

97 lines
3.5 KiB
Python

import numpy as np
from Palette.algos.boustrophedon import Boustrophedon
from Palette.algos.multi_wildfire import MultiWildFire
from Palette.algos.multi_charge import MultiCharge
from Palette.algos.utils import Node
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN, \
FILLING_UP, FILLING_DOWN, FILLING_LEFT, FILLING_RIGHT
from typing import List
class MultiGreedy(Boustrophedon):
def __init__(
self,
ship_num: int,
wildfires: MultiWildFire,
charges: MultiCharge
):
super().__init__()
self.ship_num = ship_num
self.dest_direction = [None for _ in range(ship_num)]
self.wildfires = wildfires
self.charges = charges
self.finished_idx = [False for _ in range(ship_num)]
#self.ship_energy = ship_energy
#self.current_ship_energy = [ship_energy for _ in range(self.ship_num)]
self.delta_x = (0, 0, 1, -1, ) # 1, 1, -1, -1)
self.delta_y = (1, -1, 0, 0, ) # 1, -1, 1, -1)
def step(self, state, info):
ret_directions = []
redecide_directions = [False for _ in range(self.ship_num)]
for i, (x, y, direction) in enumerate(state):
if self.finished_idx[i]:
ret_directions.append(FINISHED)
continue
if not self.charges[i].empty():
ret_directions.append(self.charges[i].step(state, info)[0])
if self.charges[i].empty():
redecide_directions[i] = True
continue
# Decide whether charge or not
first_charge_act, charge_len = self.charges.step(state, info, i)
if first_charge_act == FINISHED:
self.finished_idx[i] = True
ret_directions.append(FINISHED)
continue
if info['batteries'][i] < charge_len + 2:
ret_directions.append(first_charge_act)
redecide_directions[i] = True
continue
else:
self.charges[i].clear()
self.wildfires[i].clear()
if not self.wildfires[i].empty():
ret_directions.append(self.wildfires[i].step(state, info)[0])
if self.wildfires[i].empty():
redecide_directions[i] = True
continue
fields = info['fields'][i]
if np.sum(fields == 0) == 0:
self.finished_idx[i] = True
ret_directions.append(FINISHED) # finished
continue
up_steps = self._count_valid_steps((x, y), UP, fields)
down_steps = self._count_valid_steps((x, y), DOWN, fields)
left_steps = self._count_valid_steps((x, y), LEFT, fields)
right_steps = self._count_valid_steps((x, y), RIGHT, fields)
if up_steps == 0 and down_steps == 0 and left_steps == 0 and right_steps == 0:
cur_act = BROKEN
else:
ds, ss = [], []
for d, s in zip((UP, DOWN, LEFT, RIGHT), (up_steps, down_steps, left_steps, right_steps)):
if s > 0:
ds.append(d)
ss.append(s)
cur_act = ds[np.argmin(ss)]
if cur_act == BROKEN:
ret_directions.append(self.wildfires.step(state, info, i)[0])
redecide_directions[i] = True
else:
ret_directions.append(cur_act)
return np.asarray(ret_directions), redecide_directions