SeAIPalette/Palette/algos/charge.py

60 lines
2.2 KiB
Python

import numpy as np
from Palette.algos.wildfire import WildFire
from Palette.algos.utils import Node
from Palette.constants import FILLING_UP, FILLING_DOWN, FILLING_LEFT, FILLING_RIGHT, FINISHED, BROKEN, FILLING, FILLING_BACK
class Charge(WildFire):
def __init__(self, filling_time: int):
super().__init__()
self.filling_time = filling_time
def step(self, state, info):
if not self.empty():
action = self.actions_to_do[0]
del self.actions_to_do[0]
return action, None
x, y, _ = state
fields = info['fields']
if np.sum(fields == 0) == 0:
return FINISHED, None
root_node = Node(x, y, None, energy=True)
buffer = [root_node]
visited = np.zeros_like(fields)
visited[x, y] = True
while len(buffer) > 0:
cur_node = buffer[0]
del buffer[0]
for dx, dy in zip(self.delta_x, self.delta_y):
cur_x, cur_y = cur_node.x, cur_node.y
next_x, next_y = cur_x+dx, cur_y+dy
if self._is_valid_index((next_x, next_y), fields) and not visited[next_x, next_y]:
visited[next_x, next_y] = True
next_node = Node(next_x, next_y, cur_node, energy=True)
if fields[next_x, next_y] == -1:
# not visited
node = next_node
while node.direction is not None:
self.actions_to_do.append(node.direction)
node = node.father
self.actions_to_do.reverse()
charge_len = len(self.actions_to_do)
self.actions_to_do.extend(
[FILLING for _ in range(self.filling_time)])
self.actions_to_do.extend(
[FILLING_BACK for _ in range(charge_len)])
action = self.actions_to_do[0]
del self.actions_to_do[0]
return action, charge_len
else:
buffer.append(next_node)
raise RuntimeError(f'Charge runtime error!')