305 lines
13 KiB
Python
305 lines
13 KiB
Python
import numpy as np
|
|
import pytest
|
|
import torch as th
|
|
|
|
import omnigibson as og
|
|
import omnigibson.utils.transform_utils as T
|
|
from omnigibson.robots import LocomotionRobot
|
|
|
|
|
|
def test_arm_control():
|
|
# Create env
|
|
cfg = {
|
|
"scene": {
|
|
"type": "Scene",
|
|
},
|
|
"objects": [],
|
|
"robots": [
|
|
{
|
|
"type": "FrankaPanda",
|
|
"obs_modalities": "rgb",
|
|
"position": [150, 150, 100],
|
|
"orientation": [0, 0, 0, 1],
|
|
"action_normalize": False,
|
|
},
|
|
{
|
|
"type": "Fetch",
|
|
"obs_modalities": "rgb",
|
|
"position": [150, 150, 105],
|
|
"orientation": [0, 0, 0, 1],
|
|
"action_normalize": False,
|
|
},
|
|
{
|
|
"type": "Tiago",
|
|
"obs_modalities": "rgb",
|
|
"position": [150, 150, 110],
|
|
"orientation": [0, 0, 0, 1],
|
|
"action_normalize": False,
|
|
},
|
|
{
|
|
"type": "A1",
|
|
"obs_modalities": "rgb",
|
|
"position": [150, 150, 115],
|
|
"orientation": [0, 0, 0, 1],
|
|
"action_normalize": False,
|
|
},
|
|
{
|
|
"type": "R1",
|
|
"obs_modalities": "rgb",
|
|
"position": [150, 150, 120],
|
|
"orientation": [0, 0, 0, 1],
|
|
"action_normalize": False,
|
|
},
|
|
],
|
|
}
|
|
env = og.Environment(configs=cfg)
|
|
|
|
# Define error functions to use
|
|
def check_zero_error(curr_position, init_position, tol=1e-2):
|
|
return th.norm(curr_position - init_position).item() < tol
|
|
|
|
def check_forward_error(curr_position, init_position, tol=1e-2, forward_tol=1e-2):
|
|
# x should be positive
|
|
return (curr_position[0] - init_position[0]).item() > forward_tol and th.norm(
|
|
curr_position[[1, 2]] - init_position[[1, 2]]
|
|
).item() < tol
|
|
|
|
def check_side_error(curr_position, init_position, tol=1e-2, side_tol=1e-2):
|
|
# y should be positive
|
|
return (curr_position[1] - init_position[1]).item() > side_tol and th.norm(
|
|
curr_position[[0, 2]] - init_position[[0, 2]]
|
|
).item() < tol
|
|
|
|
def check_up_error(curr_position, init_position, tol=1e-2, up_tol=1e-2):
|
|
# z should be positive
|
|
return (curr_position[2] - init_position[2]).item() > up_tol and th.norm(
|
|
curr_position[[0, 1]] - init_position[[0, 1]]
|
|
).item() < tol
|
|
|
|
def check_ori_error(curr_orientation, init_orientation, tol=0.1):
|
|
ori_err_normalized = th.norm(
|
|
T.quat2axisangle(T.mat2quat(T.quat2mat(init_orientation).T @ T.quat2mat(curr_orientation)))
|
|
).item() / (np.pi * 2)
|
|
ori_err = np.abs(np.pi * 2 * (np.round(ori_err_normalized) - ori_err_normalized))
|
|
return ori_err < tol
|
|
|
|
# All functions take in (target, curr, init) tuple
|
|
err_checks = {
|
|
"pose_delta_ori": {
|
|
"zero": {
|
|
"pos": lambda target, curr, init: check_zero_error(curr, init),
|
|
"ori": lambda target, curr, init: check_ori_error(curr, init),
|
|
},
|
|
"forward": {
|
|
"pos": lambda target, curr, init: check_forward_error(curr, init),
|
|
"ori": lambda target, curr, init: check_ori_error(curr, init),
|
|
},
|
|
"side": {
|
|
"pos": lambda target, curr, init: check_side_error(curr, init),
|
|
"ori": lambda target, curr, init: check_ori_error(curr, init),
|
|
},
|
|
"up": {
|
|
"pos": lambda target, curr, init: check_up_error(curr, init),
|
|
"ori": lambda target, curr, init: check_ori_error(curr, init),
|
|
},
|
|
"rotate": {
|
|
"pos": lambda target, curr, init: check_zero_error(curr, init),
|
|
"ori": None,
|
|
},
|
|
"base_move": {
|
|
"pos": lambda target, curr, init: check_zero_error(
|
|
curr, init, tol=0.02
|
|
), # Slightly bigger tolerance with base moving
|
|
"ori": lambda target, curr, init: check_ori_error(curr, init),
|
|
},
|
|
},
|
|
"absolute_pose": {
|
|
"zero": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr),
|
|
},
|
|
"forward": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr),
|
|
},
|
|
"side": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr),
|
|
},
|
|
"up": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr),
|
|
},
|
|
"rotate": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr, tol=0.05),
|
|
},
|
|
"base_move": {
|
|
"pos": lambda target, curr, init: check_zero_error(target, curr, tol=5e-3),
|
|
"ori": lambda target, curr, init: check_ori_error(target, curr),
|
|
},
|
|
},
|
|
}
|
|
|
|
n_steps = {
|
|
"pose_delta_ori": {
|
|
"zero": 10,
|
|
"forward": 10,
|
|
"side": 10,
|
|
"up": 10,
|
|
"rotate": 10,
|
|
"base_move": 30,
|
|
},
|
|
"absolute_pose": {
|
|
"zero": 50,
|
|
"forward": 50,
|
|
"side": 50,
|
|
"up": 50,
|
|
"rotate": 50,
|
|
"base_move": 50,
|
|
},
|
|
}
|
|
|
|
# Position the robots, reset them, and keep them still
|
|
for i, robot in enumerate(env.robots):
|
|
robot.set_position_orientation(
|
|
position=th.tensor([0.0, i * 5.0, 0.0]), orientation=T.euler2quat(th.tensor([0.0, 0.0, np.pi / 3]))
|
|
)
|
|
robot.reset()
|
|
robot.keep_still()
|
|
|
|
# Take 10 steps to stabilize
|
|
for _ in range(10):
|
|
og.sim.step()
|
|
|
|
# Update initial state
|
|
env.scene.update_initial_state()
|
|
|
|
# Reset the environment and keep all robots still
|
|
env.reset()
|
|
for i, robot in enumerate(env.robots):
|
|
robot.keep_still()
|
|
|
|
# Record initial eef pose of all robots
|
|
initial_eef_pose = dict()
|
|
for i, robot in enumerate(env.robots):
|
|
initial_eef_pose[robot.name] = {arm: robot.get_relative_eef_pose(arm=arm) for arm in robot.arm_names}
|
|
|
|
for controller in ["InverseKinematicsController", "OperationalSpaceController"]:
|
|
for controller_mode in ["pose_delta_ori", "absolute_pose"]:
|
|
controller_kwargs = {
|
|
"mode": controller_mode,
|
|
}
|
|
if controller_mode == "absolute_pose":
|
|
controller_kwargs["command_input_limits"] = None
|
|
controller_kwargs["command_output_limits"] = None
|
|
actions = {
|
|
"zero": dict(),
|
|
"forward": dict(),
|
|
"side": dict(),
|
|
"up": dict(),
|
|
"rotate": dict(),
|
|
"base_move": dict(),
|
|
}
|
|
|
|
for i, robot in enumerate(env.robots):
|
|
controller_config = {f"arm_{arm}": {"name": controller, **controller_kwargs} for arm in robot.arm_names}
|
|
robot.reload_controllers(controller_config)
|
|
|
|
# Define actions to use
|
|
zero_action = th.zeros(robot.action_dim)
|
|
forward_action = th.zeros(robot.action_dim)
|
|
side_action = th.zeros(robot.action_dim)
|
|
up_action = th.zeros(robot.action_dim)
|
|
rot_action = th.zeros(robot.action_dim)
|
|
|
|
# Populate actions and targets
|
|
for arm in robot.arm_names:
|
|
c_name = f"arm_{arm}"
|
|
start_idx = 0
|
|
init_eef_pos, init_eef_quat = initial_eef_pose[robot.name][arm]
|
|
for c in robot.controller_order:
|
|
if c == c_name:
|
|
break
|
|
start_idx += robot.controllers[c].command_dim
|
|
if controller_mode == "pose_delta_ori":
|
|
forward_action[start_idx] = 0.1
|
|
side_action[start_idx + 1] = 0.1
|
|
up_action[start_idx + 2] = 0.1
|
|
rot_action[start_idx + 3] = 0.1
|
|
elif controller_mode == "absolute_pose":
|
|
for act in [zero_action, forward_action, side_action, up_action, rot_action]:
|
|
act[start_idx : start_idx + 3] = init_eef_pos.clone()
|
|
act[start_idx + 3 : start_idx + 6] = T.quat2axisangle(init_eef_quat.clone())
|
|
forward_action[start_idx] += 0.1
|
|
side_action[start_idx + 1] += 0.1
|
|
up_action[start_idx + 2] += 0.1
|
|
rot_action[start_idx + 3 : start_idx + 6] = T.quat2axisangle(
|
|
T.quat_multiply(T.euler2quat(th.tensor([th.pi / 10, 0, 0])), init_eef_quat.clone())
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Got invalid controller mode: {controller}")
|
|
actions["zero"][robot.name] = zero_action
|
|
actions["forward"][robot.name] = forward_action
|
|
actions["side"][robot.name] = side_action
|
|
actions["up"][robot.name] = up_action
|
|
actions["rotate"][robot.name] = rot_action
|
|
|
|
# Add base movement action if locomotion robot
|
|
base_move_action = zero_action.clone()
|
|
if isinstance(robot, LocomotionRobot):
|
|
c_name = "base"
|
|
start_idx = 0
|
|
for c in robot.controller_order:
|
|
if c == c_name:
|
|
break
|
|
start_idx += robot.controllers[c].command_dim
|
|
base_move_action[start_idx] = 0.1
|
|
actions["base_move"][robot.name] = base_move_action
|
|
|
|
# For each action set, reset all robots, then run actions and see if arm moves in expected way
|
|
for action_name, action in actions.items():
|
|
# Reset the environment and keep all robots still
|
|
env.reset()
|
|
for i, robot in enumerate(env.robots):
|
|
robot.keep_still()
|
|
|
|
# Take N steps with given action and check for error
|
|
for _ in range(n_steps[controller_mode][action_name]):
|
|
env.step(action)
|
|
|
|
for i, robot in enumerate(env.robots):
|
|
for arm in robot.arm_names:
|
|
# Make sure no arm joints are at their limit
|
|
normalized_qpos = robot.get_joint_positions(normalized=True)[robot.arm_control_idx[arm]]
|
|
assert not th.any(th.abs(normalized_qpos) == 1.0), (
|
|
f"controller [{controller}], mode [{controller_mode}], robot [{robot.model_name}], arm [{arm}], action [{action_name}]:\n"
|
|
f"Some joints are at their limit (normalized values): {normalized_qpos}"
|
|
)
|
|
|
|
init_pos, init_quat = initial_eef_pose[robot.name][arm]
|
|
curr_pos, curr_quat = robot.get_relative_eef_pose(arm=arm)
|
|
arm_controller = robot.controllers[f"arm_{arm}"]
|
|
arm_goal = arm_controller.goal
|
|
target_pos = arm_goal["target_pos"]
|
|
target_quat = (
|
|
arm_goal["target_quat"]
|
|
if controller == "InverseKinematicsController"
|
|
else T.mat2quat(arm_goal["target_ori_mat"])
|
|
)
|
|
pos_check = err_checks[controller_mode][action_name]["pos"]
|
|
if pos_check is not None:
|
|
is_valid_pos = pos_check(target_pos, curr_pos, init_pos)
|
|
assert is_valid_pos, (
|
|
f"Got mismatch for controller [{controller}], mode [{controller_mode}], robot [{robot.model_name}], action [{action_name}]\n"
|
|
f"target_pos: {target_pos}, curr_pos: {curr_pos}, init_pos: {init_pos}"
|
|
)
|
|
ori_check = err_checks[controller_mode][action_name]["ori"]
|
|
if ori_check is not None:
|
|
is_valid_ori = ori_check(target_quat, curr_quat, init_quat)
|
|
assert is_valid_ori, (
|
|
f"Got mismatch for controller [{controller}], mode [{controller_mode}], robot [{robot.model_name}], action [{action_name}]\n"
|
|
f"target_quat: {target_quat}, curr_quat: {curr_quat}, init_quat: {init_quat}"
|
|
)
|