SeAIPalette/Palette/examples/wildfire_boustrophedon.py

77 lines
2.6 KiB
Python
Raw Normal View History

2022-05-30 21:36:26 +08:00
import numpy as np
from argparse import ArgumentParser
from Palette.env import make, SeaEnv
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
from Palette.algos import Boustrophedon, WildFire
def calculate_repeat_steps(fields):
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
def main(env: SeaEnv, agent: Boustrophedon, wildfire: WildFire):
state, info = env.reset()
redecide_direction = True
all_steps = 0
while True:
all_steps += 1
state = state[0]
info['redecide_direction'] = redecide_direction
act = agent.step(state, info)
redecide_direction = False
if act[0] == FINISHED:
print(f'Congrates! covered all the places!')
repeat_steps = calculate_repeat_steps(info['fields'])
print(f'{repeat_steps} steps repeated!')
return
elif act[0] == BROKEN:
print(f'Broken!')
while True:
all_steps += 1
act = wildfire.step(state, info)
state, terminate, info = env.step(act)
if terminate:
return
if wildfire.empty():
redecide_direction = True
break
all_steps -= 1
else:
state, terminate, info = env.step(act)
if terminate:
if info['finished']:
print(f'Congrates! covered all the places!')
print(f'steps: {all_steps}')
repeat_steps = calculate_repeat_steps(info['fields'])
units_num = info['fields'].size
repeat_ratio = repeat_steps / units_num
print(
f'repeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
else:
print(f'Failed!')
return
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('--render', action='store_true',
help='render or not')
parser.add_argument('--cover-ratio', type=float, default=1.0)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sea_env = make(env_config=args.config_name,
render=args.render, n_ships=1, split_axis='x',
cover_ratio=args.cover_ratio)
algo = Boustrophedon()
wildfire = WildFire()
main(sea_env, algo, wildfire)
for _ in range(int(1e8)):
continue