SeAIPalette/Palette/examples/charge_multi_split_wildfire...

90 lines
3.5 KiB
Python

import numpy as np
from argparse import ArgumentParser
from copy import deepcopy
from Palette.env import make, SeaEnv
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos import MultiGreedy, MultiWildFire, MultiCharge, multi_split_area
def calculate_repeat_steps(fields):
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
def generate_agent_fields(fields, ship_num, splited_areas):
agents_fields = np.stack([deepcopy(fields)
for _ in range(ship_num)], axis=0)
for ship_id in range(1, ship_num+1):
cur_fields = agents_fields[ship_id-1]
cur_fields[(splited_areas != ship_id) * (splited_areas != -1)] += 1
return agents_fields
def main(env: SeaEnv, agent: MultiGreedy):
state, info = env.reset()
redecide_direction = [True for _ in range(env.ship_num)]
steps = 0
while True:
steps += 1
info['redecide_direction'] = redecide_direction
agents_fields = generate_agent_fields(
info['fields'], env.ship_num, env.splited_areas)
agent_info = deepcopy(info)
agent_info['fields'] = agents_fields
agent_info['batteries'] = [env.get_battery(
ship_id) for ship_id in range(env.ship_num)]
acts, redecide_direction = agent.step(state, agent_info)
assert (acts != BROKEN).all(), f'Invalid: ships output BROKEN step.'
state, terminate, info = env.step(acts)
if terminate:
if info['finished']:
print(f'Congrates! covered all the places!')
repeat_steps = calculate_repeat_steps(info['fields'])
units_num = info['fields'].size
repeat_ratio = repeat_steps / units_num
print(
f'Used steps: {steps}\nrepeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
else:
print(f'Failed!')
return
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('-a', '--axis', type=str, default='x',
help='The axis for splitting areas')
parser.add_argument('-n', '--n-ships', type=int,
default=3, help='The number of ships.')
parser.add_argument('-b', '--battery-capacity', type=int, default=np.inf,
help='The num of steps a full battery can support.')
parser.add_argument('-f', '--filling-time', type=int, default=1,
help='The num of time steps needed for charging.')
parser.add_argument('--render', action='store_true',
help='render or not')
parser.add_argument('--cover-ratio', type=float, default=1.0)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sea_env = make(env_config=args.config_name,
render=args.render,
split_axis=args.axis,
n_ships=args.n_ships,
filling_time=args.filling_time,
battery_capacity=args.battery_capacity,
cover_ratio=args.cover_ratio)
wildfires = MultiWildFire(n_ships=sea_env.ship_num)
charges = MultiCharge(n_ships=args.n_ships, filling_time=args.filling_time)
algo = MultiGreedy(
ship_num=sea_env.ship_num, wildfires=wildfires, charges=charges)
main(sea_env, algo)
for _ in range(int(1e8)):
continue