forked from SeAIPalette/SeAIPalette
140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
|
import numpy as np
|
||
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
||
|
|
||
|
|
||
|
class Boustrophedon(object):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.dest_direction = [None]
|
||
|
|
||
|
@ staticmethod
|
||
|
def _right_index(x, y):
|
||
|
return x+1, y
|
||
|
|
||
|
@ staticmethod
|
||
|
def _left_index(x, y):
|
||
|
return x-1, y
|
||
|
|
||
|
@ staticmethod
|
||
|
def _up_index(x, y):
|
||
|
return x, y-1
|
||
|
|
||
|
@ staticmethod
|
||
|
def _down_index(x, y):
|
||
|
return x, y+1
|
||
|
|
||
|
@ staticmethod
|
||
|
def _is_valid_index(index, an_array):
|
||
|
x, y = index
|
||
|
return 0 <= x < an_array.shape[0] and 0 <= y < an_array.shape[1]
|
||
|
|
||
|
def _count_valid_steps(self, index, direction, fields):
|
||
|
if direction == UP:
|
||
|
update_index = self._up_index
|
||
|
elif direction == DOWN:
|
||
|
update_index = self._down_index
|
||
|
elif direction == LEFT:
|
||
|
update_index = self._left_index
|
||
|
elif direction == RIGHT:
|
||
|
update_index = self._right_index
|
||
|
else:
|
||
|
raise ValueError(f'invalid direction: {direction}')
|
||
|
|
||
|
valid_steps = 0
|
||
|
cur_x, cur_y = index
|
||
|
while True:
|
||
|
cur_x, cur_y = update_index(cur_x, cur_y)
|
||
|
if not self._is_valid_index((cur_x, cur_y), fields) \
|
||
|
or fields[cur_x, cur_y] != 0:
|
||
|
break
|
||
|
else:
|
||
|
valid_steps += 1
|
||
|
return valid_steps
|
||
|
|
||
|
def _act(self, direction, fields, index, ship_id):
|
||
|
x, y = index
|
||
|
if direction == UP:
|
||
|
continue_update_index = self._up_index
|
||
|
reverse_update_index = self._down_index
|
||
|
possible_turn = [LEFT, RIGHT]
|
||
|
reverse_direction = DOWN
|
||
|
elif direction == DOWN:
|
||
|
continue_update_index = self._down_index
|
||
|
reverse_update_index = self._up_index
|
||
|
possible_turn = [LEFT, RIGHT]
|
||
|
reverse_direction = UP
|
||
|
elif direction == LEFT:
|
||
|
continue_update_index = self._left_index
|
||
|
reverse_update_index = self._right_index
|
||
|
possible_turn = [UP, DOWN]
|
||
|
reverse_direction = RIGHT
|
||
|
elif direction == RIGHT:
|
||
|
continue_update_index = self._right_index
|
||
|
reverse_update_index = self._left_index
|
||
|
possible_turn = [UP, DOWN]
|
||
|
reverse_direction = LEFT
|
||
|
else:
|
||
|
raise ValueError(f'invalid direction: {direction}')
|
||
|
|
||
|
if not self._is_valid_index(continue_update_index(x, y), fields) \
|
||
|
or fields[continue_update_index(x, y)] != 0:
|
||
|
|
||
|
steps0 = self._count_valid_steps((x, y), possible_turn[0], fields)
|
||
|
steps1 = self._count_valid_steps((x, y), possible_turn[1], fields)
|
||
|
|
||
|
if steps0 > 0 or steps1 > 0:
|
||
|
self.dest_direction[ship_id] = reverse_direction
|
||
|
if steps0 > 0 and steps1 > 0:
|
||
|
if steps0 < steps1:
|
||
|
return possible_turn[1]
|
||
|
else:
|
||
|
return possible_turn[0]
|
||
|
elif steps0 > 0:
|
||
|
return possible_turn[0]
|
||
|
else:
|
||
|
return possible_turn[1]
|
||
|
else:
|
||
|
search_x, search_y = reverse_update_index(x, y)
|
||
|
can_continue = False
|
||
|
while self._is_valid_index((search_x, search_y), fields) and fields[search_x, search_y] == 0:
|
||
|
search_steps0 = self._count_valid_steps(
|
||
|
(search_x, search_y), possible_turn[0], fields)
|
||
|
search_steps1 = self._count_valid_steps(
|
||
|
(search_x, search_y), possible_turn[1], fields)
|
||
|
if search_steps0 or search_steps1:
|
||
|
can_continue = True
|
||
|
break
|
||
|
search_x, search_y = reverse_update_index(
|
||
|
search_x, search_y)
|
||
|
if can_continue:
|
||
|
return reverse_direction
|
||
|
else:
|
||
|
return BROKEN
|
||
|
else:
|
||
|
return direction
|
||
|
|
||
|
def step(self, state, info):
|
||
|
x, y, direction = state
|
||
|
fields = info['fields']
|
||
|
|
||
|
if np.sum(fields == 0) == 0:
|
||
|
return [FINISHED] # finished
|
||
|
|
||
|
if info['redecide_direction']:
|
||
|
up_steps = self._count_valid_steps((x, y), UP, fields)
|
||
|
down_steps = self._count_valid_steps((x, y), DOWN, fields)
|
||
|
left_steps = self._count_valid_steps((x, y), LEFT, fields)
|
||
|
right_steps = self._count_valid_steps((x, y), RIGHT, fields)
|
||
|
|
||
|
direction = [UP, DOWN, LEFT, RIGHT][np.argmax(
|
||
|
[up_steps, down_steps, left_steps, right_steps])]
|
||
|
|
||
|
if self.dest_direction[0] is not None:
|
||
|
d = self.dest_direction[0]
|
||
|
self.dest_direction = [None]
|
||
|
if self._count_valid_steps((x, y), d, fields) > 0:
|
||
|
return [d]
|
||
|
else:
|
||
|
return [BROKEN]
|
||
|
return [self._act(direction, fields, (x, y), 0)]
|