forked from SeAIPalette/SeAIPalette
173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
import numpy as np
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
|
|
|
|
|
class Spiral(object):
|
|
class MyTuple:
|
|
def __init__(self, x1, x2, direction):
|
|
self.data = (x1, x2)
|
|
self.direction = direction
|
|
|
|
def __lt__(self, other):
|
|
return self.data < other.data
|
|
|
|
def __le__(self, other):
|
|
return self.data <= other.data
|
|
|
|
def __gt__(self, other):
|
|
return self.data > other.data
|
|
|
|
def __ge__(self, other):
|
|
return self.data >= other.data
|
|
|
|
def __eq__(self, other):
|
|
return self.data == other.data
|
|
|
|
def __ne__(self, other):
|
|
return self.data != other.data
|
|
|
|
def __init__(self,):
|
|
super().__init__()
|
|
self.delta_x = [0, 0, 1, -1, 1, 1, -1, -1]
|
|
self.delta_y = [1, -1, 0, 0, 1, -1, 1, -1]
|
|
|
|
@ staticmethod
|
|
def _right_index(x, y):
|
|
return x+1, y
|
|
|
|
@ staticmethod
|
|
def _left_index(x, y):
|
|
return x-1, y
|
|
|
|
@ staticmethod
|
|
def _up_index(x, y):
|
|
return x, y-1
|
|
|
|
@ staticmethod
|
|
def _down_index(x, y):
|
|
return x, y+1
|
|
|
|
@ staticmethod
|
|
def _is_valid_index(index, an_array):
|
|
x, y = index
|
|
return 0 <= x < an_array.shape[0] and 0 <= y < an_array.shape[1]
|
|
|
|
def _count_valid_steps(self, index, direction, fields):
|
|
if direction == UP:
|
|
update_index = self._up_index
|
|
elif direction == DOWN:
|
|
update_index = self._down_index
|
|
elif direction == LEFT:
|
|
update_index = self._left_index
|
|
elif direction == RIGHT:
|
|
update_index = self._right_index
|
|
else:
|
|
raise ValueError(f'invalid direction: {direction}')
|
|
|
|
valid_steps = 0
|
|
cur_x, cur_y = index
|
|
while True:
|
|
cur_x, cur_y = update_index(cur_x, cur_y)
|
|
if not self._is_valid_index((cur_x, cur_y), fields) \
|
|
or fields[cur_x, cur_y] != 0:
|
|
break
|
|
else:
|
|
valid_steps += 1
|
|
return valid_steps
|
|
|
|
def _is_edge(self, index, fields, direction=None):
|
|
if direction is not None:
|
|
if direction == UP:
|
|
delta_x, delta_y = [self.delta_x[0]] + \
|
|
self.delta_x[2:], [self.delta_y[0]]+self.delta_y[2:]
|
|
elif direction == DOWN:
|
|
delta_x, delta_y = self.delta_x[1:], self.delta_y[1:]
|
|
elif direction == RIGHT:
|
|
delta_x, delta_y = self.delta_x[0:1] + \
|
|
self.delta_x[3:], self.delta_y[0:1] + self.delta_y[3:]
|
|
elif direction == LEFT:
|
|
delta_x, delta_y = self.delta_x[0:2] + \
|
|
self.delta_x[4:], self.delta_y[0:2] + self.delta_y[4:]
|
|
else:
|
|
delta_x, delta_y = self.delta_x, self.delta_y
|
|
x, y = index
|
|
for dx, dy in zip(delta_x, delta_y):
|
|
if self._is_valid_index((x+dx, y+dy), fields) and fields[x+dx, y+dy] != 0:
|
|
return True
|
|
return False
|
|
|
|
def _act(self, direction, fields, index):
|
|
x, y = index
|
|
if direction == UP:
|
|
continue_update_index = self._up_index
|
|
reverse_update_index = self._down_index
|
|
possible_turn = [LEFT, RIGHT]
|
|
possible_turn_update_index = [self._left_index, self._right_index]
|
|
reverse_direction = DOWN
|
|
elif direction == DOWN:
|
|
continue_update_index = self._down_index
|
|
reverse_update_index = self._up_index
|
|
possible_turn = [LEFT, RIGHT]
|
|
reverse_direction = UP
|
|
possible_turn_update_index = [self._left_index, self._right_index]
|
|
elif direction == LEFT:
|
|
continue_update_index = self._left_index
|
|
reverse_update_index = self._right_index
|
|
possible_turn = [UP, DOWN]
|
|
reverse_direction = RIGHT
|
|
possible_turn_update_index = [self._up_index, self._down_index]
|
|
elif direction == RIGHT:
|
|
continue_update_index = self._right_index
|
|
reverse_update_index = self._left_index
|
|
possible_turn = [UP, DOWN]
|
|
reverse_direction = LEFT
|
|
possible_turn_update_index = [self._up_index, self._down_index]
|
|
else:
|
|
raise ValueError(f'invalid direction: {direction}')
|
|
|
|
def _turn(cur_direction):
|
|
edge_directions = [self._is_edge(
|
|
update_fn(x, y), fields, cur_direction) for update_fn in possible_turn_update_index]
|
|
valid_steps = [self._count_valid_steps(
|
|
(x, y), d, fields) for d in possible_turn]
|
|
steps0, steps1 = valid_steps
|
|
|
|
if steps0 > 0 or steps1 > 0:
|
|
direction = np.max(
|
|
[Spiral.MyTuple(e, v, d) for e, v, d in zip(edge_directions, valid_steps, possible_turn)]).direction
|
|
return direction
|
|
else:
|
|
return BROKEN
|
|
|
|
if not self._is_valid_index(continue_update_index(x, y), fields) \
|
|
or fields[continue_update_index(x, y)] != 0:
|
|
"""
|
|
if not self._is_edge(continue_update_index(x, y), fields, direction):
|
|
print('turn')
|
|
else:
|
|
print('not turn')
|
|
"""
|
|
return _turn(direction)
|
|
else:
|
|
# print(f'not turn')
|
|
return direction
|
|
|
|
def step(self, state, info):
|
|
x, y, direction = state[0]
|
|
fields = info['fields']
|
|
|
|
if np.sum(fields == 0) == 0:
|
|
return [FINISHED]
|
|
|
|
if info['redecide_direction']:
|
|
all_directions = (UP, DOWN, LEFT, RIGHT)
|
|
all_update_fn = (self._up_index, self._down_index,
|
|
self._left_index, self._right_index)
|
|
edge_directions = [self._is_edge(
|
|
update_fn(x, y), fields, direction) for update_fn in all_update_fn]
|
|
valid_steps = [self._count_valid_steps(
|
|
(x, y), d, fields) for d in all_directions]
|
|
direction = np.max([Spiral.MyTuple(e, v, d) for e, v, d in zip(
|
|
edge_directions, valid_steps, all_directions)]).direction
|
|
return [self._act(direction, fields, (x, y))]
|