forked from SeAIPalette/SeAIPalette
65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
import numpy as np
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
|
from Palette.algos.utils import Node
|
|
|
|
|
|
class WildFire(object):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.actions_to_do = []
|
|
self.delta_x = (0, 0, 1, -1, ) # 1, 1, -1, -1)
|
|
self.delta_y = (1, -1, 0, 0, ) # 1, -1, 1, -1)
|
|
|
|
@ staticmethod
|
|
def _is_valid_index(index, an_array):
|
|
x, y = index
|
|
return 0 <= x < an_array.shape[0] and 0 <= y < an_array.shape[1]
|
|
|
|
def empty(self):
|
|
return len(self.actions_to_do) == 0
|
|
|
|
def clear(self):
|
|
self.actions_to_do.clear()
|
|
|
|
def step(self, state, info):
|
|
if not self.empty():
|
|
# if there're actions to do
|
|
action = self.actions_to_do[0]
|
|
del self.actions_to_do[0]
|
|
return [action]
|
|
|
|
# print(state)
|
|
x, y, _ = state
|
|
fields = info['fields']
|
|
|
|
if np.sum(fields == 0) == 0:
|
|
return [FINISHED]
|
|
|
|
root_node = Node(x, y, None)
|
|
buffer = [root_node]
|
|
visited = np.zeros_like(fields)
|
|
visited[x, y] = True
|
|
while len(buffer) > 0:
|
|
cur_node = buffer[0]
|
|
del buffer[0]
|
|
for dx, dy in zip(self.delta_x, self.delta_y):
|
|
cur_x, cur_y = cur_node.x, cur_node.y
|
|
next_x, next_y = cur_x+dx, cur_y+dy
|
|
if self._is_valid_index((next_x, next_y), fields) and fields[next_x, next_y] >= 0 and not visited[next_x, next_y]:
|
|
visited[next_x, next_y] = True
|
|
next_node = Node(next_x, next_y, cur_node)
|
|
if fields[next_x, next_y] == 0:
|
|
# not visited
|
|
node = next_node
|
|
while node.direction is not None:
|
|
self.actions_to_do.append(node.direction)
|
|
node = node.father
|
|
self.actions_to_do.reverse()
|
|
action = self.actions_to_do[0]
|
|
del self.actions_to_do[0]
|
|
return [action]
|
|
else:
|
|
buffer.append(next_node)
|
|
|
|
raise RuntimeError(f'WildFire runtime error!')
|