SeAIPalette/Palette/examples/multi_split_wildfire_boustr...

75 lines
2.8 KiB
Python

import numpy as np
from argparse import ArgumentParser
from copy import deepcopy
from Palette.env import make, SeaEnv
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos import MultiGreedy, MultiWildFire, multi_split_area
def calculate_repeat_steps(fields):
assert np.sum(fields == 0) == 0, f'Not finish!'
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
def generate_agent_fields(fields, ship_num, splited_areas):
agents_fields = np.stack([deepcopy(fields)
for _ in range(ship_num)], axis=0)
for ship_id in range(1, ship_num+1):
cur_fields = agents_fields[ship_id-1]
cur_fields[(splited_areas != ship_id) * (splited_areas != -1)] += 1
return agents_fields
def main(env: SeaEnv, agent: MultiGreedy):
state, info = env.reset()
redecide_direction = [True for _ in range(env.ship_num)]
steps = 0
while True:
steps += 1
info['redecide_direction'] = redecide_direction
agents_fields = generate_agent_fields(
info['fields'], env.ship_num, env.splited_areas)
agent_info = deepcopy(info)
agent_info['fields'] = agents_fields
acts, redecide_direction = agent.step(state, agent_info)
assert (acts != BROKEN).all(), f'Invalid: ships output BROKEN step.'
state, terminate, info = env.step(acts)
if terminate:
if info['finished']:
print(f'Congrates! covered all the places!')
repeat_steps = calculate_repeat_steps(info['fields'])
units_num = info['fields'].size
repeat_ratio = repeat_steps / units_num
print(
f'Used steps: {steps}\nrepeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
else:
print(f'Failed!')
return
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('-a', '--axis', type=str, default='x',
help='The axis for splitting areas')
parser.add_argument('-n', '--n-ships', type=int,
default=3, help='The number of ships.')
parser.add_argument('--render', action='store_true',
help='render or not')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sea_env = make(env_config=args.config_name,
render=args.render,
split_axis=args.axis,
n_ships=args.n_ships)
wildfires = MultiWildFire(n_ships=sea_env.ship_num)
algo = MultiGreedy(ship_num=sea_env.ship_num, wildfires=wildfires)
main(sea_env, algo)