forked from SeAIPalette/SeAIPalette
75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
from copy import deepcopy
|
|
|
|
from Palette.env import make, SeaEnv
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
|
from Palette.algos import MultiGreedy, MultiWildFire, multi_split_area
|
|
|
|
|
|
def calculate_repeat_steps(fields):
|
|
assert np.sum(fields == 0) == 0, f'Not finish!'
|
|
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
|
|
|
|
|
|
def generate_agent_fields(fields, ship_num, splited_areas):
|
|
agents_fields = np.stack([deepcopy(fields)
|
|
for _ in range(ship_num)], axis=0)
|
|
for ship_id in range(1, ship_num+1):
|
|
cur_fields = agents_fields[ship_id-1]
|
|
cur_fields[(splited_areas != ship_id) * (splited_areas != -1)] += 1
|
|
return agents_fields
|
|
|
|
|
|
def main(env: SeaEnv, agent: MultiGreedy):
|
|
state, info = env.reset()
|
|
redecide_direction = [True for _ in range(env.ship_num)]
|
|
steps = 0
|
|
while True:
|
|
steps += 1
|
|
info['redecide_direction'] = redecide_direction
|
|
agents_fields = generate_agent_fields(
|
|
info['fields'], env.ship_num, env.splited_areas)
|
|
agent_info = deepcopy(info)
|
|
agent_info['fields'] = agents_fields
|
|
acts, redecide_direction = agent.step(state, agent_info)
|
|
|
|
assert (acts != BROKEN).all(), f'Invalid: ships output BROKEN step.'
|
|
state, terminate, info = env.step(acts)
|
|
if terminate:
|
|
if info['finished']:
|
|
print(f'Congrates! covered all the places!')
|
|
repeat_steps = calculate_repeat_steps(info['fields'])
|
|
units_num = info['fields'].size
|
|
repeat_ratio = repeat_steps / units_num
|
|
print(
|
|
f'Used steps: {steps}\nrepeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
|
|
else:
|
|
print(f'Failed!')
|
|
return
|
|
|
|
|
|
def get_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-c', "--config_name", type=str, required=True,
|
|
help="The name of config file, e.g., map0")
|
|
parser.add_argument('-a', '--axis', type=str, default='x',
|
|
help='The axis for splitting areas')
|
|
parser.add_argument('-n', '--n-ships', type=int,
|
|
default=3, help='The number of ships.')
|
|
parser.add_argument('--render', action='store_true',
|
|
help='render or not')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
sea_env = make(env_config=args.config_name,
|
|
render=args.render,
|
|
split_axis=args.axis,
|
|
n_ships=args.n_ships)
|
|
wildfires = MultiWildFire(n_ships=sea_env.ship_num)
|
|
algo = MultiGreedy(ship_num=sea_env.ship_num, wildfires=wildfires)
|
|
main(sea_env, algo)
|