forked from SeAIPalette/SeAIPalette
60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
import numpy as np
|
|
|
|
from Palette.algos.wildfire import WildFire
|
|
from Palette.algos.utils import Node
|
|
from Palette.constants import FILLING_UP, FILLING_DOWN, FILLING_LEFT, FILLING_RIGHT, FINISHED, BROKEN, FILLING, FILLING_BACK
|
|
|
|
|
|
class Charge(WildFire):
|
|
def __init__(self, filling_time: int):
|
|
super().__init__()
|
|
self.filling_time = filling_time
|
|
|
|
def step(self, state, info):
|
|
if not self.empty():
|
|
action = self.actions_to_do[0]
|
|
del self.actions_to_do[0]
|
|
return action, None
|
|
|
|
x, y, _ = state
|
|
fields = info['fields']
|
|
|
|
if np.sum(fields == 0) == 0:
|
|
return FINISHED, None
|
|
|
|
root_node = Node(x, y, None, energy=True)
|
|
buffer = [root_node]
|
|
visited = np.zeros_like(fields)
|
|
visited[x, y] = True
|
|
while len(buffer) > 0:
|
|
cur_node = buffer[0]
|
|
del buffer[0]
|
|
for dx, dy in zip(self.delta_x, self.delta_y):
|
|
cur_x, cur_y = cur_node.x, cur_node.y
|
|
next_x, next_y = cur_x+dx, cur_y+dy
|
|
if self._is_valid_index((next_x, next_y), fields) and not visited[next_x, next_y]:
|
|
visited[next_x, next_y] = True
|
|
next_node = Node(next_x, next_y, cur_node, energy=True)
|
|
if fields[next_x, next_y] == -1:
|
|
# not visited
|
|
node = next_node
|
|
while node.direction is not None:
|
|
self.actions_to_do.append(node.direction)
|
|
node = node.father
|
|
self.actions_to_do.reverse()
|
|
charge_len = len(self.actions_to_do)
|
|
|
|
self.actions_to_do.extend(
|
|
[FILLING for _ in range(self.filling_time)])
|
|
self.actions_to_do.extend(
|
|
[FILLING_BACK for _ in range(charge_len)])
|
|
|
|
action = self.actions_to_do[0]
|
|
del self.actions_to_do[0]
|
|
|
|
return action, charge_len
|
|
else:
|
|
buffer.append(next_node)
|
|
|
|
raise RuntimeError(f'Charge runtime error!')
|