SeAIPalette/Palette/algos/spiral.py

173 lines
6.1 KiB
Python

import numpy as np
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
class Spiral(object):
class MyTuple:
def __init__(self, x1, x2, direction):
self.data = (x1, x2)
self.direction = direction
def __lt__(self, other):
return self.data < other.data
def __le__(self, other):
return self.data <= other.data
def __gt__(self, other):
return self.data > other.data
def __ge__(self, other):
return self.data >= other.data
def __eq__(self, other):
return self.data == other.data
def __ne__(self, other):
return self.data != other.data
def __init__(self,):
super().__init__()
self.delta_x = [0, 0, 1, -1, 1, 1, -1, -1]
self.delta_y = [1, -1, 0, 0, 1, -1, 1, -1]
@ staticmethod
def _right_index(x, y):
return x+1, y
@ staticmethod
def _left_index(x, y):
return x-1, y
@ staticmethod
def _up_index(x, y):
return x, y-1
@ staticmethod
def _down_index(x, y):
return x, y+1
@ staticmethod
def _is_valid_index(index, an_array):
x, y = index
return 0 <= x < an_array.shape[0] and 0 <= y < an_array.shape[1]
def _count_valid_steps(self, index, direction, fields):
if direction == UP:
update_index = self._up_index
elif direction == DOWN:
update_index = self._down_index
elif direction == LEFT:
update_index = self._left_index
elif direction == RIGHT:
update_index = self._right_index
else:
raise ValueError(f'invalid direction: {direction}')
valid_steps = 0
cur_x, cur_y = index
while True:
cur_x, cur_y = update_index(cur_x, cur_y)
if not self._is_valid_index((cur_x, cur_y), fields) \
or fields[cur_x, cur_y] != 0:
break
else:
valid_steps += 1
return valid_steps
def _is_edge(self, index, fields, direction=None):
if direction is not None:
if direction == UP:
delta_x, delta_y = [self.delta_x[0]] + \
self.delta_x[2:], [self.delta_y[0]]+self.delta_y[2:]
elif direction == DOWN:
delta_x, delta_y = self.delta_x[1:], self.delta_y[1:]
elif direction == RIGHT:
delta_x, delta_y = self.delta_x[0:1] + \
self.delta_x[3:], self.delta_y[0:1] + self.delta_y[3:]
elif direction == LEFT:
delta_x, delta_y = self.delta_x[0:2] + \
self.delta_x[4:], self.delta_y[0:2] + self.delta_y[4:]
else:
delta_x, delta_y = self.delta_x, self.delta_y
x, y = index
for dx, dy in zip(delta_x, delta_y):
if self._is_valid_index((x+dx, y+dy), fields) and fields[x+dx, y+dy] != 0:
return True
return False
def _act(self, direction, fields, index):
x, y = index
if direction == UP:
continue_update_index = self._up_index
reverse_update_index = self._down_index
possible_turn = [LEFT, RIGHT]
possible_turn_update_index = [self._left_index, self._right_index]
reverse_direction = DOWN
elif direction == DOWN:
continue_update_index = self._down_index
reverse_update_index = self._up_index
possible_turn = [LEFT, RIGHT]
reverse_direction = UP
possible_turn_update_index = [self._left_index, self._right_index]
elif direction == LEFT:
continue_update_index = self._left_index
reverse_update_index = self._right_index
possible_turn = [UP, DOWN]
reverse_direction = RIGHT
possible_turn_update_index = [self._up_index, self._down_index]
elif direction == RIGHT:
continue_update_index = self._right_index
reverse_update_index = self._left_index
possible_turn = [UP, DOWN]
reverse_direction = LEFT
possible_turn_update_index = [self._up_index, self._down_index]
else:
raise ValueError(f'invalid direction: {direction}')
def _turn(cur_direction):
edge_directions = [self._is_edge(
update_fn(x, y), fields, cur_direction) for update_fn in possible_turn_update_index]
valid_steps = [self._count_valid_steps(
(x, y), d, fields) for d in possible_turn]
steps0, steps1 = valid_steps
if steps0 > 0 or steps1 > 0:
direction = np.max(
[Spiral.MyTuple(e, v, d) for e, v, d in zip(edge_directions, valid_steps, possible_turn)]).direction
return direction
else:
return BROKEN
if not self._is_valid_index(continue_update_index(x, y), fields) \
or fields[continue_update_index(x, y)] != 0:
"""
if not self._is_edge(continue_update_index(x, y), fields, direction):
print('turn')
else:
print('not turn')
"""
return _turn(direction)
else:
# print(f'not turn')
return direction
def step(self, state, info):
x, y, direction = state[0]
fields = info['fields']
if np.sum(fields == 0) == 0:
return [FINISHED]
if info['redecide_direction']:
all_directions = (UP, DOWN, LEFT, RIGHT)
all_update_fn = (self._up_index, self._down_index,
self._left_index, self._right_index)
edge_directions = [self._is_edge(
update_fn(x, y), fields, direction) for update_fn in all_update_fn]
valid_steps = [self._count_valid_steps(
(x, y), d, fields) for d in all_directions]
direction = np.max([Spiral.MyTuple(e, v, d) for e, v, d in zip(
edge_directions, valid_steps, all_directions)]).direction
return [self._act(direction, fields, (x, y))]