forked from SeAIPalette/SeAIPalette
159 lines
4.6 KiB
Python
159 lines
4.6 KiB
Python
import os
|
|
import yaml
|
|
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
from copy import deepcopy
|
|
from easydict import EasyDict
|
|
from typing import List
|
|
|
|
from Palette.algos.multi_split_area import multi_split_area
|
|
from Palette.env.engine import SeaEnvEngine
|
|
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
|
|
|
|
|
|
class SeaEnv(object):
|
|
def __init__(
|
|
self,
|
|
cfg: EasyDict,
|
|
draw: bool,
|
|
n_ships: int,
|
|
battery_capacity: int,
|
|
split_axis: str = 'x',
|
|
filling_time: int = 1,
|
|
cover_ratio: float = 1.0
|
|
) -> None:
|
|
super().__init__()
|
|
self.draw = draw
|
|
self.cfg = cfg
|
|
self.split_axis = split_axis
|
|
if "max_step" in self.cfg:
|
|
self.max_step = self.cfg["max_step"]
|
|
else:
|
|
self.max_step = None
|
|
self.cur_step = 0
|
|
self.init_ships = False
|
|
self.n_ships = n_ships
|
|
self.battery_capacity = battery_capacity
|
|
self.filling_time = filling_time
|
|
self.cover_ratio = cover_ratio
|
|
|
|
@property
|
|
def ship_num(self):
|
|
return self.n_ships
|
|
|
|
@property
|
|
def fields(self):
|
|
return deepcopy(self.sea_env_engine.fields)
|
|
|
|
@property
|
|
def splited_areas(self):
|
|
return self.sea_env_engine.splited_areas
|
|
|
|
def step(self, action: List[int]):
|
|
cur_frame, done, info = self.sea_env_engine.frame_step(action)
|
|
self.cur_step += 1
|
|
if self.max_step is not None and self.cur_step >= self.max_step:
|
|
done = True
|
|
return cur_frame, done, info
|
|
|
|
def reset(self):
|
|
done = True
|
|
self.cur_step = 0
|
|
while done:
|
|
self.sea_env_engine = SeaEnvEngine(
|
|
draw_screen=self.draw,
|
|
init_ships=self.init_ships,
|
|
filling_time=self.filling_time,
|
|
cover_ratio=self.cover_ratio,
|
|
**self.cfg
|
|
)
|
|
self.sea_env_engine.splited_areas, start_points = multi_split_area(
|
|
fields=self.sea_env_engine.fields,
|
|
n_areas=self.n_ships,
|
|
axis=self.split_axis,
|
|
field_size=self.cfg.field_size[0],
|
|
)
|
|
self.sea_env_engine.create_ship(
|
|
start_position=start_points,
|
|
start_r=[1.0*np.pi for _ in range(self.n_ships)],
|
|
battery_capacity=self.battery_capacity
|
|
)
|
|
state, done, info = self.sea_env_engine.current()
|
|
return state, info
|
|
|
|
def seed(self, seed: int = None) -> List[int]:
|
|
np.random.seed(seed)
|
|
return [seed]
|
|
|
|
def render(self, mode='human'):
|
|
pass
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def create_ships(self, start_position, start_r):
|
|
self.sea_env_engine.create_ship(
|
|
start_position=start_position, start_r=start_r)
|
|
|
|
def get_battery(self, ship_id: int):
|
|
return self.sea_env_engine.get_battery(ship_id)
|
|
|
|
|
|
def make(
|
|
env_config: str,
|
|
render: bool,
|
|
split_axis: str,
|
|
n_ships: int,
|
|
battery_capacity: int = np.inf,
|
|
filling_time: int = 1,
|
|
cover_ratio: float = 1.0
|
|
):
|
|
|
|
|
|
with open(os.path.join('Palette', 'maps', f'{env_config}.yaml'), 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
config = EasyDict(config)
|
|
env_config = config.env
|
|
env = SeaEnv(cfg=env_config, draw=render,
|
|
split_axis=split_axis, n_ships=n_ships,
|
|
battery_capacity=battery_capacity, filling_time=filling_time,
|
|
cover_ratio=cover_ratio)
|
|
return env
|
|
|
|
|
|
if __name__ == '__main__':
|
|
def get_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-c', "--config_name", type=str, required=True,
|
|
help="The name of config file, e.g., map0")
|
|
parser.add_argument('-a', '--axis', type=str,
|
|
default='x', help="The axis for splitting areas.")
|
|
parser.add_argument('--render', action='store_true',
|
|
help='render or not')
|
|
return parser.parse_args()
|
|
|
|
args = get_args()
|
|
sea_env = make(env_config=args.config_name,
|
|
render=args.render, split_axis=args.axis)
|
|
|
|
import random
|
|
steps = 0
|
|
acts = []
|
|
acts.extend([[UP, DOWN] for _ in range(49)])
|
|
acts.extend([[RIGHT, LEFT] for _ in range(5)])
|
|
acts.extend([[LEFT, RIGHT] for _ in range(5)])
|
|
acts.extend([[DOWN, UP] for _ in range(20)])
|
|
acts.extend([[RIGHT, LEFT] for _ in range(200)])
|
|
while True:
|
|
s, t, _ = sea_env.step(acts[steps])
|
|
for _ in range(99999999999):
|
|
pass
|
|
steps += 1
|
|
print(s)
|
|
|
|
if t:
|
|
break
|
|
|
|
print(steps)
|