forked from SeAIPalette/SeAIPalette
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
|
import numpy as np
|
||
|
|
||
|
from argparse import ArgumentParser
|
||
|
|
||
|
from Palette.env import make, SeaEnv
|
||
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
||
|
from Palette.algos import Boustrophedon, WildFire
|
||
|
|
||
|
|
||
|
def calculate_repeat_steps(fields):
|
||
|
return int(np.sum(fields[fields > 0]) - np.sum(fields > 0))
|
||
|
|
||
|
|
||
|
def main(env: SeaEnv, agent: Boustrophedon, wildfire: WildFire):
|
||
|
state, info = env.reset()
|
||
|
redecide_direction = True
|
||
|
all_steps = 0
|
||
|
while True:
|
||
|
all_steps += 1
|
||
|
state = state[0]
|
||
|
info['redecide_direction'] = redecide_direction
|
||
|
act = agent.step(state, info)
|
||
|
redecide_direction = False
|
||
|
if act[0] == FINISHED:
|
||
|
print(f'Congrates! covered all the places!')
|
||
|
repeat_steps = calculate_repeat_steps(info['fields'])
|
||
|
print(f'{repeat_steps} steps repeated!')
|
||
|
return
|
||
|
elif act[0] == BROKEN:
|
||
|
print(f'Broken!')
|
||
|
while True:
|
||
|
all_steps += 1
|
||
|
act = wildfire.step(state, info)
|
||
|
state, terminate, info = env.step(act)
|
||
|
if terminate:
|
||
|
return
|
||
|
if wildfire.empty():
|
||
|
redecide_direction = True
|
||
|
break
|
||
|
all_steps -= 1
|
||
|
else:
|
||
|
state, terminate, info = env.step(act)
|
||
|
if terminate:
|
||
|
if info['finished']:
|
||
|
print(f'Congrates! covered all the places!')
|
||
|
print(f'steps: {all_steps}')
|
||
|
repeat_steps = calculate_repeat_steps(info['fields'])
|
||
|
units_num = info['fields'].size
|
||
|
repeat_ratio = repeat_steps / units_num
|
||
|
print(
|
||
|
f'repeat steps: {repeat_steps}\nrepeat ratio: {repeat_steps}/{units_num}={repeat_ratio*100:.2f}%')
|
||
|
else:
|
||
|
print(f'Failed!')
|
||
|
return
|
||
|
|
||
|
|
||
|
def get_args():
|
||
|
parser = ArgumentParser()
|
||
|
parser.add_argument('-c', "--config_name", type=str, required=True,
|
||
|
help="The name of config file, e.g., map0")
|
||
|
parser.add_argument('--render', action='store_true',
|
||
|
help='render or not')
|
||
|
parser.add_argument('--cover-ratio', type=float, default=1.0)
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = get_args()
|
||
|
sea_env = make(env_config=args.config_name,
|
||
|
render=args.render, n_ships=1, split_axis='x',
|
||
|
cover_ratio=args.cover_ratio)
|
||
|
algo = Boustrophedon()
|
||
|
wildfire = WildFire()
|
||
|
main(sea_env, algo, wildfire)
|
||
|
|
||
|
for _ in range(int(1e8)):
|
||
|
continue
|