Use gymnasium APIs

This commit is contained in:
Cem Gökmen 2023-11-22 17:14:04 -08:00
parent afb29ac76e
commit 16568cb016
11 changed files with 37 additions and 18 deletions

View File

@ -14,7 +14,7 @@ from math import ceil
import cv2
from matplotlib import pyplot as plt
import gym
import gymnasium as gym
import numpy as np
from scipy.spatial.transform import Rotation, Slerp

View File

@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import numpy as np
import omnigibson as og
@ -398,10 +398,11 @@ class Environment(gym.Env, GymObservable, Recreatable):
of actions
Returns:
4-tuple:
5-tuple:
- dict: state, i.e. next observation
- float: reward, i.e. reward at this current timestep
- bool: done, i.e. whether this episode is terminated
- bool: terminated, i.e. whether this episode ended due to a failure or success
- bool: truncated, i.e. whether this episode ended due to a time limit etc.
- dict: info, i.e. dictionary with any useful information
"""
try:
@ -440,10 +441,21 @@ class Environment(gym.Env, GymObservable, Recreatable):
info["last_observation"] = obs
obs = self.reset()
# Hacky way to check for time limit info to split terminated and truncated
terminated = False
truncated = False
for tc, tc_data in info["done"]["termination_conditions"].items():
if tc_data["done"]:
if tc == "timeout":
truncated = True
else:
terminated = True
assert (terminated or truncated) == done, "Terminated and truncated must match done!"
# Increment step
self._current_step += 1
return obs, reward, done, info
return obs, reward, terminated, truncated, info
except:
raise ValueError(f"Failed to execute environment step {self._current_step} in episode {self._current_episode}")

View File

@ -27,7 +27,8 @@ class EnvironmentWrapper(Wrapper):
4-tuple:
- (dict) observations from the environment
- (float) reward from the environment
- (bool) whether the current episode is completed or not
- (bool) whether the current episode is terminated
- (bool) whether the current episode is truncated
- (dict) misc information
"""
return self.env.step(action)

View File

@ -14,7 +14,7 @@ from omnigibson.macros import gm
from omnigibson.utils.python_utils import meets_minimum_version
try:
import gym
import gymnasium as gym
import torch as th
import torch.nn as nn
import tensorboard
@ -31,10 +31,10 @@ except ModuleNotFoundError:
"pip install torch\n"
"pip install stable-baselines3==1.7.0\n"
"pip install tensorboard\n"
"Also, please update gym to >=0.26.1 after installing sb3: pip install gym>=0.26.1")
"Also, please use gymnasium instead of gym: pip install gymnasium>=0.28.1")
exit(1)
assert meets_minimum_version(gym.__version__, "0.26.1"), "Please install/update gym to version >= 0.26.1"
assert meets_minimum_version(gym.__version__, "0.28.1"), "Please install/update gym to version >= 0.28.1"
# We don't need object states nor transitions rules, so we disable them now, and also enable flatcache for maximum speed
gm.ENABLE_OBJECT_STATES = False

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from copy import deepcopy
import numpy as np
import gym
import gymnasium as gym
from collections.abc import Iterable
import omnigibson as og
from omnigibson.objects.object_base import BaseObject

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
import gym
import gymnasium as gym
from omnigibson.robots.locomotion_robot import LocomotionRobot
from omnigibson.utils.python_utils import classproperty

View File

@ -2,7 +2,7 @@ from abc import ABCMeta
from omnigibson.prims.xform_prim import XFormPrim
from omnigibson.utils.python_utils import classproperty, assert_valid_key, Registerable
from omnigibson.utils.gym_utils import GymObservable
from gym.spaces import Space
import gymnasium as gym
# Registered sensors
@ -97,7 +97,7 @@ class BaseSensor(XFormPrim, GymObservable, Registerable, metaclass=ABCMeta):
obs_space = dict()
for modality, space in self._obs_space_mapping.items():
if modality in self._modalities:
if isinstance(space, Space):
if isinstance(space, gym.Space):
# Directly add this space
obs_space[modality] = space
else:

View File

@ -1,6 +1,6 @@
import numpy as np
import time
import gym
import gymnasium as gym
import omnigibson as og
from omnigibson.sensors.sensor_base import BaseSensor

View File

@ -201,16 +201,22 @@ class BaseTask(GymObservable, Registerable, metaclass=ABCMeta):
# Get all dones and successes from individual termination conditions
dones = []
successes = []
for termination_condition in self._termination_conditions.values():
info = dict() if info is None else info
if "termination_conditions" not in info:
info["termination_conditions"] = dict()
for name, termination_condition in self._termination_conditions.items():
d, s = termination_condition.step(self, env, action)
dones.append(d)
successes.append(s)
info["termination_conditions"][name] = {
"done": d,
"success": s,
}
# Any True found corresponds to a done / success
done = sum(dones) > 0
success = sum(successes) > 0
# Populate info
info = dict() if info is None else info
info["success"] = success
return done, info

View File

@ -1,4 +1,4 @@
import gym
import gymnasium as gym
from abc import ABCMeta, abstractmethod
import numpy as np

View File

@ -21,7 +21,7 @@ setup(
zip_safe=False,
packages=find_packages(),
install_requires=[
"gym>=0.26",
"gymnasium>=0.28.1",
"numpy>=1.20.0",
"GitPython",
"transforms3d>=0.3.1",