SeAIPalette/Palette/examples/split.py

90 lines
3.3 KiB
Python
Raw Permalink Normal View History

2022-05-30 21:36:26 +08:00
import numpy as np
from argparse import ArgumentParser
from copy import deepcopy
from Palette.env import make, SeaEnv
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos import split_area, Boustrophedon, Spiral, WildFire
from typing import Union
def calculate_repeat_steps(fields):
assert np.sum(fields == 0) == 0, f'Not finish!'
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
def main(env: SeaEnv, agent: Union[Boustrophedon, Spiral], wildfire: WildFire):
state, info = env.reset()
fields_to_split = deepcopy(info['fields'])
fields_to_split[fields_to_split > 0] = 0
splited_area = split_area(fields_to_split)
n_areas = int(np.max(splited_area))
for area_id in range(1, n_areas+1):
redecide_direction = True
cur_area_terminate = False
while True:
# for _ in range(9999999):
# continue
info['redecide_direction'] = redecide_direction
modified_fields = deepcopy(info['fields'])
modified_fields += ((splited_area != area_id)
* (splited_area != -1))
info_for_agent = deepcopy(info)
info_for_agent['fields'] = modified_fields
act = agent.step(state, info_for_agent)
redecide_direction = False
if act == FINISHED:
cur_area_terminate = True
break
elif act == BROKEN:
print(f'Broken!')
while True:
act = wildfire.step(state, info_for_agent)
state, cur_area_terminate, info = env.step(act)
if cur_area_terminate:
break
if wildfire.empty():
redecide_direction = True
break
else:
state, terminate, info = env.step(act)
if terminate:
if info['finished']:
print(f'Congrates! covered all the places!')
repeat_steps = calculate_repeat_steps(info['fields'])
units_num = info['fields'].size
repeat_ratio = repeat_steps / units_num
print(
f'repeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
else:
print(f'Failed!')
return
if cur_area_terminate:
break
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('-a', '--algo', type=str, choices=[
'b', 's'], default='b', help='The algorithm is Boustrophedon or Spiral.')
parser.add_argument('--render', action='store_true',
help='render or not')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sea_env = make(env_config=args.config_name, render=args.render)
if args.algo == 'b':
algo = Boustrophedon()
elif args.algo == 's':
algo = Spiral()
else:
raise ValueError(f'invalid args.algo: {args.algo}')
wildfire = WildFire()
main(sea_env, algo, wildfire)