forked from SeAIPalette/SeAIPalette
88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
from copy import deepcopy
|
|
|
|
from Palette.env import make, SeaEnv
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
|
from Palette.algos import split_area, Boustrophedon, Spiral, WildFire
|
|
from typing import Union
|
|
|
|
|
|
def calculate_repeat_steps(fields):
|
|
assert np.sum(fields == 0) == 0, f'Not finish!'
|
|
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
|
|
|
|
|
|
def main(env: SeaEnv, agent: Union[Boustrophedon, Spiral], wildfire: WildFire):
|
|
state, info = env.reset()
|
|
fields_to_split = deepcopy(info['fields'])
|
|
fields_to_split[fields_to_split > 0] = 0
|
|
splited_area = split_area(fields_to_split)
|
|
n_areas = int(np.max(splited_area))
|
|
|
|
for area_id in range(1, n_areas+1):
|
|
redecide_direction = True
|
|
cur_area_terminate = False
|
|
while True:
|
|
info['redecide_direction'] = redecide_direction
|
|
modified_fields = deepcopy(info['fields'])
|
|
modified_fields += ((splited_area != area_id)
|
|
* (splited_area != -1))
|
|
info_for_agent = deepcopy(info)
|
|
info_for_agent['fields'] = modified_fields
|
|
act = agent.step(state, info_for_agent)
|
|
redecide_direction = False
|
|
if act == FINISHED:
|
|
cur_area_terminate = True
|
|
break
|
|
elif act == BROKEN:
|
|
print(f'Broken!')
|
|
while True:
|
|
act = wildfire.step(state, info_for_agent)
|
|
state, cur_area_terminate, info = env.step(act)
|
|
if cur_area_terminate:
|
|
break
|
|
if wildfire.empty():
|
|
redecide_direction = True
|
|
break
|
|
else:
|
|
state, terminate, info = env.step(act)
|
|
if terminate:
|
|
if info['finished']:
|
|
print(f'Congrates! covered all the places!')
|
|
repeat_steps = calculate_repeat_steps(info['fields'])
|
|
units_num = info['fields'].size
|
|
repeat_ratio = repeat_steps / units_num
|
|
print(
|
|
f'repeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
|
|
else:
|
|
print(f'Failed!')
|
|
return
|
|
if cur_area_terminate:
|
|
break
|
|
|
|
|
|
def get_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-c', "--config_name", type=str, required=True,
|
|
help="The name of config file, e.g., map0")
|
|
parser.add_argument('-a', '--algo', type=str, choices=[
|
|
'b', 's'], default='b', help='The algorithm is Boustrophedon or Spiral.')
|
|
parser.add_argument('--render', action='store_true',
|
|
help='render or not')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
sea_env = make(env_config=args.config_name, render=args.render)
|
|
if args.algo == 'b':
|
|
algo = Boustrophedon()
|
|
elif args.algo == 's':
|
|
algo = Spiral()
|
|
else:
|
|
raise ValueError(f'invalid args.algo: {args.algo}')
|
|
wildfire = WildFire()
|
|
main(sea_env, algo, wildfire)
|