SeAIPalette/Palette/env/sea_env.py

159 lines
4.6 KiB
Python

import os
import yaml
import numpy as np
from argparse import ArgumentParser
from copy import deepcopy
from easydict import EasyDict
from typing import List
from Palette.algos.multi_split_area import multi_split_area
from Palette.env.engine import SeaEnvEngine
from Palette.constants import UP, DOWN, LEFT, RIGHT, FINISHED, BROKEN
class SeaEnv(object):
def __init__(
self,
cfg: EasyDict,
draw: bool,
n_ships: int,
battery_capacity: int,
split_axis: str = 'x',
filling_time: int = 1,
cover_ratio: float = 1.0
) -> None:
super().__init__()
self.draw = draw
self.cfg = cfg
self.split_axis = split_axis
if "max_step" in self.cfg:
self.max_step = self.cfg["max_step"]
else:
self.max_step = None
self.cur_step = 0
self.init_ships = False
self.n_ships = n_ships
self.battery_capacity = battery_capacity
self.filling_time = filling_time
self.cover_ratio = cover_ratio
@property
def ship_num(self):
return self.n_ships
@property
def fields(self):
return deepcopy(self.sea_env_engine.fields)
@property
def splited_areas(self):
return self.sea_env_engine.splited_areas
def step(self, action: List[int]):
cur_frame, done, info = self.sea_env_engine.frame_step(action)
self.cur_step += 1
if self.max_step is not None and self.cur_step >= self.max_step:
done = True
return cur_frame, done, info
def reset(self):
done = True
self.cur_step = 0
while done:
self.sea_env_engine = SeaEnvEngine(
draw_screen=self.draw,
init_ships=self.init_ships,
filling_time=self.filling_time,
cover_ratio=self.cover_ratio,
**self.cfg
)
self.sea_env_engine.splited_areas, start_points = multi_split_area(
fields=self.sea_env_engine.fields,
n_areas=self.n_ships,
axis=self.split_axis,
field_size=self.cfg.field_size[0],
)
self.sea_env_engine.create_ship(
start_position=start_points,
start_r=[1.0*np.pi for _ in range(self.n_ships)],
battery_capacity=self.battery_capacity
)
state, done, info = self.sea_env_engine.current()
return state, info
def seed(self, seed: int = None) -> List[int]:
np.random.seed(seed)
return [seed]
def render(self, mode='human'):
pass
def close(self):
pass
def create_ships(self, start_position, start_r):
self.sea_env_engine.create_ship(
start_position=start_position, start_r=start_r)
def get_battery(self, ship_id: int):
return self.sea_env_engine.get_battery(ship_id)
def make(
env_config: str,
render: bool,
split_axis: str,
n_ships: int,
battery_capacity: int = np.inf,
filling_time: int = 1,
cover_ratio: float = 1.0
):
with open(os.path.join('Palette', 'maps', f'{env_config}.yaml'), 'r') as f:
config = yaml.safe_load(f)
config = EasyDict(config)
env_config = config.env
env = SeaEnv(cfg=env_config, draw=render,
split_axis=split_axis, n_ships=n_ships,
battery_capacity=battery_capacity, filling_time=filling_time,
cover_ratio=cover_ratio)
return env
if __name__ == '__main__':
def get_args():
parser = ArgumentParser()
parser.add_argument('-c', "--config_name", type=str, required=True,
help="The name of config file, e.g., map0")
parser.add_argument('-a', '--axis', type=str,
default='x', help="The axis for splitting areas.")
parser.add_argument('--render', action='store_true',
help='render or not')
return parser.parse_args()
args = get_args()
sea_env = make(env_config=args.config_name,
render=args.render, split_axis=args.axis)
import random
steps = 0
acts = []
acts.extend([[UP, DOWN] for _ in range(49)])
acts.extend([[RIGHT, LEFT] for _ in range(5)])
acts.extend([[LEFT, RIGHT] for _ in range(5)])
acts.extend([[DOWN, UP] for _ in range(20)])
acts.extend([[RIGHT, LEFT] for _ in range(200)])
while True:
s, t, _ = sea_env.step(acts[steps])
for _ in range(99999999999):
pass
steps += 1
print(s)
if t:
break
print(steps)