Use gymnasium APIs
This commit is contained in:
parent
afb29ac76e
commit
16568cb016
|
@ -14,7 +14,7 @@ from math import ceil
|
|||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation, Slerp
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
import omnigibson as og
|
||||
|
@ -398,10 +398,11 @@ class Environment(gym.Env, GymObservable, Recreatable):
|
|||
of actions
|
||||
|
||||
Returns:
|
||||
4-tuple:
|
||||
5-tuple:
|
||||
- dict: state, i.e. next observation
|
||||
- float: reward, i.e. reward at this current timestep
|
||||
- bool: done, i.e. whether this episode is terminated
|
||||
- bool: terminated, i.e. whether this episode ended due to a failure or success
|
||||
- bool: truncated, i.e. whether this episode ended due to a time limit etc.
|
||||
- dict: info, i.e. dictionary with any useful information
|
||||
"""
|
||||
try:
|
||||
|
@ -440,10 +441,21 @@ class Environment(gym.Env, GymObservable, Recreatable):
|
|||
info["last_observation"] = obs
|
||||
obs = self.reset()
|
||||
|
||||
# Hacky way to check for time limit info to split terminated and truncated
|
||||
terminated = False
|
||||
truncated = False
|
||||
for tc, tc_data in info["done"]["termination_conditions"].items():
|
||||
if tc_data["done"]:
|
||||
if tc == "timeout":
|
||||
truncated = True
|
||||
else:
|
||||
terminated = True
|
||||
assert (terminated or truncated) == done, "Terminated and truncated must match done!"
|
||||
|
||||
# Increment step
|
||||
self._current_step += 1
|
||||
|
||||
return obs, reward, done, info
|
||||
return obs, reward, terminated, truncated, info
|
||||
except:
|
||||
raise ValueError(f"Failed to execute environment step {self._current_step} in episode {self._current_episode}")
|
||||
|
||||
|
|
|
@ -27,7 +27,8 @@ class EnvironmentWrapper(Wrapper):
|
|||
4-tuple:
|
||||
- (dict) observations from the environment
|
||||
- (float) reward from the environment
|
||||
- (bool) whether the current episode is completed or not
|
||||
- (bool) whether the current episode is terminated
|
||||
- (bool) whether the current episode is truncated
|
||||
- (dict) misc information
|
||||
"""
|
||||
return self.env.step(action)
|
||||
|
|
|
@ -14,7 +14,7 @@ from omnigibson.macros import gm
|
|||
from omnigibson.utils.python_utils import meets_minimum_version
|
||||
|
||||
try:
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import tensorboard
|
||||
|
@ -31,10 +31,10 @@ except ModuleNotFoundError:
|
|||
"pip install torch\n"
|
||||
"pip install stable-baselines3==1.7.0\n"
|
||||
"pip install tensorboard\n"
|
||||
"Also, please update gym to >=0.26.1 after installing sb3: pip install gym>=0.26.1")
|
||||
"Also, please use gymnasium instead of gym: pip install gymnasium>=0.28.1")
|
||||
exit(1)
|
||||
|
||||
assert meets_minimum_version(gym.__version__, "0.26.1"), "Please install/update gym to version >= 0.26.1"
|
||||
assert meets_minimum_version(gym.__version__, "0.28.1"), "Please install/update gym to version >= 0.28.1"
|
||||
|
||||
# We don't need object states nor transitions rules, so we disable them now, and also enable flatcache for maximum speed
|
||||
gm.ENABLE_OBJECT_STATES = False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
from collections.abc import Iterable
|
||||
import omnigibson as og
|
||||
from omnigibson.objects.object_base import BaseObject
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
|
||||
from omnigibson.robots.locomotion_robot import LocomotionRobot
|
||||
from omnigibson.utils.python_utils import classproperty
|
||||
|
|
|
@ -2,7 +2,7 @@ from abc import ABCMeta
|
|||
from omnigibson.prims.xform_prim import XFormPrim
|
||||
from omnigibson.utils.python_utils import classproperty, assert_valid_key, Registerable
|
||||
from omnigibson.utils.gym_utils import GymObservable
|
||||
from gym.spaces import Space
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
# Registered sensors
|
||||
|
@ -97,7 +97,7 @@ class BaseSensor(XFormPrim, GymObservable, Registerable, metaclass=ABCMeta):
|
|||
obs_space = dict()
|
||||
for modality, space in self._obs_space_mapping.items():
|
||||
if modality in self._modalities:
|
||||
if isinstance(space, Space):
|
||||
if isinstance(space, gym.Space):
|
||||
# Directly add this space
|
||||
obs_space[modality] = space
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import time
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
|
||||
import omnigibson as og
|
||||
from omnigibson.sensors.sensor_base import BaseSensor
|
||||
|
|
|
@ -201,16 +201,22 @@ class BaseTask(GymObservable, Registerable, metaclass=ABCMeta):
|
|||
# Get all dones and successes from individual termination conditions
|
||||
dones = []
|
||||
successes = []
|
||||
for termination_condition in self._termination_conditions.values():
|
||||
info = dict() if info is None else info
|
||||
if "termination_conditions" not in info:
|
||||
info["termination_conditions"] = dict()
|
||||
for name, termination_condition in self._termination_conditions.items():
|
||||
d, s = termination_condition.step(self, env, action)
|
||||
dones.append(d)
|
||||
successes.append(s)
|
||||
info["termination_conditions"][name] = {
|
||||
"done": d,
|
||||
"success": s,
|
||||
}
|
||||
# Any True found corresponds to a done / success
|
||||
done = sum(dones) > 0
|
||||
success = sum(successes) > 0
|
||||
|
||||
# Populate info
|
||||
info = dict() if info is None else info
|
||||
info["success"] = success
|
||||
return done, info
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import gym
|
||||
import gymnasium as gym
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
|
|
Loading…
Reference in New Issue