SeAIPalette/Palette/algos/wildfire.py

65 lines
2.2 KiB
Python

import numpy as np
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos.utils import Node
class WildFire(object):
def __init__(self):
super().__init__()
self.actions_to_do = []
self.delta_x = (0, 0, 1, -1, ) # 1, 1, -1, -1)
self.delta_y = (1, -1, 0, 0, ) # 1, -1, 1, -1)
@ staticmethod
def _is_valid_index(index, an_array):
x, y = index
return 0 <= x < an_array.shape[0] and 0 <= y < an_array.shape[1]
def empty(self):
return len(self.actions_to_do) == 0
def clear(self):
self.actions_to_do.clear()
def step(self, state, info):
if not self.empty():
# if there're actions to do
action = self.actions_to_do[0]
del self.actions_to_do[0]
return [action]
# print(state)
x, y, _ = state
fields = info['fields']
if np.sum(fields == 0) == 0:
return [FINISHED]
root_node = Node(x, y, None)
buffer = [root_node]
visited = np.zeros_like(fields)
visited[x, y] = True
while len(buffer) > 0:
cur_node = buffer[0]
del buffer[0]
for dx, dy in zip(self.delta_x, self.delta_y):
cur_x, cur_y = cur_node.x, cur_node.y
next_x, next_y = cur_x+dx, cur_y+dy
if self._is_valid_index((next_x, next_y), fields) and fields[next_x, next_y] >= 0 and not visited[next_x, next_y]:
visited[next_x, next_y] = True
next_node = Node(next_x, next_y, cur_node)
if fields[next_x, next_y] == 0:
# not visited
node = next_node
while node.direction is not None:
self.actions_to_do.append(node.direction)
node = node.father
self.actions_to_do.reverse()
action = self.actions_to_do[0]
del self.actions_to_do[0]
return [action]
else:
buffer.append(next_node)
raise RuntimeError(f'WildFire runtime error!')