SeAIPalette/Palette/examples/wildfire_boustrophedon.py

77 lines
2.6 KiB
Python

import numpy as np
from argparse import ArgumentParser
from Palette.env import make, SeaEnv
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos import Boustrophedon, WildFire
def calculate_repeat_steps(fields):
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
def main(env: SeaEnv, agent: Boustrophedon, wildfire: WildFire):
state, info = env.reset()
redecide_direction = True
all_steps = 0
while True:
all_steps += 1
state = state[0]
info['redecide_direction'] = redecide_direction
act = agent.step(state, info)
redecide_direction = False
if act[0] == FINISHED:
print(f'Congrates! covered all the places!')
repeat_steps = calculate_repeat_steps(info['fields'])
print(f'{repeat_steps} steps repeated!')
return
elif act[0] == BROKEN:
print(f'Broken!')
while True:
all_steps += 1
act = wildfire.step(state, info)
state, terminate, info = env.step(act)
if terminate:
return
if wildfire.empty():
redecide_direction = True
break
all_steps -= 1
else:
state, terminate, info = env.step(act)
if terminate:
if info['finished']:
print(f'Congrates! covered all the places!')
print(f'steps: {all_steps}')
repeat_steps = calculate_repeat_steps(info['fields'])
units_num = info['fields'].size
repeat_ratio = repeat_steps / units_num
print(
f'repeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
else:
print(f'Failed!')
return
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('--render', action='store_true',
help='render or not')
parser.add_argument('--cover-ratio', type=float, default=1.0)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sea_env = make(env_config=args.config_name,
render=args.render, n_ships=1, split_axis='x',
cover_ratio=args.cover_ratio)
algo = Boustrophedon()
wildfire = WildFire()
main(sea_env, algo, wildfire)
for _ in range(int(1e8)):
continue