forked from SeAIPalette/SeAIPalette
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
import numpy as np
|
|
|
|
|
|
def multi_split_area(
|
|
fields: np.ndarray,
|
|
field_size: int,
|
|
n_areas: int,
|
|
axis: str
|
|
):
|
|
assert axis in ['x', 'y'], f"Invalid axis: {axis}"
|
|
splited_fields = np.zeros_like(fields)
|
|
blank_spaces_num = np.sum(fields == 0)
|
|
|
|
last_pos = 0
|
|
area_spaces = []
|
|
start_points = []
|
|
for area_id in range(1, n_areas+1):
|
|
p = last_pos
|
|
cur_area_spaces = 0
|
|
start_point_assigned = False
|
|
while True:
|
|
cur_spaces = np.sum(fields[p, :] == 0) if axis == 'x' else np.sum(
|
|
fields[:, p] == 0)
|
|
cur_area_spaces += cur_spaces
|
|
if axis == 'x':
|
|
for i in range(fields.shape[1]):
|
|
if fields[p, i] != -1:
|
|
splited_fields[p, i] = area_id
|
|
if not start_point_assigned and fields[p, i] == 0:
|
|
start_points.append((p, i))
|
|
start_point_assigned = True
|
|
else:
|
|
splited_fields[p, i] = -1
|
|
else:
|
|
for i in range(fields.shape[0]):
|
|
if fields[i, p] != -1:
|
|
splited_fields[i, p] = area_id
|
|
if not start_point_assigned and fields[i, p] == 0:
|
|
start_points.append((i, p))
|
|
start_point_assigned = True
|
|
else:
|
|
splited_fields[i, p] = -1
|
|
p += 1
|
|
if (axis == 'x' and p >= fields.shape[0]) \
|
|
or (axis == 'y' and p >= fields.shape[1]) \
|
|
or np.sum(area_spaces) + cur_area_spaces >= blank_spaces_num * area_id / n_areas:
|
|
break
|
|
last_pos = p
|
|
area_spaces.append(cur_area_spaces)
|
|
start_points = np.asarray(start_points)
|
|
start_points[:, 1] = fields.shape[1] - start_points[:, 1] - 1
|
|
start_points = start_points * field_size + field_size / 2
|
|
assert np.sum(splited_fields == 0) == 0
|
|
return splited_fields, np.asarray(start_points)
|