SeAIPalette/Palette/algos/multi_split_area.py

55 lines
2.0 KiB
Python
Raw Permalink Normal View History

2022-05-30 21:36:26 +08:00
import numpy as np
def multi_split_area(
fields: np.ndarray,
field_size: int,
n_areas: int,
axis: str
):
assert axis in ['x', 'y'], f"Invalid axis: {axis}"
splited_fields = np.zeros_like(fields)
blank_spaces_num = np.sum(fields == 0)
last_pos = 0
area_spaces = []
start_points = []
for area_id in range(1, n_areas+1):
p = last_pos
cur_area_spaces = 0
start_point_assigned = False
while True:
cur_spaces = np.sum(fields[p, :] == 0) if axis == 'x' else np.sum(
fields[:, p] == 0)
cur_area_spaces += cur_spaces
if axis == 'x':
for i in range(fields.shape[1]):
if fields[p, i] != -1:
splited_fields[p, i] = area_id
if not start_point_assigned and fields[p, i] == 0:
start_points.append((p, i))
start_point_assigned = True
else:
splited_fields[p, i] = -1
else:
for i in range(fields.shape[0]):
if fields[i, p] != -1:
splited_fields[i, p] = area_id
if not start_point_assigned and fields[i, p] == 0:
start_points.append((i, p))
start_point_assigned = True
else:
splited_fields[i, p] = -1
p += 1
if (axis == 'x' and p >= fields.shape[0]) \
or (axis == 'y' and p >= fields.shape[1]) \
or np.sum(area_spaces) + cur_area_spaces >= blank_spaces_num * area_id / n_areas:
break
last_pos = p
area_spaces.append(cur_area_spaces)
start_points = np.asarray(start_points)
start_points[:, 1] = fields.shape[1] - start_points[:, 1] - 1
start_points = start_points * field_size + field_size / 2
assert np.sum(splited_fields == 0) == 0
return splited_fields, np.asarray(start_points)