From 3cbbc214319a48749ebfdf41e77565965622ab74 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 17:18:42 +0800 Subject: [PATCH 01/49] Migrate the config of tasks --- .../pour_water_agent_v3/agent_config.json | 31 ++ .../agent_config_dual.json | 31 ++ .../pour_water_agent_v3/fast_gym_config.json | 386 ++++++++++++++++++ .../rearrangement_agent_v3/agent_config.json | 31 ++ .../fast_gym_config.json | 378 +++++++++++++++++ 5 files changed, 857 insertions(+) create mode 100644 configs/gym/agent/pour_water_agent_v3/agent_config.json create mode 100644 configs/gym/agent/pour_water_agent_v3/agent_config_dual.json create mode 100644 configs/gym/agent/pour_water_agent_v3/fast_gym_config.json create mode 100644 configs/gym/agent/rearrangement_agent_v3/agent_config.json create mode 100644 configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config.json b/configs/gym/agent/pour_water_agent_v3/agent_config.json new file mode 100644 index 0000000..3172f3c --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/agent_config.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt_for_correction" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "PourWaterAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} \ No newline at end of file diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json new file mode 100644 index 0000000..53e707c --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt_for_correction" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "DualPourWaterAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} \ No newline at end of file diff --git a/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json new file mode 100644 index 0000000..b06bbb7 --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json @@ -0,0 +1,386 @@ +{ + "id": "PourWaterAgent-v3", + "max_episodes": 5, + "env": { + "events": { + "init_table_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "table"}, + "position_range": [[0.0, 0.0, -0.04], [0.0, 0.0, 0.04]], + "relative_position": true + } + }, + "init_bottle_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "bottle"}, + "position_range": [[-0.08, -0.12, 0.0], [0.08, 0.04, 0.0]], + "relative_position": true + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cup"}, + "position_range": [[-0.08, -0.04, 0.0], [0.04, 0.06, 0.0]], + "relative_position": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "bottle" + }, + "value": [[ + [0.32243, 0.03245, 0.94604, 0.025], + [0.00706, -0.99947, 0.03188, -0.0 ], + [0.94657, -0.0036 , -0.32249, 0.0 ], + [0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "cup" + }, + "value": [[ + [ 0.32039, -0.03227, 0.94673, 0.0 ], + [ 0.00675, -0.99932, -0.03635, 0.0 ], + [ 0.94726, 0.01803, -0.31996, 0.0 ], + [ 0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "left_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "left_arm_base", + "to_matrix": true + } + }, + { + "name": "right_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "right_arm_base", + "to_matrix": true + } + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "bottle" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "cup" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "cup", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "bottle", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_table_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/WoodTable" + } + }, + "random_bottle_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "bottle"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Bottle" + } + }, + "random_cup_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "cup"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Cup" + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + } + }, + "dataset": { + "instruction": { + "lang": "Pour water from the bottle into the mug." + }, + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 5 + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": false, + "enable_depth": false, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.7186357573287919, -0.054534732904795505, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "right_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "left_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_1", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [1.7, 0, 2.3], + "target": [0.6, 0, 0.8] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_2", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.8], + "target": [0.7, 0, 0.9] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_3", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.3], + "target": [0.7, 0, 0.9] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 100.0, + "init_pos": [2, 0, 2], + "radius": 20.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid":"cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, 0.1, 0.9], + "body_scale":[0.75, 0.75, 1.0], + "max_convex_hull_num": 8 + }, + { + "uid":"bottle", + "shape": { + "shape_type": "Mesh", + "fpath": "ScannedBottle/kashijia_processed.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, -0.1, 0.932], + "body_scale":[1, 1, 1], + "max_convex_hull_num": 8 + } + ] +} \ No newline at end of file diff --git a/configs/gym/agent/rearrangement_agent_v3/agent_config.json b/configs/gym/agent/rearrangement_agent_v3/agent_config.json new file mode 100644 index 0000000..5394169 --- /dev/null +++ b/configs/gym/agent/rearrangement_agent_v3/agent_config.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt_for_correction" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "RearrangementAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} diff --git a/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json new file mode 100644 index 0000000..117c4e5 --- /dev/null +++ b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json @@ -0,0 +1,378 @@ +{ + "id": "RearrangementAgent-v3", + "max_episodes": 10, + "env": { + "events": { + "init_table_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "table"}, + "position_range": [[0.0, 0.0, -0.04], [0.0, 0.0, 0.04]], + "relative_position": true + } + }, + "init_fork_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "fork"}, + "position_range": [[0.0, -0.05, 0.0], [0.1, 0.05, 0.0]], + "rotation_range": [[0, 0, -45], [0, 0, 45]], + "relative_position": true, + "relative_rotation": true + } + }, + "init_spoon_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "spoon"}, + "position_range": [[0.0, -0.05, 0.0], [0.1, 0.05, 0.0]], + "rotation_range": [[0, 0, -45], [0, 0, 45]], + "relative_position": true, + "relative_rotation": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_uids": ["fork", "spoon"], + "value": [[ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]] + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "plate" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "fork" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "spoon" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "spoon", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "fork", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_table_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/WoodTable" + } + }, + "random_plate_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "plate"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Plate" + } + }, + "random_fork_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "fork"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Spoon" + } + }, + "random_spoon_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "spoon"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Spoon" + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + } + }, + "dataset": { + "instruction": { + "lang": "Place the spoon and fork neatly into the plate on the table." + }, + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 125 + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": false, + "enable_depth": false, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.7186357573287919, -0.054534732904795505, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "right_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "left_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_1", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [1.7, 0, 2.3], + "target": [0.6, 0, 0.8] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_2", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.8], + "target": [0.7, 0, 0.9] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_3", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.3], + "target": [0.7, 0, 0.9] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 100.0, + "init_pos": [2, 0, 2], + "radius": 20.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 2.0, + "dynamic_friction": 1.5, + "restitution": 0.1 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.691], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid": "plate", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/plate/3_center.ply", + "compute_uv": true + }, + "body_scale": [0.001, 0.001, 0.001], + "init_pos": [0.5, 0.0, 1.0], + "init_rot": [180, 0, 0] + }, + { + "uid": "fork", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/fork/standard_fork_scale.ply", + "compute_uv": true + }, + "body_scale": [1.0, 1.0, 1.0], + "init_pos": [0.5, 0.21, 1.0], + "attrs" : { + "mass": 0.01, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.05, + "min_position_iters": 16, + "min_velocity_iters": 4 + } + }, + { + "uid": "spoon", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/spoon/standard_spoon_a_rescale.ply", + "compute_uv": true + }, + "body_scale": [1.0, 1.0, 1.0], + "init_pos": [0.5, -0.21, 1.0], + "attrs" : { + "mass": 0.01, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.05, + "min_position_iters": 16, + "min_velocity_iters": 4 + } + } + ] +} From f4dbddf084798ca5206b1d0803a18f0984ffe311 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 18:10:34 +0800 Subject: [PATCH 02/49] Migrate execute script --- embodichain/lab/scripts/run_agent.py | 267 +++++++++++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 embodichain/lab/scripts/run_agent.py diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py new file mode 100644 index 0000000..a10f27c --- /dev/null +++ b/embodichain/lab/scripts/run_agent.py @@ -0,0 +1,267 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import gymnasium +import numpy as np +import argparse +import os +import torch + +from threading import Thread +from tqdm import tqdm +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.lab.scripts.generate_video import visualize_data_dict +from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, +) +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.lab.sim.cfg import MarkerCfg + + +def generate_function( + env, + time_id: int = 0, + online_training: bool = False, + save_path: str = "", + save_video: bool = False, + debug_mode: bool = False, + regenerate: bool = True, + **kwargs, +): + """ + Generate and execute a sequence of actions in the environment. + + This function resets the environment, generates and executes action trajectories, + collects data, and optionally saves videos of the episodes. It supports both online + and offline data generation modes. + + Args: + env: The environment instance. + time_id (int, optional): Identifier for the current time step or episode. + online_training (bool, optional): Whether to use online data generation. + save_path (str, optional): Path to save generated videos. + save_video (bool, optional): Whether to save episode videos. + debug_mode (bool, optional): Enable debug mode for visualization and logging. + regenerate (bool, optional): Whether enable regenerating if existed. + **kwargs: Additional keyword arguments for data generation. + + Returns: + list or bool: Returns a list of data dicts if online_training is True, + otherwise returns True if generation is successful. + """ + + def wait_for_threads(threads): + for t in threads: + t.join() + + vis_threads = [] + + while True: # repeat until success + env.reset() + + ret = [] + trajectory_idx = 0 + + env.create_demo_action_list(regenerate=regenerate) + + # --------------------------------------------------------- + # SUCCESS CASE + # --------------------------------------------------------- + if not debug_mode and env.is_task_success().item(): + + dataset_id = f"time_{time_id}_trajectory_{trajectory_idx}" + + # online training: dataset may not be saved every iteration + if online_training: + dataset_id += "_online_generated" + num_samples = kwargs.get("num_samples", 0) + is_save_dataset = time_id < num_samples + + data_dict = env.to_dataset(id=dataset_id if is_save_dataset else None) + ret.append(data_dict) + else: + data_dict = env.to_dataset(id=dataset_id) + + # episode id + episode = getattr(env, "get_current_episode", lambda: time_id)() + + # video saving + if save_video: + video_path = os.path.join(save_path, f"episode_{episode}") + if online_training: + vis_thread = Thread( + target=visualize_data_dict, + args=(data_dict["data"], video_path), + daemon=True, + ) + vis_thread.start() + vis_threads.append(vis_thread) + else: + visualize_data_dict(data_dict["data"], video_path) + + break # success + + # --------------------------------------------------------- + # FAILURE CASE + # --------------------------------------------------------- + else: + log_warning("Task fail, Skip to next generation and retry.") + continue # retry until success + + wait_for_threads(vis_threads) + return ret if online_training else True + + +def main(args, env, gym_config): + is_online_training = os.path.exists(args.online_config) + if is_online_training: + + log_info("Start online data generation.", color="green") + assert os.path.exists(args.online_config), "{} does not exist.".format( + args.online_config + ) + + online_config = load_json(args.online_config) + online_callback = OnlineGenerator(**online_config) + + generator_func = lambda time_id, **kwargs: generate_function( + env, + time_id, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + regenerate=args.regenerate, + **kwargs, + ) + online_callback.generator(generator_func, **online_config) + else: + log_info("Start offline data generation.", color="green") + for i in range(gym_config["max_episodes"]): + generate_function( + env, + i, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + debug_mode=args.debug_mode, + regenerate=args.regenerate, + ) + + if args.headless: + env.reset(options={"final": True}) + + +if __name__ == "__main__": + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--render_backend", + help="The rendering backend to use for the simulation.", + default="egl", + type=str, + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--save_video", + help="Whether to save data as video.", + default=False, + action="store_true", + ) + parser.add_argument( + "--save_path", help="path", default="./outputs/thirdviewvideo", type=str + ) + parser.add_argument( + "--debug_mode", + help="Enable debug mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + parser.add_argument("--online_config", type=str, help="online_config", default="") + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument( + "--task_name", type=str, help="Name of the task.", required=True + ) + + # Agent related configs + parser.add_argument( + "--agent_config", type=str, help="agent_config", default=None, required=True + ) + parser.add_argument( + "--regenerate", + type=bool, + help="Whether regenerate code if already existed.", + default=False, + ) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + agent_config = load_json(args.agent_config) + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make( + id=gym_config["id"], + cfg=cfg, + agent_config=agent_config, + task_name=args.task_name, + ) + main(args, env, gym_config) From 82869982ddb262a99f5da7a61caf3f1ae423bf53 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 18:15:22 +0800 Subject: [PATCH 03/49] Migrate base_agent_env --- .../envs/tasks/tableware/base_agent_env.py | 395 ++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py new file mode 100644 index 0000000..a0e106c --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -0,0 +1,395 @@ +import torch +from embodichain.utils import logger +import traceback +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import os +from embodichain.toolkits.interfaces import extract_drive_calls, draw_axis +from embodichain.agents.hierarchy.code_agent import format_execution_history +from embodichain.agents.hierarchy.validation_agent import ( + save_obs_image, + get_obj_position_info, +) + + +class BaseAgentEnv: + + def _init_agents(self, agent_config, task_name): + from embodichain.agents.hierarchy.task_agent import TaskAgent + from embodichain.agents.hierarchy.code_agent import CodeAgent + from embodichain.agents.hierarchy.validation_agent import ValidationAgent + from embodichain.agents.hierarchy.llm import ( + create_llm, + task_llm, + code_llm, + validation_llm, + ) + + if agent_config.get("TaskAgent") is not None: + self.task_agent = TaskAgent( + task_llm, + **agent_config["Agent"], + **agent_config["TaskAgent"], + task_name=task_name, + ) + self.code_agent = CodeAgent( + code_llm, + **agent_config["Agent"], + **agent_config.get("CodeAgent"), + task_name=task_name, + ) + self.validation_agent = ValidationAgent( + validation_llm, + task_name=task_name, + task_description=self.code_agent.prompt_kwargs.get("task_prompt")[ + "content" + ], + basic_background=self.code_agent.prompt_kwargs.get("basic_background")[ + "content" + ], + atom_actions=self.code_agent.prompt_kwargs.get("atom_actions")["content"], + ) + + def get_states(self): + # TODO: only support num_env = 1 for now + # store robot states in each env.reset + self.init_qpos = self.robot.get_qpos().squeeze(0) + + self.left_arm_joints = self.robot.get_joint_ids(name="left_arm") + self.right_arm_joints = self.robot.get_joint_ids(name="right_arm") + self.left_eef_joints = self.robot.get_joint_ids(name="left_eef") + self.right_eef_joints = self.robot.get_joint_ids(name="right_eef") + + self.left_arm_init_qpos = self.init_qpos[self.left_arm_joints] + self.right_arm_init_qpos = self.init_qpos[self.right_arm_joints] + + self.left_arm_init_xpos = self.robot.compute_fk( + self.left_arm_init_qpos, name="left_arm", to_matrix=True + ).squeeze(0) + self.right_arm_init_xpos = self.robot.compute_fk( + self.right_arm_init_qpos, name="right_arm", to_matrix=True + ).squeeze(0) + + self.left_arm_current_qpos = self.left_arm_init_qpos + self.right_arm_current_qpos = self.right_arm_init_qpos + + self.left_arm_current_xpos = self.left_arm_init_xpos + self.right_arm_current_xpos = self.right_arm_init_xpos + + self.left_arm_base_pose = self.robot.get_control_part_base_pose( + "left_arm", to_matrix=True + ).squeeze(0) + self.right_arm_base_pose = self.robot.get_control_part_base_pose( + "right_arm", to_matrix=True + ).squeeze(0) + + self.open_state = torch.tensor([0.05]) + self.close_state = torch.tensor([0.0]) + self.left_arm_current_gripper_state = self.open_state + self.right_arm_current_gripper_state = self.open_state + + # store some useful obj information + init_obj_info = {} + obj_uids = self.sim.get_rigid_object_uid_list() + for obj_name in obj_uids: + obj = self.sim.get_rigid_object(obj_name) + obj_pose = obj.get_local_pose(to_matrix=True).squeeze(0) + obj_height = obj_pose[2, 3] # Extract the height (z-coordinate) + obj_grasp_pose = self.affordance_datas.get( + f"{obj_name}_grasp_pose_object", None + ) + init_obj_info[obj_name] = { + "pose": obj_pose, # Store the full pose (4x4 matrix) + "height": obj_height, # Store the height (z-coordinate) + "grasp_pose_obj": ( + obj_grasp_pose.squeeze(0) if obj_grasp_pose is not None else None + ), # Store the grasp pose if available + } + self.init_obj_info = init_obj_info + + # -------------------- Common getters / setters -------------------- + + def get_obs_for_agent(self): + obs = self.get_obs(get_valid_sensor_data=True) + rgb = obs["sensor"]["cam_high"]["color"].squeeze(0) + valid_rgb_1 = obs["sensor"]["valid_cam_1"]["color"].squeeze(0) + valid_rgb_2 = obs["sensor"]["valid_cam_2"]["color"].squeeze(0) + valid_rgb_3 = obs["sensor"]["valid_cam_3"]["color"].squeeze(0) + + # obs_image_path = save_obs_image(obs_image=self.get_obs_for_agent()["rgb_1"], save_dir='./', step_id=0) + + return { + "rgb": rgb, + "valid_rgb_1": valid_rgb_1, + "valid_rgb_2": valid_rgb_2, + "valid_rgb_3": valid_rgb_3, + } + + # depth = obs["sensor"]["cam_high"]["depth"].squeeze(0) + # mask = obs["sensor"]["cam_high"]["mask"].squeeze(0) + # semantic_mask = obs["sensor"]["cam_high"]["semantic_mask_l"].squeeze(0) + # return {"rgb": rgb, "depth": depth, "mask": mask, "semantic_mask": semantic_mask} + + def get_current_qpos_agent(self): + return self.left_arm_current_qpos, self.right_arm_current_qpos + + def set_current_qpos_agent(self, arm_qpos, is_left): + if is_left: + self.left_arm_current_qpos = arm_qpos + else: + self.right_arm_current_qpos = arm_qpos + + def get_current_xpos_agent(self): + return self.left_arm_current_xpos, self.right_arm_current_xpos + + def set_current_xpos_agent(self, arm_xpos, is_left): + if is_left: + self.left_arm_current_xpos = arm_xpos + else: + self.right_arm_current_xpos = arm_xpos + + def get_current_gripper_state_agent(self): + return self.left_arm_current_gripper_state, self.right_arm_current_gripper_state + + def set_current_gripper_state_agent(self, arm_gripper_state, is_left): + if is_left: + self.left_arm_current_gripper_state = arm_gripper_state + else: + self.right_arm_current_gripper_state = arm_gripper_state + + # -------------------- IK / FK -------------------- + def get_arm_ik(self, target_xpos, is_left, qpos_seed=None): + control_part = "left_arm" if is_left else "right_arm" + ret, qpos = self.robot.compute_ik( + name=control_part, pose=target_xpos, joint_seed=qpos_seed + ) + return ret.all().item(), qpos.squeeze(0) + + def get_arm_fk(self, qpos, is_left): + control_part = "left_arm" if is_left else "right_arm" + xpos = self.robot.compute_fk( + name=control_part, qpos=torch.as_tensor(qpos), to_matrix=True + ) + return xpos.squeeze(0) + + # -------------------- get only code for action list -------------------- + def generate_code_for_actions(self, regenerate=False, **kwargs): + logger.log_info( + f"Generate code for creating action list for {self.code_agent.task_name}.", + color="green", + ) + + # # Task planning (not used currently) + # print(f"\033[92m\nStart task planning.\n\033[0m") + # task_agent_input = self.task_agent.get_composed_observations(env=self) + # query = self.task_agent.generate(**task_agent_input, regenerate=regenerate, **kwargs) + + # Code generation + print(f"\033[94m\nStart code generation.\n\033[0m") + code_agent_input = self.code_agent.get_composed_observations( + env=self, regenerate=regenerate, **kwargs + ) + code_file_path, kwargs, code = self.code_agent.generate(**code_agent_input) + return code_file_path, kwargs, code + + # -------------------- get action list -------------------- + def create_demo_action_list(self, regenerate=False): + code_file_path, kwargs, _ = self.generate_code_for_actions( + regenerate=regenerate + ) + action_list = self.code_agent.act(code_file_path, **kwargs) + return action_list + + def create_demo_action_list_with_self_correction(self, **kwargs): + logger.log_info( + f"Generate code for creating action list for {self.code_agent.task_name} with self correction.", + color="green", + ) + + # Create log file name with timestamp + import datetime + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + log_dir = ( + Path(database_agent_prompt_dir) + / self.code_agent.task_name + / "self_correction_logs" + / timestamp + ) + os.makedirs(log_dir, exist_ok=True) + img_dir = log_dir / "observation_images" + + kwargs.setdefault("env", self) + kwargs.setdefault("log_dir", log_dir) + kwargs.setdefault("file_path", log_dir / "agent_generated_code.py") + kwargs.setdefault("md_path", log_dir / "agent_llm_responses.md") + kwargs.setdefault("last_task_plan", "None.") + kwargs.setdefault("last_executed_failure", "None.") + kwargs.setdefault("last_executed_history", "None.") + + # TODO: rethink which part should be divided to task / code agents. Important! + # TODO: use the task agent to select which needs the validation (mainly interaction with the objects), not all steps. + # TODO: add logs + # TODO: maybe use a sequence of images for task planning + + step_id = 0 + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_1"], + save_dir=img_dir / "cam_1", + step_id=step_id, + ) + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_2"], + save_dir=img_dir / "cam_2", + step_id=step_id, + ) + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_3"], + save_dir=img_dir / "cam_3", + step_id=step_id, + ) + + task_agent_input = self.task_agent.get_composed_observations(**kwargs) + code_agent_input = self.code_agent.get_composed_observations(**kwargs) + while True: + exec_code = [] + print(f"\033[94m\nStart task planning.\n\033[0m") + task_plan, plan_list, validation_list = ( + self.task_agent.generate_for_correction( + img_dir=img_dir / "cam_1", **task_agent_input + ) + ) + + # TODO: maybe here I need to insert an error-occurred agent, calling some error-occurred apis, maybe with correction action too. + # TODO:maybe the validation agent can provide correction action, and no need to generate the subsequent full task by the task agent. + + print(f"\033[92m\nStart code generation.\n\033[0m") + code_agent_input, code = self.code_agent.generate_according_to_task_plan( + task_plan=task_plan, **code_agent_input + ) + drive_list = extract_drive_calls(code) + for action_id, single_action in enumerate(drive_list): + try: + # ---------- execute ---------- + self.code_agent.act_single_action(single_action, **code_agent_input) + exec_success = True + exec_trace = None + + # # # # TODO: manually adjust the bottle pose for testing + # if step_id == 2: + # + # # pose = torch.tensor( + # # [[[0.99989, -0.00457, -0.01415, 0.72850], + # # [0.00457, 0.99999, -0.00041, -0.20441], + # # [0.01415, 0.00034, 0.99990, 0.92571], + # # [0.00000, 0.00000, 0.00000, 1.00000]]], + # # dtype=torch.float32 + # # ) + # # self.sim.get_rigid_object('bottle').set_local_pose(pose) + # + # pose = torch.tensor( + # [[[0.99989, -0.00457, -0.01415, 0.722850], + # [0.00457, 0.99999, -0.00041, 0.20441], + # [0.01415, 0.00034, 0.99990, 0.92571], + # [0.00000, 0.00000, 0.00000, 1.00000]]], + # dtype=torch.float32 + # ) + # self.sim.get_rigid_object('cup').set_local_pose(pose) + # + # # pose = self.sim.get_rigid_object('spoon').get_local_pose(to_matrix=True).squeeze(0) + # # pose[0, 3] = 0.6 + # # pose[1, 3] = -0.35 + # # pose[2, 3] = 0.8 + # # self.sim.get_rigid_object('spoon').set_local_pose(pose.unsqueeze(0)) + # + # for i in range(5): + # _ = self.step(action=self.robot.get_qpos()) + + except Exception: + exec_success = False + exec_trace = traceback.format_exc() + print(f"Execution failed:\n{exec_trace}") + + # ---------- step transition ---------- + step_id += 1 + + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_1"], + save_dir=img_dir / "cam_1", + step_id=step_id, + ) + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_2"], + save_dir=img_dir / "cam_2", + step_id=step_id, + ) + save_obs_image( + obs_image=self.get_obs_for_agent()["valid_rgb_3"], + save_dir=img_dir / "cam_3", + step_id=step_id, + ) + + # ---------- post-execution handling ---------- + if exec_success: + if code_agent_input.get("validation_agent"): + print( + f"\033[33mStarting validation with condition '{validation_list[action_id]}'!\033[0m" + ) + validation_info = self.validation_agent.validate_single_action( + single_action, + plan_list[action_id], + validation_list[action_id], + img_dir, + get_obj_position_info(self), + ) + + if "SUCCESS" in validation_info: + print(f"\033[33mValid info:\n{validation_info}\033[0m") + is_success = True + exec_code.append(plan_list[action_id]) + continue + else: + print(f"\033[31mValid info:\n{validation_info}\033[0m") + info = ( + "Validation Result: FAILED\n\n" + "Failed Step (currently executing step):\n" + f"{plan_list[action_id]}\n\n" + "Failure Analysis (why this step failed):\n" + f"{validation_info}" + ) + history = ( + "Executed History (previous steps):\n" + f"{format_execution_history(exec_code)}\n\n" + ) + is_success = False + else: + is_success = True + exec_code.append(plan_list[action_id]) + continue + else: + info = ( + "Action Execution: FAILED\n\n" + "Failed Step (currently executing step):\n" + f"{plan_list[action_id]}\n\n" + "Execution Error Trace:\n" + f"{exec_trace}\n\n" + "Note: You may try `force_valid=True` for the current action to find the nearest valid pose." + ) + history = ( + "Executed History (previous steps):\n" + f"{format_execution_history(exec_code)}\n\n" + ) + + is_success = False + + task_agent_input["last_task_plan"] = task_plan + task_agent_input["last_executed_failure"] = info + task_agent_input["last_executed_history"] = history + break + + if single_action == drive_list[-1] and is_success: + # ---------- termination ---------- + print( + "\033[91mExecuted all the plans. The task is considered complete.\033[0m" + ) + break From 9af2785f03194512c82c7efe40d62e0dfa5b73cb Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 18:19:02 +0800 Subject: [PATCH 04/49] Migrate prompt template --- embodichain/agents/mllm/prompt/__init__.py | 8 + embodichain/agents/mllm/prompt/code_prompt.py | 138 +++++++++++ embodichain/agents/mllm/prompt/task_prompt.py | 216 ++++++++++++++++++ 3 files changed, 362 insertions(+) create mode 100644 embodichain/agents/mllm/prompt/__init__.py create mode 100644 embodichain/agents/mllm/prompt/code_prompt.py create mode 100644 embodichain/agents/mllm/prompt/task_prompt.py diff --git a/embodichain/agents/mllm/prompt/__init__.py b/embodichain/agents/mllm/prompt/__init__.py new file mode 100644 index 0000000..55bc408 --- /dev/null +++ b/embodichain/agents/mllm/prompt/__init__.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .task_prompt import TaskPrompt +from .code_prompt import CodePrompt diff --git a/embodichain/agents/mllm/prompt/code_prompt.py b/embodichain/agents/mllm/prompt/code_prompt.py new file mode 100644 index 0000000..efbaae7 --- /dev/null +++ b/embodichain/agents/mllm/prompt/code_prompt.py @@ -0,0 +1,138 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from embodichain.utils.utility import encode_image, encode_image_from_path + + +class CodePrompt: + @staticmethod + def one_stage_prompt(**kwargs) -> ChatPromptTemplate: + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are an AI assistant that can generate python code to execute robot arms." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "text", + "text": ( + "Generate a Python code snippet that accomplishes the following task:\n" + "{query}\n\n" + "You must strictly follow the rules and available functions described below:\n" + "{code_prompt}\n\n" + "Here are some reference examples of the expected output code:\n" + "{code_example}\n\n" + ), + } + ] + ), + ] + ) + return prompt.invoke(kwargs) + + @staticmethod + def unified_prompt(observations, **kwargs): + """ + Unified Vision→Code prompt: + - Model observes the image + - Understands the scene and the task goal + - Generates final executable Python code using atomic robot APIs + """ + + # Encode the image + observation = observations["rgb"] + kwargs.update({"observation": encode_image(observation)}) + + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a reliable Vision-Language-Code robot assistant. " + "You observe an image, understand the scene and the task goal, " + "and generate correct Python code using ONLY the allowed atomic robot actions. " + "Your final output must be a single Python code block." + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "### Task Goal\n" + "{task_prompt}\n\n" + "### Environment Background\n" + "{basic_background}\n\n" + "### Allowed Atomic Actions\n" + "{atom_actions}\n\n" + "### Code Rules\n" + "{code_prompt}\n\n" + "### Reference Code Examples\n" + "{code_example}\n\n" + "### Final Instructions\n" + "Understand the scene from the image and generate final executable Python code " + "that performs the task using ONLY the allowed atomic actions.\n\n" + "Your entire response must be EXACTLY one Python code block:\n" + "```python\n" + "# your solution code here\n" + "```\n" + ), + }, + ] + ), + ] + ) + + return prompt.invoke(kwargs) + + @staticmethod + def one_stage_prompt_according_to_task_plan(**kwargs) -> ChatPromptTemplate: + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a reliable robot control code generator.\n" + "Your task is to generate Python code that executes robot arm actions.\n\n" + "CRITICAL RULES:\n" + "- The TASK PLAN defines the available atomic actions, rules, and execution logic.\n" + "- You MUST strictly follow the TASK PLAN.\n" + "- The CONSTRAINTS section contains additional global constraints you must obey.\n" + "- Do NOT invent new actions, functions, parameters, or control flow.\n" + "- You MAY include Python comments (# ...) inside the code.\n" + "- Your ENTIRE response MUST be a single Python code block.\n" + "- The code block MUST be directly executable without modification.\n" + "- Do NOT include any text, explanation, or markdown outside the Python code block.\n" + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "text", + "text": ( + "TASK PLAN (atomic actions, rules, and intended behavior):\n" + "{task_plan}\n\n" + "GLOBAL CONSTRAINTS (must be satisfied):\n" + "{code_prompt}\n\n" + "REFERENCE CODE (style and structure only; do NOT copy logic):\n" + "{code_example}\n\n" + "Generate the corrected Python code now." + ), + } + ] + ), + ] + ) + return prompt.invoke(kwargs) diff --git a/embodichain/agents/mllm/prompt/task_prompt.py b/embodichain/agents/mllm/prompt/task_prompt.py new file mode 100644 index 0000000..3ba547f --- /dev/null +++ b/embodichain/agents/mllm/prompt/task_prompt.py @@ -0,0 +1,216 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from embodichain.utils.utility import encode_image +from embodichain.utils.utility import encode_image, encode_image_from_path + + +class TaskPrompt: + @staticmethod + def one_stage_prompt(observations, **kwargs): + """ + Hybrid one-pass prompt: + Step 1: VLM analyzes the image and extracts object IDs. + Step 2: LLM generates task instructions using only those IDs. + """ + # Encode image + observation = observations["rgb"] + kwargs.update({"observation": encode_image(observation)}) + + # Build hybrid prompt + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a precise and reliable robotic manipulation planner. " + "Given a camera observation and a task description, you must generate " + "a clear, step-by-step task plan for a robotic arm. " + "All actions must strictly use the provided atomic API functions, " + "and the plan must be executable without ambiguity." + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "Here is the latest camera observation.\n" + "First, analyze the scene in the image.\n" + "Then, using the context below, produce an actionable task plan.\n\n" + "**Environment background:** \n{basic_background}\n\n" + '**Task goal:** \n"{task_prompt}"\n\n' + "**Available atomic actions:** \n{atom_actions}\n" + ), + }, + ] + ), + ] + ) + + # Return the prompt template and kwargs to be executed by the caller + return prompt.invoke(kwargs) + + @staticmethod + def two_stage_prompt(observations, **kwargs): + # for VLM generate image descriptions + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are a helpful assistant to operate a robotic arm with a camera to generate task plans according to descriptions." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpg;base64,{observation}", + }, + }, + { + "type": "text", + "text": "What is in the image? Return answer with their potential effects.", + }, + ] + ), + ] + ) + + observation = observations["rgb"] + kwargs.update({"observation": encode_image(observation)}) + # for LLM generate task descriptions + prompt_query = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are a helpful assistant to operate a robotic arm with a camera to generate task plans according to descriptions." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpg;base64,{observation}", + }, + }, + { + "type": "text", + "text": "Here is analysis for this image: {query}.", + }, + { + "type": "text", + "text": ( + "Using the context below, produce an actionable task plan.\n\n" + "**Environment background:** \n{basic_background}\n\n" + '**Task goal:** \n"{task_prompt}"\n\n' + "**Available atomic actions:** \n{atom_actions}\n" + ), + }, + ] + ), + ] + ) + + return [prompt.invoke(kwargs), {"prompt": prompt_query, "kwargs": kwargs}] + + @staticmethod + def one_stage_prompt_for_correction(obs_image_path, **kwargs): + """ + Hybrid one-pass prompt: + Step 1: VLM analyzes the image and extracts object IDs. + Step 2: LLM generates task instructions using only those IDs. + """ + # Encode image + kwargs.update({"observation": encode_image_from_path(obs_image_path)}) + + # Build hybrid prompt + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a robotic manipulation planner operating STRICTLY in the robot base coordinate frame.\n\n" + "COORDINATE FRAME RULE (NON-NEGOTIABLE):\n" + "- ALL spatial reasoning and motion descriptions (left/right/front/back/up/down, offsets, rotations)\n" + " are defined ONLY in the robot base coordinate frame, oriented from the base looking outward along +x (toward the end-effector).\n" + "- The camera is positioned in front of the robot, facing the arm and looking toward the robot base.\n" + "- Due to this viewpoint, the rendered image is HORIZONTALLY MIRRORED relative to the robot base frame.\n" + "- LEFT–RIGHT in the image MUST be inverted when reasoning:\n" + " * Image left → Robot right\n" + " * Image right → Robot left\n" + "- Vertical orientation is preserved:\n" + " * Image up → Robot up\n" + " * Image down → Robot down\n" + "- Always reason as if you are physically located at the robot base, facing along +x.\n" + "- For your output, you must use the robot base frame and explicitly account for this horizontal mirroring when interpreting the image " + "(e.g., What appears as “left” in the image corresponds to “right” in the robot base frame, and vice versa. " + "Vertical orientation is preserved: what appears as “up” in the image is also “up” in the robot base frame.).\n\n" + "HARD CONSTRAINT:\n" + "- Any reasoning based on image left/right, visual perspective, or camera orientation is VALID.\n" + "- If a direction cannot be inferred from the robot base frame, you must state it explicitly." + "- Each arm may execute at most one atomic action per step. If multiple atomic actions are required, " + "they must be distributed across multiple steps.\n" + "- Both arms may operate in the same step, but each arm may execute at most ONE atomic action per step. " + "If only one arm needs to act (e.g., a single-arm step or recovery), the other arm should remain idle.\n\n" + "TASK:\n" + "- Given the observation and task, produce a step-by-step plan using ONLY the provided atomic API.\n" + "- The plan must be executable without ambiguity.\n\n" + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "Here is the latest camera observation.\n" + "IMPORTANT: The current image may NOT represent the initial state of the task. " + "It may correspond to an intermediate step where some actions have already been executed.\n\n" + "First, analyze the scene in the image to infer the current state.\n" + "Then, using the context below, produce the remaining actionable task plan from this state onward.\n\n" + "**Environment background:** \n" + "{basic_background}\n\n" + '**Task goal:** \n"' + '{task_prompt}"\n\n' + "**Available atomic actions:** \n" + "{atom_actions}\n" + "**Failed Task Plan (Reference)::**\n" + "{last_task_plan}\n\n" + "**Executed history (reference only):**\n" + "{last_executed_history}\n\n" + "**Most recent failure (CRITICAL):**\n" + "{last_executed_failure}\n\n" + "**REQUIRED OUTPUT**\n" + "[PLANS]:\n" + "Step 1: (...)\n" + "..." + "Step N: (...)\n\n" + "[VALIDATION_CONDITIONS]:\n" + "Step 1: \n" + "..." + "Step N: \n\n" + "VALIDATION_CONDITIONS MUST include the robot arm and relevant object(s), and whether the object(s) should be held or not.\n" + "Produce the COMPLETE remaining task plan." + ), + }, + ] + ), + ] + ) + + return prompt.invoke(kwargs) From 60c168cd084ddbc6842af723d2eff4cede484c92 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 18:23:52 +0800 Subject: [PATCH 05/49] Migrate the core code of agent --- embodichain/agents/hierarchy/__init__.py | 3 + embodichain/agents/hierarchy/agent_base.py | 41 +++ embodichain/agents/hierarchy/code_agent.py | 325 +++++++++++++++++ embodichain/agents/hierarchy/llm.py | 36 ++ embodichain/agents/hierarchy/task_agent.py | 178 ++++++++++ .../agents/hierarchy/validation_agent.py | 330 ++++++++++++++++++ 6 files changed, 913 insertions(+) create mode 100644 embodichain/agents/hierarchy/__init__.py create mode 100644 embodichain/agents/hierarchy/agent_base.py create mode 100644 embodichain/agents/hierarchy/code_agent.py create mode 100644 embodichain/agents/hierarchy/llm.py create mode 100644 embodichain/agents/hierarchy/task_agent.py create mode 100644 embodichain/agents/hierarchy/validation_agent.py diff --git a/embodichain/agents/hierarchy/__init__.py b/embodichain/agents/hierarchy/__init__.py new file mode 100644 index 0000000..d56fc16 --- /dev/null +++ b/embodichain/agents/hierarchy/__init__.py @@ -0,0 +1,3 @@ +from langchain_openai import AzureChatOpenAI +from langchain_openai import ChatOpenAI +import os diff --git a/embodichain/agents/hierarchy/agent_base.py b/embodichain/agents/hierarchy/agent_base.py new file mode 100644 index 0000000..75b4dae --- /dev/null +++ b/embodichain/agents/hierarchy/agent_base.py @@ -0,0 +1,41 @@ +from abc import ABCMeta, abstractmethod +import os +import cv2 +from embodichain.utils.utility import load_json, load_txt +from embodichain.agents.mllm.prompt import * +from embodichain.data import database_agent_prompt_dir, database_2d_dir +from embodichain.utils.utility import encode_image + + +class AgentBase(metaclass=ABCMeta): + def __init__(self, **kwargs) -> None: + + assert ( + "prompt_kwargs" in kwargs.keys() + ), "Key prompt_kwargs must exist in config." + + for key, value in kwargs.items(): + setattr(self, key, value) + + # Preload and store prompt contents inside self.prompt_kwargs + for key, val in self.prompt_kwargs.items(): + if val["type"] == "text": + file_path = os.path.join(database_agent_prompt_dir, val["name"]) + val["content"] = load_txt(file_path) # ← store content here + else: + raise ValueError( + f"Now only support `text` type but {val['type']} is given." + ) + + def generate(self, *args, **kwargs): + pass + + def act(self, *args, **kwargs): + pass + + def get_composed_observations(self, **kwargs): + ret = {"observations": kwargs.get("env").get_obs_for_agent()} + for key, val in self.prompt_kwargs.items(): + ret[key] = val["content"] + ret.update(kwargs) + return ret diff --git a/embodichain/agents/hierarchy/code_agent.py b/embodichain/agents/hierarchy/code_agent.py new file mode 100644 index 0000000..a3eac99 --- /dev/null +++ b/embodichain/agents/hierarchy/code_agent.py @@ -0,0 +1,325 @@ +from embodichain.agents.hierarchy.agent_base import AgentBase +from langchain_core.prompts import ChatPromptTemplate +import os +import numpy as np +import functools +from typing import Dict, Tuple, Any +from embodichain.toolkits.code_generation import ( + ExecutableOutputParser, + OutputFormatting, +) +from embodichain.toolkits.toolkits import ToolkitsBase +from embodichain.agents.mllm.prompt import CodePrompt +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import re +import importlib.util +from langchain_core.messages import HumanMessage +from datetime import datetime +from embodichain.utils.utility import encode_image +import base64 + + +def format_execution_history(execution_history): + if not execution_history or len(execution_history) == 0: + return "None." + + return "\n\n".join(f"{i}. {entry}" for i, entry in enumerate(execution_history, 1)) + + +def extract_python_code_and_text(llm_response: str) -> Tuple[str, str]: + """ + Extract exactly ONE python code block from the LLM response, + and return: + - code: the content inside the python block + - text: all remaining explanation text (outside the code block) + + Raises ValueError if zero or multiple python blocks are found. + """ + + pattern = r"```python\s*(.*?)\s*```" + matches = list(re.finditer(pattern, llm_response, re.DOTALL)) + + if len(matches) == 0: + raise ValueError("No python code block found in LLM response.") + if len(matches) > 1: + raise ValueError("Multiple python code blocks found in LLM response.") + + match = matches[0] + code = match.group(1).strip() + + # Optional sanity check + if not code.startswith("#") and not code.startswith("drive("): + raise ValueError( + f"Invalid code block content. Expected `drive(...)` or `# TASK_COMPLETE`, got:\n{code}" + ) + + # Extract remaining text (before + after the code block) + text_before = llm_response[: match.start()].strip() + text_after = llm_response[match.end() :].strip() + + explanation_text = "\n\n".join(part for part in [text_before, text_after] if part) + + return code, explanation_text + + +def format_llm_response_md( + llm_analysis: str, # plain-text explanation (NO code) + extracted_code: str, # validated executable code + step_id: int = None, + execution_history: str = None, + obs_image_path: Path = None, + md_file_path: Path = None, +) -> str: + ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + header = f"## Step: {step_id if step_id is not None else '-'} | {ts}\n\n" + + history_block = "" + if execution_history: + history_block = ( + "### Execution History (Input to LLM)\n\n" + "```\n" + f"{execution_history}\n" + "```\n\n" + ) + + image_block = "" + if obs_image_path is not None and md_file_path is not None: + try: + rel_path = obs_image_path.relative_to(md_file_path.parent) + except ValueError: + # Fallback: just use filename + rel_path = obs_image_path.name + + image_block = ( + "### Observation Image\n\n" f"![]({Path(rel_path).as_posix()})\n\n" + ) + + body = ( + image_block + history_block + "### LLM Analysis\n\n" + f"{llm_analysis.strip()}\n\n" + "### Executed Code\n\n" + "```python\n" + f"{extracted_code.strip()}\n" + "```\n\n" + "---\n\n" + ) + + return header + body + + +class CodeAgent(AgentBase): + query_prefix = "# " + query_suffix = "." + prompt: ChatPromptTemplate + prompt_kwargs: Dict[str, Dict] + + def __init__(self, llm, **kwargs) -> None: + super().__init__(**kwargs) + self.llm = llm + + def generate(self, **kwargs): + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = log_dir / "agent_generated_code.py" + + # Check if the file already exists + if not kwargs.get("regenerate", False): + if file_path.exists(): + print(f"Code file already exists at {file_path}, skipping writing.") + return file_path, kwargs, None + + # Generate code via LLM + prompt = getattr(CodePrompt, self.prompt_name)( + **kwargs, + ) + + # insert feedback if exists + if len(kwargs.get("error_messages", [])) != 0: + # just use the last one + last_code = kwargs["generated_codes"][-1] + last_error = kwargs["error_messages"][-1] + last_observation = ( + kwargs.get("observation_feedbacks")[-1] + if kwargs.get("observation_feedbacks") + else None + ) + + # Add extra human message with feedback + feedback_msg = self.build_feedback_message( + last_code, last_error, last_observation + ) + prompt.messages.append(feedback_msg) + + llm_code = self.llm.invoke(prompt) + + # Normalize content + llm_code = getattr(llm_code, "content", str(llm_code)) + + print(f"\033[92m\nCode agent output:\n{llm_code}\n\033[0m") + + # Write the code to the file if it does not exist + match = re.search(r"```python\n(.*?)\n```", llm_code, re.DOTALL) + if match: + code_to_save = match.group(1).strip() + else: + code_to_save = llm_code.strip() + + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "w") as f: + f.write(code_to_save) + print(f"Generated function code saved to {file_path}") + + return file_path, kwargs, code_to_save + + def act(self, code_file_path, **kwargs): + # Dynamically import the generated function from the .py file + spec = importlib.util.spec_from_file_location( + "generated_function", code_file_path + ) + generated_function_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(generated_function_module) + + # Ensure that the function exists and call it with kwargs + if hasattr(generated_function_module, "create_agent_action_list"): + result = generated_function_module.create_agent_action_list( + **kwargs + ) # Call the function with kwargs + print("Function executed successfully.") + return result + else: + raise AttributeError( + "The function 'create_agent_action_list' was not found in the generated code." + ) + + def build_feedback_message( + self, last_code: str, last_error: str, last_observation: str = None + ) -> HumanMessage: + + useful_info = ( + "The error may be caused by:\n" + "1. You did not follow the basic background information, especially the world coordinate system with its xyz directions.\n" + "2. You did not take into account the NOTE given in the atomic actions or in the example functions.\n" + "3. You did not follow the steps of the task descriptions.\n" + ) + + # Optional observation section + observation_text = "" + if last_observation is not None: + observation_text = ( + "\nThe visual observation feedback of the execution process was:\n" + "```\n" + str(last_observation) + "\n```\n" + ) + + return HumanMessage( + content=( + "Your previously generated code was:\n" + "```\n" + last_code + "\n```\n\n" + "When this code was executed in the test environment, it failed with the following error:\n" + "```\n" + + last_error + + "```\n" + + observation_text + + "\n" + + useful_info + + "\nAnalyze the cause of the failure and produce a corrected version of the code. " + "Modify only what is necessary to fix the issue. The corrected code must:\n" + " - strictly use only the allowed atomic API functions,\n" + " - be executable and unambiguous,\n" + " - directly resolve the error shown above.\n\n" + "Your entire response must be EXACTLY one Python code block:\n" + "```python\n" + "# corrected solution code\n" + "```\n" + ) + ) + + def generate_according_to_task_plan(self, task_plan, **kwargs): + # Generate code via LLM + prompt = getattr(CodePrompt, self.prompt_name)(task_plan=task_plan, **kwargs) + + llm_code = self.llm.invoke(prompt) + llm_code = getattr(llm_code, "content", str(llm_code)) + + match = re.search(r"```python\n(.*?)\n```", llm_code, re.DOTALL) + if match: + llm_code = match.group(1).strip() + else: + llm_code = llm_code.strip() + + print(f"\033[92m\nCode agent output:\n{llm_code}\n\033[0m") + + return kwargs, llm_code + + def act_single_action(self, code: str, **kwargs): + import ast + + # ---- 0. Build execution namespace ---- + ns = { + "__builtins__": __builtins__, + "kwargs": kwargs, # visible for **kwargs injection + } + + # ---- 1. Executor-controlled import ---- + try: + exec( + "from embodichain.toolkits.interfaces import *", + ns, + ns, + ) + except Exception as e: + raise RuntimeError( + "Failed to import embodichain.toolkits.interfaces in act_single_action" + ) from e + + # ---- 2. Parse generated code ---- + tree = ast.parse(code) + body = tree.body + + # ---------- AST transformer: inject **kwargs everywhere ---------- + class InjectKwargs(ast.NodeTransformer): + def visit_Call(self, node): + self.generic_visit(node) + + # Check if **kwargs already exists + has_kwargs = any( + kw.arg is None + and isinstance(kw.value, ast.Name) + and kw.value.id == "kwargs" + for kw in node.keywords + ) + + if not has_kwargs: + node.keywords.append( + ast.keyword( + arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()) + ) + ) + + return node + + transformer = InjectKwargs() + + # ---- 3. Execute actions step by step ---- + for step_id, node in enumerate(body, start=1): + try: + node = transformer.visit(node) + ast.fix_missing_locations(node) + + step_mod = ast.Module([node], type_ignores=[]) + compiled = compile( + step_mod, filename=f"", mode="exec" + ) + + print( + f"\033[95m\nExecuting the current action {code} with **kwargs\033[0m" + ) + exec(compiled, ns, ns) + + except Exception as e: + raise RuntimeError( + f"Execution failed at step {step_id} with action {code}:\n{e}" + ) + + print("\033[95m\nThe current action step executed successfully.\033[0m") diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py new file mode 100644 index 0000000..a38463d --- /dev/null +++ b/embodichain/agents/hierarchy/llm.py @@ -0,0 +1,36 @@ +import os +from langchain_openai import ChatOpenAI, AzureChatOpenAI + +# ------------------------------------------------------------------------------ +# Environment configuration +# ------------------------------------------------------------------------------ + +os.environ["ALL_PROXY"] = "" +os.environ["all_proxy"] = "" +os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" +os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" +os.environ["OPENAI_API_VERSION"] = "2024-10-21" +os.environ["AZURE_OPENAI_ENDPOINT"] = "YOUR_ENDPOINT_HERE" + +# ------------------------------------------------------------------------------ +# LLM factory +# ------------------------------------------------------------------------------ + + +def create_llm(*, temperature=0.0, model="gpt-4o"): + return ChatOpenAI( + temperature=temperature, + model=model, + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_BASE_URL"), + ) + + +# ------------------------------------------------------------------------------ +# LLM instances +# ------------------------------------------------------------------------------ + +task_llm = create_llm(temperature=0.0, model="gpt-4o") +code_llm = create_llm(temperature=0.0, model="gemini-2.5-flash-lite") +validation_llm = create_llm(temperature=0.0, model="gemini-3-flash-preview") +view_selection_llm = create_llm(temperature=0.0, model="gemini-2.5-flash-lite") diff --git a/embodichain/agents/hierarchy/task_agent.py b/embodichain/agents/hierarchy/task_agent.py new file mode 100644 index 0000000..fd7516c --- /dev/null +++ b/embodichain/agents/hierarchy/task_agent.py @@ -0,0 +1,178 @@ +from typing import List, Dict, Tuple +from embodichain.agents.hierarchy.agent_base import AgentBase +from langchain_core.prompts import ChatPromptTemplate +from embodichain.data import database_2d_dir +from embodichain.utils.utility import load_txt, encode_image +from embodichain.agents.mllm.prompt import TaskPrompt +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +from langchain_core.messages import HumanMessage +import numpy as np + +# from openai import OpenAI +import os +import time +import cv2 +import glob +import json +import re + +USEFUL_INFO = """The error may be caused by: +1. You did not follow the basic background information, especially the world coordinate system with its xyz directions. +2. You did not take into account the NOTE given in the atom actions or in the example functions. +3. You did not follow the steps of the task descriptions.\n +""" + + +def extract_plan_and_validation(text: str) -> Tuple[str, List[str], List[str]]: + def get_section(src: str, name: str, next_name) -> str: + if next_name: + pat = re.compile( + rf"\[{name}\]\s*:\s*(.*?)\s*(?=\[{next_name}\]\s*:|\Z)", + re.DOTALL | re.IGNORECASE, + ) + else: + pat = re.compile( + rf"\[{name}\]\s*:\s*(.*?)\s*\Z", + re.DOTALL | re.IGNORECASE, + ) + m = pat.search(src) + return m.group(1).strip() if m else "" + + step_re = re.compile( + r"Step\s*\d+\s*:.*?(?=Step\s*\d+\s*:|\Z)", + re.DOTALL | re.IGNORECASE, + ) + + # ---- plans ---- + plans_raw = get_section(text, "PLANS", "VALIDATION_CONDITIONS") + plan_steps = [m.group(0).rstrip() for m in step_re.finditer(plans_raw)] + plan_str = "\n".join(plan_steps) + + # normalized plan list (strip "Step k:") + plan_list = [] + for step in plan_steps: + content = re.sub(r"^Step\s*\d+\s*:\s*", "", step, flags=re.IGNORECASE).strip() + if content: + plan_list.append(content) + + # ---- validations ---- + vals_raw = get_section(text, "VALIDATION_CONDITIONS", None) + validation_list = [] + for m in step_re.finditer(vals_raw): + content = re.sub( + r"^Step\s*\d+\s*:\s*", "", m.group(0), flags=re.IGNORECASE + ).strip() + if content: + validation_list.append(content) + + return plan_str, plan_list, validation_list + + +class TaskAgent(AgentBase): + prompt: ChatPromptTemplate + object_list: List[str] + target: np.ndarray + prompt_name: str + prompt_kwargs: Dict[str, Dict] + + def __init__(self, llm, **kwargs) -> None: + super().__init__(**kwargs) + self.llm = llm + + def generate(self, **kwargs) -> str: + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = log_dir / "agent_generated_plan.txt" + + # Check if the file already exists + if not kwargs.get("regenerate", False): + if file_path.exists(): + print(f"Plan file already exists at {file_path}, skipping writing.") + return load_txt(file_path) + + # Generate query via LLM + prompts_ = getattr(TaskPrompt, self.prompt_name)(**kwargs) + if isinstance(prompts_, list): + # TODO: support two-stage prompts with feedback + start_time = time.time() + response = self.llm.invoke(prompts_[0]) + query = response.content + print( + f"\033[92m\nSystem tasks output ({np.round(time.time()-start_time, 4)}s):\n{query}\n\033[0m" + ) + for prompt in prompts_[1:]: + temp = prompt["kwargs"] + temp.update({"query": query}) + start_time = time.time() + response = self.llm.invoke(prompt["prompt"].invoke(temp)) + query = response.content + print( + f"\033[92m\nSystem tasks output({np.round(time.time()-start_time, 4)}s):\n{query}\n\033[0m" + ) + else: + # insert feedback if exists + if len(kwargs.get("error_messages", [])) != 0: + # just use the last one + last_plan = kwargs["generated_plans"][-1] + last_code = kwargs["generated_codes"][-1] + last_error = kwargs["error_messages"][-1] + + # Add extra human message with feedback + feedback_msg = self.build_feedback_message( + last_plan, last_code, last_error + ) + prompts_.messages.append(feedback_msg) + + response = self.llm.invoke(prompts_) + print(f"\033[92m\nTask agent output:\n{response.content}\n\033[0m") + + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "w") as f: + f.write(response.content) + print(f"Generated task plan saved to {file_path}") + + return response.content + + def act(self, *args, **kwargs): + return super().act(*args, **kwargs) + + def build_feedback_message( + self, last_plan: str, last_code: str, last_error: str + ) -> HumanMessage: + return HumanMessage( + content=( + "Your previous plan was:\n" + "```\n" + last_plan + "\n```\n\n" + "This plan led the code agent to generate the following code according to your plan:\n" + "```\n" + last_code + "\n```\n\n" + "When this code was executed in the test environment, it failed with the following error:\n" + "```\n" + last_error + "\n```\n\n" + USEFUL_INFO + "\n" + "Please analyze the failure, revise your plan, and provide sufficient instructions to correct the issue, " + "so that the code agent can generate a correct and executable solution based on your plan. " + "Your updated plan must strictly adhere to the atomic API functions and avoid ambiguous actions." + ) + ) + + def generate_for_correction(self, img_dir, **kwargs): + # Generate task plan via LLM + image_files = glob.glob(os.path.join(img_dir, "obs_step_*.png")) + if len(image_files) < 1: + raise ValueError("Need at least one observation images for validation.") + # sort by step index + image_files_sorted = sorted( + image_files, + key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), + ) + obs_image_path = image_files_sorted[-1] # the current image + prompt = getattr(TaskPrompt, self.prompt_name)( + obs_image_path=obs_image_path, **kwargs + ) + + response = self.llm.invoke(prompt).content + print(f"\033[94m\nTask agent output:\n{response}\n\033[0m") + + task_plan, plan_list, validation_list = extract_plan_and_validation(response) + + return task_plan, plan_list, validation_list diff --git a/embodichain/agents/hierarchy/validation_agent.py b/embodichain/agents/hierarchy/validation_agent.py new file mode 100644 index 0000000..7ad33a5 --- /dev/null +++ b/embodichain/agents/hierarchy/validation_agent.py @@ -0,0 +1,330 @@ +from langchain_core.prompts import ChatPromptTemplate +import os +from langchain_core.messages import SystemMessage, HumanMessage +from abc import ABCMeta +from embodichain.utils.utility import encode_image_from_path +import glob +from embodichain.agents.hierarchy.llm import view_selection_llm + + +def save_obs_image(obs_image, save_dir, step_id=None): + """ + Save observation image using encode_image() and return its file path. + """ + import base64 + from embodichain.utils.utility import encode_image + + if obs_image is None: + return None + + if isinstance(save_dir, str): + from pathlib import Path + + save_dir = Path(save_dir) + + save_dir.mkdir(parents=True, exist_ok=True) + + name = f"obs_step_{step_id}.png" if step_id is not None else "obs.png" + img_path = save_dir / name + + # Encode to base64 + base64_image = encode_image(obs_image) + + # Decode base64 → bytes + img_bytes = base64.b64decode(base64_image) + + # Write to file + with open(img_path, "wb") as f: + f.write(img_bytes) + + return img_path + + +def get_obj_position_info(env): + import json + + position_info = {} + obj_uids = env.sim.get_rigid_object_uid_list() + for obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0)[:3, 3] + position_info[obj_name] = target_obj_pose.tolist() + return json.dumps(position_info, indent=4) + + +class ValidationAgent(metaclass=ABCMeta): + + def __init__(self, llm, **kwargs) -> None: + super().__init__() + for key, value in kwargs.items(): + setattr(self, key, value) + self.llm = llm + + def validate(self, step_names, problematic_code, error_message, image_files): + # Construct the prompt + prompt = f""" + Analyze the execution of the following robot task: + + Task name: {self.task_name} + Task description: {self.task_description} + Basic background knowledge: {self.basic_background} + + You will be given images showing each step of the execution. For the step sequence: + {', '.join(step_names)} + + Provide the following analysis: + 1. Determine whether each step was executed correctly. + 2. If a step failed, identify which one and explain the cause. + 3. Decide whether the full task succeeded or failed. + 4. If the task failed, provide a precise and detailed explanation. + + Below is a potentially problematic piece of code and the corresponding execution error: + + ```python + {problematic_code} + # Execution error: + {error_message} + Explain whether (and how) this code contributed to the observed failure. + """ + + # Prepare message content for API call + user_content = [] + + # Add textual prompt + user_content.append({"type": "text", "text": prompt}) + + # Add images and step names + for img_path in image_files: + filename = os.path.basename(img_path) + first_underscore_pos = filename.find("_") + if first_underscore_pos != -1: + step_name = filename[first_underscore_pos + 1 :].rsplit(".", 1)[0] + else: + step_name = filename.rsplit(".", 1)[0] + + # Add step name + user_content.append({"type": "text", "text": f"Step: {step_name}"}) + + # Add image as base64 + base64_image = encode_image_from_path(img_path) + user_content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + } + ) + + messages = [ + SystemMessage( + content="You are a robot task execution analysis expert. Please analyze the provided image sequence." + ), + HumanMessage(content=user_content), + ] + + response = self.llm.invoke(messages) + return response.content + + def select_best_view_dir( + self, img_dirs: dict, action_description: str, valid_condition: str + ): + """ + img_dirs: { + "cam_1": Path, + "cam_2": Path, + "cam_3": Path + } + """ + + # --- collect final images --- + last_images = {} + for cam_id, cam_dir in img_dirs.items(): + imgs = sorted( + glob.glob(os.path.join(cam_dir, "obs_step_*.png")), + key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), + ) + if imgs: + last_images[cam_id] = imgs[-1] + + if not last_images: + raise ValueError("No images found in any camera directory.") + + # --- system prompt --- + system_prompt = ( + "You are a robot perception assistant specialized in VIEW SELECTION.\n\n" + "TASK:\n" + "- You are given ONE final observation image from EACH camera view.\n" + "- Your job is NOT to judge success or failure.\n" + "- Your job is ONLY to select the SINGLE camera view that is MOST SUITABLE\n" + " for OBJECT-LEVEL validation of the action result.\n\n" + "ACTION CONTEXT:\n" + "- The robot has just executed ONE atomic action.\n" + "- You are given the action intention and the expected object-level outcome\n" + " ONLY to help you decide which view best reveals that outcome.\n\n" + "SELECTION CRITERIA (PRIORITY ORDER):\n" + "- Prefer views with:\n" + " * the clearest visibility of the relevant object(s)\n" + " * minimal occlusion by the arm or environment\n" + " * the clearest evidence related to the expected object-level result\n" + " (e.g., contact, separation, support, stability)\n\n" + "STRICT CONSTRAINTS:\n" + "- Do NOT judge robot motion quality or execution accuracy.\n" + "- Do NOT reason about numeric values (distance, angle, offset).\n" + "- Do NOT decide whether the action succeeded or failed.\n" + "- If multiple views are acceptable, choose the clearest overall view.\n\n" + "OUTPUT FORMAT (STRICT):\n" + "Output EXACTLY ONE of the following tokens:\n" + "- cam_1\n" + "- cam_2\n" + "- cam_3\n" + ) + + # --- human content --- + human_content = [ + { + "type": "text", + "text": ( + "Select the best camera view for object-level validation.\n\n" + "--------------------------------------------------\n" + "ACTION DESCRIPTION (INTENT ONLY):\n" + f"{action_description}\n\n" + "EXPECTED OBJECT-LEVEL RESULT (REFERENCE ONLY):\n" + f"{valid_condition}\n" + "--------------------------------------------------" + ), + } + ] + + for cam_id, img_path in last_images.items(): + img_b64 = encode_image_from_path(img_path) + human_content.extend( + [ + {"type": "text", "text": f"View candidate: {cam_id}"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + }, + ] + ) + + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=human_content), + ] + + response = view_selection_llm.invoke(messages).content.strip() + + if response not in img_dirs: + raise ValueError(f"Invalid camera selection from LLM: {response}") + + return response + + def validate_single_action( + self, + current_action, + action_description, + valid_condition, + img_dir, + obj_position_info, + ): + # --- camera directories --- + img_dirs = { + "cam_1": img_dir / "cam_1", + "cam_2": img_dir / "cam_2", + "cam_3": img_dir / "cam_3", + } + + # === Stage 1: select best view === + selected_cam = self.select_best_view_dir( + img_dirs, action_description, valid_condition + ) + selected_dir = img_dirs[selected_cam] + print(f"\033[38;5;214mSelected camera for validation: {selected_cam}\033[0m") + + # === Stage 2: load FULL sequence from selected view === + image_files = glob.glob(os.path.join(selected_dir, "obs_step_*.png")) + if len(image_files) < 2: + raise ValueError("Need at least two observation images for validation.") + + # Sort images by step index + image_files_sorted = sorted( + image_files, + key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), + ) + + # Encode ALL images in sequence + encoded_images = [encode_image_from_path(p) for p in image_files_sorted] + + system_prompt = ( + "You are a helpful robot manipulation ACTION VALIDATOR.\n\n" + "ROLE:\n" + "- Judge ONLY the OBJECT-LEVEL outcome of ONE atomic action.\n" + "- Do NOT judge robot motion, planning, or execution quality.\n\n" + "CORE ASSUMPTIONS:\n" + "- The robot arm motion itself is correct by definition.\n" + "- Any failure must be due to incorrect OBJECT interaction or state.\n\n" + "COORDINATE RULE:\n" + "- The image is horizontally mirrored: image left ↔ robot right, image right ↔ robot left.\n" + "- Vertical direction is preserved.\n" + "- Use robot base frame terminology in your final judgment.\n\n" + "EVALUATION RULES:\n" + "- Focus on the FINAL image.\n" + "- Earlier images are context only.\n" + "- Do NOT infer numeric precision or motion quality.\n" + "- Ignore minor offsets or simulation noise.\n\n" + "DECISION POLICY:\n" + "- If visual evidence contradicts the expected object state → FAILURE.\n" + "- If visual evidence clearly matches the expected object state → SUCCESS.\n" + ) + + prompt = f""" + Validate the result of ONE atomic robot action. + + -------------------------------------------------- + ACTION + -------------------------------------------------- + {action_description} + + -------------------------------------------------- + EXPECTED OBJECT-LEVEL OUTCOME + -------------------------------------------------- + {valid_condition} + + -------------------------------------------------- + INPUT + -------------------------------------------------- + You are given an ordered image sequence. + - The FINAL image shows the state AFTER the action. + + -------------------------------------------------- + OUTPUT FORMAT (STRICT) + -------------------------------------------------- + Output EXACTLY one of the following. + + IMPORTANT: + - You MUST explicitly state the correctness of BOTH arms in the Evidence. + + [ACTION_SUCCESS] + - Evidence: + + [ACTION_FAILED] + - Reason: + - Evidence: + """ + + # Build multimodal message with ALL images + human_content = [{"type": "text", "text": prompt}] + for img_b64 in encoded_images: + human_content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + } + ) + + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=human_content), + ] + + llm_response = self.llm.invoke(messages) + return llm_response.content From ee5a5683f18c625043b51dd986f2148333d5ed73 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 16 Jan 2026 18:30:23 +0800 Subject: [PATCH 06/49] Migrate prompt file --- .../DualPourWaterAgent-v3/task_prompt.txt | 5 + .../PourWaterAgent-v3/task_prompt.txt | 5 + .../RearrangementAgent-v3/task_prompt.txt | 8 ++ .../database/agent_prompt/atom_actions.txt | 136 ++++++++++++++++++ .../agent_prompt/basic_background.txt | 42 ++++++ .../database/agent_prompt/code_example.txt | 35 +++++ .../database/agent_prompt/code_prompt.txt | 7 + 7 files changed, 238 insertions(+) create mode 100644 embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt create mode 100644 embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt create mode 100644 embodichain/database/agent_prompt/atom_actions.txt create mode 100644 embodichain/database/agent_prompt/basic_background.txt create mode 100644 embodichain/database/agent_prompt/code_example.txt create mode 100644 embodichain/database/agent_prompt/code_prompt.txt diff --git a/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt new file mode 100644 index 0000000..6b2205e --- /dev/null +++ b/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt @@ -0,0 +1,5 @@ +Task: +Use both arms to pour water from the bottle into the cup. +First, grasp the bottle with right arm and the cup with left arm simultaneously. Then lift the cup by 0.10 m, and then move it to [0.55, 0.05] to prepare for pouring. +Then hold the bottle at a relative offset of [0.05, −0.10, 0.125] with respect to the cup for starting pouring. +After pouring, place the bottle to [0.7, −0.1] and place the cup at [0.6, 0.1] simultaneously. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt new file mode 100644 index 0000000..be7eb3b --- /dev/null +++ b/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt @@ -0,0 +1,5 @@ +Task: +Use a single robotic arm to pour water from the bottle into the cup. +Position the bottle at an offset of [0.05, −0.10, 0.125] relative to the cup’s position during pouring. +After completing the pour, return the bottle to the location [0.7, −0.1]. +The cup should remain stationary throughout the task. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt new file mode 100644 index 0000000..280bd22 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt @@ -0,0 +1,8 @@ +Task: +Use both arms to rearrange a fork and a spoon on opposite sides of a plate. Perform the following steps **simultaneously**. + +1. Grasp the fork with the left arm and the spoon with the right arm. + +2. Reorient both end-effectors to a downward-facing pose. + +3. Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/atom_actions.txt b/embodichain/database/agent_prompt/atom_actions.txt new file mode 100644 index 0000000..257464c --- /dev/null +++ b/embodichain/database/agent_prompt/atom_actions.txt @@ -0,0 +1,136 @@ +### Atom Functions for Robot Arm Control +Each atomic function returns a list of joint-space trajectories (list[np.ndarray]). +All functions support an optional argument: + +force_valid: bool + If True, the system will automatically correct an invalid target pose by + projecting it to the nearest valid pose. Use this option carefully: + enable it only for actions where small spatial deviations are acceptable + and will not compromise task correctness. Default is False. + +Use the following functions exactly as defined. Do not invent new APIs or parameters. + +"grasp": + def grasp(robot_name: str, obj_name: str, pre_grasp_dis: float, **kwargs) -> list[np.ndarray] + + Moves the specified arm to the target object’s affordance-based grasp pose and executes a grasp by closing the gripper. + + The function plans a two-stage trajectory: + (1) from the current pose to a pre-grasp pose offset from the object, and + (2) from the pre-grasp pose to the final grasp pose, followed by gripper closure. + + Upon completion, the gripper is closed and the target object is expected to be stably held by the gripper. + + Example: + grasp(robot_name='right_arm', obj_name='bottle', pre_grasp_dis=0.10) # Moves the right arm to a pre-grasp pose 10 cm from the bottle, then to the grasp pose and closes the gripper to grasp the bottle. + +"place_on_table": + def place_on_table(robot_name: str, obj_name: str, x: float, y: float, pre_place_dis: float, **kwargs) -> list[np.ndarray] + + Moves the specified robot arm with the target object to the desired [x, y] location on the table and opens the gripper to place the object. + The z-coordinate is automatically adjusted based on the table height and the object’s dimensions. + This function assumes that the robot is already holding the object and that the task is to place it on the table at the specified coordinates. + Remember that when you need to place some objects on the table at specific coordinates, use this function without using other movement atom actions. + Otherwise, **if you need to place some objects relative to some place, then use "move_relative_to_object" first to move to the desired position, then use "open_gripper" to release the object.** + + Example: + place_on_table(robot_name='right_arm', obj_name='bottle', x=0.1, y=0.5, pre_place_dis=0.08) # Moves the right arm to a pre-place position 8 cm from the table, then places the bottle at the specified [0.1, 0.5] location on the table and opens the gripper. + +"move_relative_to_object": + def move_relative_to_object(robot_name: str, obj_name: str, + x_offset=0, y_offset=0, z_offset=0, + **kwargs) -> list[np.ndarray] + Moves the end-effector to a pose defined relative to the target object: + target = object_position + [x_offset, y_offset, z_offset] + Orientation is preserved. + Example: + move_relative_to_object(robot_name='right_arm', obj_name='cup', + x_offset=0.05, y_offset=0.10, z_offset=0.10) # Moves the right arm’s end-effector to a spot located 5 cm forward, 10 cm to the left, and 10 cm above the cup, while preserving the current gripper orientation. + move_relative_to_object(robot_name='right_arm', obj_name='cup', + x_offset=-0.05, y_offset=-0.10, z_offset=0.10) # Moves the right arm’s end-effector to a spot located 5 cm backward, 10 cm to the right, and 10 cm above the cup, while preserving the current gripper orientation. + +"move_to_absolute_position": + def move_to_absolute_position(robot_name: str, + x=None, y=None, z=None, + **kwargs) -> list[np.ndarray] + Moves the end-effector to an absolute (x, y, z) position in world coordinates. + Any coordinate set to None remains unchanged. + Orientation is preserved. + Example: + move_to_absolute_position(robot_name='right_arm', x=0.10, y=0.10, z=None) # Moves the end-effector to the absolute world position (x=0.10 m, y=0.10 m) while leaving z unchanged, and preserves the orientation. + +"move_by_relative_offset": + def move_by_relative_offset(robot_name: str, + dx=0.0, dy=0.0, dz=0.0, mode='extrinsic', + **kwargs) -> list[np.ndarray] + Moves the end-effector by a relative translation: + new_position = current_position + [dx, dy, dz] + The offset is applied along the specified axes using the given mode, while preserving the original end-effector orientation. + Mode can be 'extrinsic' (world frame) or 'intrinsic' (end-effector frame). If you want to move along the world axes, use 'extrinsic'. If you want to move along the end-effector’s local axes, use "intrinsic". + Example: + move_by_relative_offset(robot_name='right_arm', dx=0.05, dy=-0.10, dz=0.20, mode='extrinsic') # Translates the end-effector by +5 cm in x (front), −10 cm in y (right), +20 cm in z (above) in the world coordinate, relative to its current position, with orientation preserved. + move_by_relative_offset(robot_name='right_arm', dx=0, dy=0, dz=0.1, mode='intrinsic') # Translates the end-effector by +10 cm in z (forward) in the EEF coordinate, meaning that it moves forward relative to its current facing direction, with orientation preserved. + +"rotate_eef" + def rotate_eef(robot_name: str, degree: float, **kwargs) -> list[np.ndarray] + Rotates the wrist roll joint (joint index 5) of the specified arm by the + given number of degrees. End-effector position is preserved. + Example: + rotate_eef(robot_name='right_arm', degree=-90) # Rotates the right arm’s wrist-roll joint by −45° (counterclockwise), while keeping the end-effector position unchanged. This is a joint-level rotation, not a full orientation override. + rotate_eef(robot_name='right_arm', degree=90) # Rotates the right arm’s wrist-roll joint by 45° (clockwise), while keeping the end-effector position unchanged. This is a joint-level rotation, not a full orientation override. + Typical use cases: + Pouring or tilting a grasped object. + Rotating the gripper around its forward axis without translating the end effector. + After rotating, you typically need to apply an opposite rotation back to return to the original pose. + Usage notes: + Rotation sign convention: negative = counterclockwise, positive = clockwise, viewed along the end-effector forward axis. + For pouring with the right arm, a common pattern is: first apply a negative rotation to start pouring, then apply a positive rotation to return. + For the left arm, the sign convention is typically reversed. + +"orient_eef": + def orient_eef(robot_name: str, + direction: str = 'front', # 'front' or 'down' + **kwargs) -> list[np.ndarray] + Reorients the end-effector to a predefined canonical orientation in the + WORLD coordinate frame, while keeping the EE’s current position fixed. + This function replaces the entire 3×3 orientation matrix of the current + end-effector pose. + Usage notes: + This function should only be used when you explicitly need to override the end-effector’s full orientation. + This differs from rotate_eef(). orient_eef performs a full orientation override of the end-effector, not a single-joint rotation. For tasks like pouring, no need to use it. + For general wrist rotation, prefer using rotate_eef instead. + For aligning tasks, use "front" or "down" orientations as needed. + Supported orientations: + • 'front' : Align the end-effector so its direction faces forward. + • 'down' : Align the end-effector so its direction faces downward. + Example: + orient_eef(robot_name='right_arm', direction='front') # Reorients the right arm’s end-effector so it faces forward + +"back_to_initial_pose": + def back_to_initial_pose(robot_name: str, **kwargs) -> list[np.ndarray] + Returns the specified arm to its predefined initial joint configuration + stored in the environment. + Example: + back_to_initial_pose(robot_name='right_arm') # Returns the right arm back to its predefined initial joint configuration stored in the environment, regardless of its current pose. + +"close_gripper": + def close_gripper(robot_name: str, **kwargs) -> list[np.ndarray] + Closes the arm’s gripper using a short (10-step) gripper-only trajectory. + Example: + close_gripper(robot_name='right_arm') # Closes the right gripper using a short, smooth 10-step gripper-only trajectory. + +"open_gripper": + def open_gripper(robot_name: str, **kwargs) -> list[np.ndarray] + Opens the arm’s gripper using a short (10-step) gripper-only trajectory. + Example: + open_gripper(robot_name='right_arm') # Opens the right gripper using a 10-step gripper-only trajectory. + +### Drive Function (Trajectory Synchronization) +"drive": + def drive(left_arm_action=None, right_arm_action=None, **kwargs) -> list[torch.Tensor] + Wraps one or two arm trajectories into synchronized full-robot actions. + • If only one arm action is provided, the other arm stays idle. + • If both are provided, they are temporally aligned and executed together. + • The actions are obtained from the output of the above functions. + Example: + drive(left_arm_action=left_actions, right_arm_action=right_actions) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/basic_background.txt b/embodichain/database/agent_prompt/basic_background.txt new file mode 100644 index 0000000..dc6d1c3 --- /dev/null +++ b/embodichain/database/agent_prompt/basic_background.txt @@ -0,0 +1,42 @@ +The environment uses a right-handed world coordinate system, where 1 unit equals 1 meter. +All robot poses are represented as 4×4 homogeneous transformation matrices. + +The robot base coordinate frame is the ONLY authoritative frame for all spatial reasoning, planning, and action generation. + +CAMERA AND IMAGE INTERPRETATION + +The camera is positioned in front of the robot, facing the robot arm and looking toward the robot base. +Because of this viewpoint, the rendered image is horizontally mirrored relative to the robot base frame. +This mirroring affects LEFT–RIGHT only. There is NO vertical or depth inversion. + +Mirror mapping (image → robot base frame): + +* Image left corresponds to robot right +* Image right corresponds to robot left +* Image up corresponds to robot up +* Image down corresponds to robot down + +REQUIRED REASONING PERSPECTIVE (NON-NEGOTIABLE) + +You must ignore the camera and rendered image orientation when reasoning. +All spatial reasoning must be performed as if you are physically located at the robot base, looking outward along the robot’s +x (forward) direction. + +Do NOT reason from the camera viewpoint. +Do NOT trust left/right as shown in the image. +Always remap image left/right before reasoning. + +ROBOT BASE COORDINATE DEFINITIONS + +All directions below are defined strictly in the robot base frame: + +* Moving forward increases x +* Moving backward decreases x +* Moving left increases y (appears as right in the image) +* Moving right decreases y (appears as left in the image) +* Moving up increases z +* Moving down decreases z + +ROBOT INITIALIZATION AND TERMINATION + +Both robot arms start in predefined initial configurations with their end-effectors open. +At task completion, both arms must be returned to their initial poses. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/code_example.txt b/embodichain/database/agent_prompt/code_example.txt new file mode 100644 index 0000000..c2952fe --- /dev/null +++ b/embodichain/database/agent_prompt/code_example.txt @@ -0,0 +1,35 @@ +# Python scripts +# Use the right arm to grasp bottle, move to the target location (x=0.2, y=0.1), and then open the gripper to release the object. + +```python +# Step 1 — Reach and grasp the bottle +drive( + right_arm_action=grasp( + robot_name="right_arm", + obj_name="bottle", + ), +) + +# Step 2 — Move to target location +drive( + right_arm_action=move_to_absolute_position( + robot_name="right_arm", + x=0.2, + y=0.1, + ), +) + +# Step 3 — Open gripper to release the object +drive( + right_arm_action=open_gripper( + robot_name="right_arm", + ), +) + +# Step 4 — Return the arm to the initial pose +drive( + right_arm_action=back_to_initial_pose( + robot_name="right_arm", + ), +) +``` \ No newline at end of file diff --git a/embodichain/database/agent_prompt/code_prompt.txt b/embodichain/database/agent_prompt/code_prompt.txt new file mode 100644 index 0000000..3fadf1c --- /dev/null +++ b/embodichain/database/agent_prompt/code_prompt.txt @@ -0,0 +1,7 @@ +Constraints: +- Every atomic action MUST be executed via a single drive(...) call. +- Each drive(...) call must directly contain the atomic action(s); do NOT define actions separately and then pass them into drive. +- For single-arm execution: specify the active arm’s action and explicitly set the unused arm to None within the same drive(...) call. +- For dual-arm execution: both arms’ actions MUST be specified within the same drive(...) call. +- Use exactly one drive(...) call per step; no exceptions. +- Output MUST be executable Python code only: no explanations, no comments, no markdown, and no extra text. \ No newline at end of file From b1212afbe0f5b5f4c3c6ae723ac4f297a4bbe0a2 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 13:38:35 +0800 Subject: [PATCH 07/49] Update pyproject and gitignore --- .gitignore | 3 ++- pyproject.toml | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index cbe0649..040955d 100644 --- a/.gitignore +++ b/.gitignore @@ -179,7 +179,8 @@ materials embodichain/toolkits/outputs/* embodichain/toolkits/outputs/* -embodichain/database/* +embodichain/database/tmp/* +embodichain/database/train/* 3rdparty/ Log/ diff --git a/pyproject.toml b/pyproject.toml index 50fef25..a770d55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dependencies = [ "dexsim_engine==0.3.9", "setuptools>=78.1.1", "gymnasium>=0.29.1", + "langchain==0.2.14", + "langchain-openai==0.1.22", "toppra==0.6.3", "pin", "pin-pink", @@ -49,6 +51,8 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "zmq==0.0.0", + "pycocotools", ] [project.optional-dependencies] From 52a2b457eb57267fc3c21f2eb675c82c177103d9 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 14:09:07 +0800 Subject: [PATCH 08/49] Migrate generate video script --- embodichain/lab/scripts/generate_video.py | 175 ++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 embodichain/lab/scripts/generate_video.py diff --git a/embodichain/lab/scripts/generate_video.py b/embodichain/lab/scripts/generate_video.py new file mode 100644 index 0000000..d0a89d7 --- /dev/null +++ b/embodichain/lab/scripts/generate_video.py @@ -0,0 +1,175 @@ +from embodichain.utils.logger import log_info, log_warning + +try: + import h5ffmpeg as hf +except Exception as e: + log_warning("Fail to import h5ffmpeg.") +import h5py +import argparse +import numpy as np +import os +from tqdm import tqdm +from dexsim.utility import images_to_video +from typing import Dict, Callable, Tuple +from embodichain.utils.visualizer import draw_keypoints, draw_action_distribution +from embodichain.data.enum import EefType, JointType, Modality, PrivilegeType +from embodichain.data.data_engine.indices_unifier import ActionIndicesGenerator + + +class VideoCreator: + def __init__(self) -> None: + pass + + @staticmethod + def _sub_function( + images, + output_path, + video_key, + exteroceptions: Dict = None, + multiplier: int = 1, + drawer: Callable = lambda x: x, + ): + for key in images.keys(): + imgs = images[key] + if imgs is None: + log_warning(f"No images found for key: {key}. Skipping.") + continue + img_list = [] + for i in tqdm(range(imgs.shape[0])): + image_i = drawer(imgs[i] * multiplier) + if exteroceptions is not None and len(exteroceptions[key]) != 0: + image_i = draw_keypoints( + image_i, exteroceptions[key][i].reshape(-1, 2) + ) + img_list.append(image_i) + + images_to_video(img_list, output_path, f"{key}_{video_key}") + + @staticmethod + def monocular_save( + observations: Dict, + video_key: str, + output_path: str, + multiplier: int = 1, + drawer: Callable = lambda x: x, + draw_exteroception: bool = True, + ): + images = observations[video_key] + if ( + PrivilegeType.EXTEROCEPTION.value in observations.keys() + and draw_exteroception + ): + exteroceptions = observations[PrivilegeType.EXTEROCEPTION.value] + else: + exteroceptions = None + VideoCreator._sub_function( + images, + output_path, + video_key, + exteroceptions, + multiplier, + drawer, + ) + + +def visualize_data_dict(f: Dict, output_path: str): + observations = f["observations"] + + if PrivilegeType.MASK.value in observations.keys(): + VideoCreator.monocular_save( + observations, + PrivilegeType.MASK.value, + output_path, + 255, + draw_exteroception=False, + ) + + if Modality.GEOMAP.value in observations.keys(): + from embodichain.utils.utility_3d import gen_disp_colormap + + VideoCreator.monocular_save( + observations, + Modality.GEOMAP.value, + output_path, + 1, + lambda x: (gen_disp_colormap(x).transpose(1, 2, 0) * 255).astype(np.uint8), + draw_exteroception=False, + ) + + VideoCreator.monocular_save(observations, Modality.IMAGES.value, output_path) + + +def main(args): + + data_path = args.data_path + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + assert data_path.endswith(".hdf5"), "Data path must have format of .hdf5" + with h5py.File(data_path, "r") as f: + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + import hdfdict + + data = hdfdict.load(data_path) + data = CompressedVideoHDF5( + output_path, chunks=data["chunks"].item() + ).safe_filter(data) + + # NOTE: DO NOT USE THIS IN SCRIPT, IT IS FOR DEBUGGING PURPOSES ONLY + # slice_id = 20 + # data_copy = hdfdict.load(data_path) + # data_copy = CompressedVideoHDF5(output_path).safe_filter(data_copy, slice_id=slice_id) + + # a = data["observations"]["images"]["cam_high"][slice_id] + # b = data_copy["observations"]["images"]["cam_high"] + # print(a, b.shape) + # delta = a-b[slice_id] + # print(np.linalg.norm(delta)) + + visualize_data_dict(data, output_path) + if "robot_meta" in data.keys(): + log_warning("Simulation data.") + robot_meta = data["robot_meta"] + arm_dofs = robot_meta["arm_dofs"][()] + actions = f[Modality.ACTIONS.value][()] + else: + from embodichain.data.data_engine.datasets.sim_real_unified_dict_dataset import ( + RobotRealDataRouter, + ) + + log_warning("Real data.") + actions, _, _, _, _ = RobotRealDataRouter( + robot_name=args.robot_name + ).realdata2simdata(f) + indices_generator = ActionIndicesGenerator(arm_dofs) + + key_names = indices_generator.global_mapping.mapping_from_name_to_indices.keys() + log_info(f"Arm dofs: {arm_dofs}", color="green") + indices_dict = {} + for key_name in key_names: + indices_dict[key_name] = indices_generator.get([key_name]) + draw_action_distribution(actions, indices_dict, output_path, smooth=args.smooth) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, help="Path to the data file.") + parser.add_argument( + "--output_path", + type=str, + help="Path to the output video file.", + default="./outputs", + ) + parser.add_argument( + "--smooth", + action="store_true", + default=False, + help="whether smooth joints.", + ) + parser.add_argument("--robot_name", default="DexforceW1", type=str) + args = parser.parse_args() + + main(args) From 0b65baa713321bafefaefc902cd024ca5f3316e6 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 14:14:08 +0800 Subject: [PATCH 09/49] Migrate data engine --- embodichain/data/data_engine/__init__.py | 5 + .../data/data_engine/compressed_hdf5.py | 545 ++++++++++++ .../data/data_engine/data_dict_extractor.py | 801 ++++++++++++++++++ .../datasets/sim_real_unified_dict_dataset.py | 696 +++++++++++++++ .../data/data_engine/indices_unifier.py | 395 +++++++++ embodichain/data/data_engine/online/engine.py | 513 +++++++++++ embodichain/data/data_engine/online/enum.py | 24 + .../data_engine/online/online_generator.py | 181 ++++ 8 files changed, 3160 insertions(+) create mode 100644 embodichain/data/data_engine/__init__.py create mode 100644 embodichain/data/data_engine/compressed_hdf5.py create mode 100644 embodichain/data/data_engine/data_dict_extractor.py create mode 100644 embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py create mode 100644 embodichain/data/data_engine/indices_unifier.py create mode 100644 embodichain/data/data_engine/online/engine.py create mode 100644 embodichain/data/data_engine/online/enum.py create mode 100644 embodichain/data/data_engine/online/online_generator.py diff --git a/embodichain/data/data_engine/__init__.py b/embodichain/data/data_engine/__init__.py new file mode 100644 index 0000000..6488c76 --- /dev/null +++ b/embodichain/data/data_engine/__init__.py @@ -0,0 +1,5 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- diff --git a/embodichain/data/data_engine/compressed_hdf5.py b/embodichain/data/data_engine/compressed_hdf5.py new file mode 100644 index 0000000..c050b76 --- /dev/null +++ b/embodichain/data/data_engine/compressed_hdf5.py @@ -0,0 +1,545 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_warning, log_error + +try: + import h5ffmpeg as hf + + has_h5ffmpeg = True +except Exception as e: + has_h5ffmpeg = False + log_warning("Fail to import h5ffmpeg.") + +import h5py +import numpy as np + +from typing import Dict, Any, List, Union, Optional +from embodichain.data.enum import ( + Modality, + PrivilegeType, +) +from tqdm import tqdm + +SCALE_FACTOR = 4e3 # Scale factor for depth data +FAR_CLIP = 4.0 # m + + +class CompressedVideoHDF5: + def __init__(self, save_path: str, chunks: int = 20) -> None: + """ + Initializes the data dictionary extractor with the specified save path and number of chunks. + Attempts to configure video encoding settings based on the detected GPU model using the h5ffmpeg library. + Supported GPUs include NVIDIA A800 and NVIDIA GeForce RTX 3060, with specific encoding configurations for each. + If the GPU is unsupported or an error occurs during initialization, a warning is logged and default configuration is used. + + Args: + save_path (str): Path where extracted data will be saved. + chunks (int, optional): Number of chunks to split the data into. Defaults to 20. + + Raises: + ValueError: If the detected GPU is not supported. + """ + self.save_path = save_path + self.chunks = chunks + + try: + import h5ffmpeg as hf + import torch + + name = torch.cuda.get_device_name() + + if "A800" in name or name == "NVIDIA A800-SXM4-80GB": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "3060" in name or name == "NVIDIA GeForce RTX 3060": + self.conf = { + Modality.GEOMAP.value: hf.h264_nvenc(), + Modality.IMAGES.value: hf.h264_nvenc(), + PrivilegeType.MASK.value: hf.h264_nvenc(), + } + elif "3090" in name or name == "NVIDIA GeForce RTX 3090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "4090" in name or name == "NVIDIA GeForce RTX 4090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "Orin" in name: + # FIXME: temporary solution for Orin GPU. Need to test and adjust parameters later for nvenc encoder. + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + else: + raise ValueError("Unsupported GPU: {}".format(name)) + + except Exception as e: + log_warning( + "{}. Please make sure h5ffmpeg is successfully installed.".format(e) + ) + self.conf = {} + + @staticmethod + def is_compressed_hdf5(data: Dict) -> bool: + images_group = data.get("observations", {}).get(Modality.IMAGES.value, {}) + has_compressed_keys = any( + (isinstance(k, str) and ("index" in k or "start" in k)) + for k in images_group.keys() + ) + return has_compressed_keys + + @staticmethod + def get_chunk_name(name: str, id: Union[int, str]) -> str: + """ + Generates a chunk name by concatenating the given name with the provided id, separated by an underscore. + Args: + name (str): The base name for the chunk. + id (Union[int, str]): The identifier to append to the name. + Returns: + str: The resulting chunk name in the format 'name_id'. + """ + + return name + "_{}".format(id) + + @staticmethod + def video_save( + f, + chunks: int, + data: Dict[str, np.ndarray], + key: str, + dtype=np.uint8, + conf: Dict = None, + ): + """ + Saves video data from multiple cameras into an HDF5 file, splitting the data into chunks for efficient storage. + Args: + f: An open HDF5 file handle where the video data will be saved. + data (Dict[str, np.ndarray]): Dictionary mapping camera names to their corresponding video data arrays. + key (str): Key under "observations" group in the HDF5 file to store the video data. + dtype (type, optional): Data type to convert the video frames to before saving (default: np.uint8). + conf (Dict, optional): Additional configuration parameters for HDF5 dataset creation. + Notes: + - Video data for each camera is processed and split into the specified number of chunks. + - Index and start datasets are created for each camera to map frame indices to chunk IDs and chunk start indices. + - Uses CompressedVideoHDF5 utility functions for data formatting and conversion. + - Progress is displayed using tqdm for each chunk being saved. + """ + import h5ffmpeg as hf + + f_images = f["observations"].create_group(key) + + for cam_name in data.keys(): + data_ = data[cam_name] + if len(data_) != 0: + data_ = CompressedVideoHDF5.to_bhw(data_) + + if dtype == np.uint16: + data_ = CompressedVideoHDF5.uint16_depth(data_) + else: + data_ = data_.astype(dtype) + + data_chunks = np.array_split(data_, chunks, axis=0) + data_chunk_ids = np.arange(data_.shape[0]) + data_chunk_ids_ = np.array_split(data_chunk_ids, chunks) + idtochunkid = np.zeros((data_.shape[0])) + chunkid2startid = np.zeros((chunks,)) + for chunkid, temp in enumerate(data_chunk_ids_): + chunkid2startid[chunkid] = min(temp) + for tempi in temp: + idtochunkid[tempi] = chunkid + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "index"), + data=idtochunkid, + ) + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "start"), + data=chunkid2startid, + ) + + for t, data_chunk in enumerate(tqdm(data_chunks)): + _ = f_images.create_dataset( + "{}/{}".format(cam_name, t), + data=data_chunk, + chunks=data_chunk.shape, + **conf, + ) + + @staticmethod + def uint16_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts a depth data array to a uint16 format after applying scaling and clipping. + Args: + data (np.ndarray): The input depth data as a NumPy array. + scale_factor (float, optional): The factor by which to scale the depth data. + Defaults to SCALE_FACTOR. + far_clip (float, optional): The maximum depth value (far clipping plane) + before scaling. Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled and clipped depth data as a NumPy array of type uint16. + """ + return (np.clip(data * scale_factor, 0, far_clip * scale_factor)).astype( + np.uint16 + ) + + @staticmethod + def float32_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts depth data to float32 and scales it by the given scale factor. + Args: + data (np.ndarray): The input depth data array. + scale_factor (float, optional): The factor by which to scale the depth values. Defaults to SCALE_FACTOR. + far_clip (float, optional): The far clipping distance (unused in this function). Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled depth data as a float32 numpy array. + """ + + return data.astype(np.float32) / scale_factor + + @staticmethod + def to_bhw(data: np.ndarray) -> np.ndarray: + """ + Reshapes a 4D numpy array from (vdepth, height, width, channels) to (vdepth, height, width * channels). + If the input is already a 3D array, returns it unchanged. + Args: + data (np.ndarray): Input array of shape (vdepth, height, width, channels) or (vdepth, height, width). + Returns: + np.ndarray: Reshaped array of shape (vdepth, height, width * channels) or the original array if 3D. + Raises: + Logs an error if the input array does not have 3 or 4 dimensions. + """ + + if len(data.shape) == 4: + vdepth, h, w, channels = ( + data.shape[0], + data.shape[1], + data.shape[2], + data.shape[3], + ) + return data.reshape(vdepth, h, w * channels) + elif len(data.shape) == 3: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + @staticmethod + def to_bhwc(data: np.ndarray): + """ + Converts a numpy array to BHWC (Batch, Height, Width, Channels) format. + If the input array has 3 dimensions, it reshapes the array to have a channel dimension of size 3. + If the input array already has 4 dimensions, it returns the array unchanged. + Otherwise, logs an error for unsupported shapes. + Args: + data (np.ndarray): Input numpy array to be converted. + Returns: + np.ndarray: Array in BHWC format. + Raises: + Logs an error if the input array shape is not supported. + """ + + if len(data.shape) == 3: + vdepth, h, w = data.shape + return data.reshape(vdepth, h, -1, 3) + elif len(data.shape) == 4: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + def dump( + self, + ret: Dict, + video_names: List[str] = [ + Modality.IMAGES.value, + PrivilegeType.MASK.value, + Modality.GEOMAP.value, + ], + dtypes: List = [np.uint8, np.uint8, np.uint16], + ): + """ + Dumps the provided data into an HDF5 file, saving specific video data with + compression and specified data types. + Args: + ret (Dict): The data dictionary containing observations and other metadata. + video_names (List[str], optional): A list of video names to extract from + the observations. Defaults to [Modality.IMAGES.value, PrivilegeType.MASK.value, Modality.GEOMAP.value]. + dtypes (List, optional): A list of data types corresponding to each video + name. Defaults to [np.uint8, np.uint8, np.uint16]. + Raises: + AssertionError: If the lengths of `video_names` and `dtypes` are not equal. + RuntimeError: If the configuration (`self.conf`) is empty, indicating that + `h5ffmpeg` is not installed or configured properly. + Notes: + - The method modifies the `ret` dictionary by temporarily removing the + specified video data during the HDF5 file creation process and then + restoring it afterward. + - The `hdfdict.dump` function is used to save the remaining data in the + dictionary, while the `CompressedVideoHDF5.video_save` function handles + the saving of video data with compression. + """ + + assert len(video_names) == len( + dtypes + ), "Inequal length of video names {} and dtypes {}.".format(video_names, dtypes) + import hdfdict + + if self.conf == {}: + raise RuntimeError( + "Please make sure h5ffmpeg is successfully installed before using `dump`." + ) + + pop_ret = {} + for video_name, dtype in zip(video_names, dtypes): + video_data = ret["observations"].pop(video_name) + pop_ret[video_name] = video_data + + # Open the file once and pass the open file object to hdfdict.dump so + # h5py doesn't try to truncate the same path while it is already open. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + for video_name, dtype in zip(video_names, dtypes): + CompressedVideoHDF5.video_save( + f, + self.chunks, + pop_ret[video_name], + video_name, + dtype=dtype, + conf=self.conf[video_name], + ) + + ret["observations"].update(pop_ret) + + @staticmethod + def decode_resources( + f: Dict, + ret: Dict, + name: str, + slice_id: int, + condition: callable, + function: callable, + padding: bool = True, + chunk_id: int = None, + ): + """ + Decodes and processes resources from a hierarchical data structure, applying + a condition and transformation function to the data, and optionally adding + zero-padding. + Args: + f (Dict): The input data dictionary containing observations and metadata. + ret (Dict): The output data dictionary where processed data will be stored. + name (str): The key name under "observations" to access specific data. + slice_id (int): The slice index used to retrieve the corresponding chunk ID. + condition (callable): A function that takes the data as input and returns + a boolean indicating whether the transformation function should be applied. + function (callable): A function to transform the data if the condition is met. + padding (bool, optional): Whether to add zero-padding to the data. Defaults to True. + chunk_id (int, optional): The chunk ID to use instead of deriving it from the slice ID. + Defaults to None. + Returns: + None: The function modifies the `ret` dictionary in place. + """ + + import time + + images = f["observations"][name] + + for cam_name in images.keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + + start_time = time.time() + sliceid2chunkid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "index") + ][:] + chunkid = int(sliceid2chunkid[slice_id]) if chunk_id is None else chunk_id + data_ = images[cam_name][str(chunkid)][:] + # log_warning("".format(time.time() - start_time) + if condition(data_): + data_ = function(data_) + + if padding: + chunkid2startid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "start") + ][:] + start_idx = chunkid2startid[chunkid] + zero_padding = np.zeros_like(data_)[0:1] + zero_padding = np.repeat(zero_padding, repeats=start_idx, axis=0) + ret["observations"][name][cam_name] = np.concatenate( + [zero_padding, data_], 0 + ) + else: + if ret["observations"][name][cam_name] is None: + ret["observations"][name][cam_name] = data_ + else: + ret["observations"][name][cam_name] = np.concatenate( + [ret["observations"][name][cam_name], data_], 0 + ) + + def safe_filter(self, f: Dict, slice_id: int = None) -> Dict: + """ + Filters and processes the input data dictionary based on the configuration + and specified slice ID. + Args: + f (Dict): The input data dictionary containing observations, including + images, masks, and geomap. + slice_id (int, optional): The specific slice ID to process. If None, + processes all chunks. Defaults to None. + Returns: + Dict: The filtered and processed data dictionary with updated + observations for images, masks, and geomap. + Notes: + - The method filters out camera names containing "index" or "start". + - It initializes the return dictionary with None values for images, + masks, and geomap for the filtered camera names. + - Depending on the `slice_id`, it either processes all chunks or a + specific slice using the `CompressedVideoHDF5.decode_resources` + method. + - The processed observations are updated in the input dictionary `f`. + """ + + if self.conf is {}: + return f + + cam_names = [] + for cam_name in f["observations"][Modality.IMAGES.value].keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + cam_names.append(cam_name) + + # Only build return structure for actually present modalities, avoid errors when real data lacks mask/geomap + present_modalities = [] + if Modality.IMAGES.value in f["observations"]: + present_modalities.append(Modality.IMAGES.value) + if PrivilegeType.MASK.value in f["observations"]: + present_modalities.append(PrivilegeType.MASK.value) + if Modality.GEOMAP.value in f["observations"]: + present_modalities.append(Modality.GEOMAP.value) + + ret = {"observations": {}} + for modality_key in present_modalities: + ret["observations"][modality_key] = { + cam_name: None for cam_name in cam_names + } + + if slice_id == None: + # For all chunks + for chunk_id_ in range(self.chunks): + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + None, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + chunk_id=chunk_id_, + padding=False, + ) + + else: + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + slice_id, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + ) + if Modality.IMAGES.value in present_modalities: + f["observations"][Modality.IMAGES.value] = ret["observations"][ + Modality.IMAGES.value + ] + if PrivilegeType.MASK.value in present_modalities: + f["observations"][PrivilegeType.MASK.value] = ret["observations"][ + PrivilegeType.MASK.value + ] + if Modality.GEOMAP.value in present_modalities: + f["observations"][Modality.GEOMAP.value] = ret["observations"][ + Modality.GEOMAP.value + ] + + return f diff --git a/embodichain/data/data_engine/data_dict_extractor.py b/embodichain/data/data_engine/data_dict_extractor.py new file mode 100644 index 0000000..e82f964 --- /dev/null +++ b/embodichain/data/data_engine/data_dict_extractor.py @@ -0,0 +1,801 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_warning, log_error + +try: + import h5ffmpeg as hf + + has_h5ffmpeg = True +except Exception as e: + has_h5ffmpeg = False + log_warning("Fail to import h5ffmpeg.") + +import h5py +import os +import random +import torch +import numpy as np + +from functools import cached_property +from typing import Dict, Any, List, Union, Optional +from embodichain.data.enum import ( + HandQposNormalizer, + Modality, + PrivilegeType, + JointType, + ActionMode, + EefType, + EndEffector, + ControlParts, + ArmName, +) +from embodichain.data.global_mapping import GlobalMapping +from embodichain.lab.sim.sensors import StereoCamera +from embodichain.lab.sim.objects import Robot +from embodichain.lab.gym.envs import BaseEnv, EmbodiedEnv +from embodichain.lab.gym.utils.gym_utils import map_qpos_to_eef_pose +from embodichain.utils.utility import get_right_name +from embodichain.lab.gym.robots.interface import LearnableRobot +from embodichain.lab.gym.utils.misc import is_binocularcam, _data_key_to_control_part +from embodichain.utils import logger +from embodichain.data.data_engine.indices_unifier import ( + StateUnifier, +) +from embodichain.data.data_engine.compressed_hdf5 import CompressedVideoHDF5 +from embodichain.data.enum import ( + SUPPORTED_PROPRIO_TYPES, + SUPPORTED_ACTION_TYPES, + SUPPORTED_EXTRA_VISION_TYPES, +) +from copy import deepcopy +from embodichain.lab.gym.envs.action_bank.configurable_action import ( + get_control_part_joint_ids, +) + +DATA_FORMATS = { + "observations": { + Modality.IMAGES.value: {}, + Modality.GEOMAP.value: {}, + PrivilegeType.MASK.value: {}, + PrivilegeType.EXTEROCEPTION.value: {}, + Modality.STATES.value: {}, + }, + Modality.ACTIONS.value: {}, +} + + +class ActStateStatistic: + def __init__(self, data_dict: Dict, min_len_steps: int) -> None: + self.data_dict = data_dict + self.min_len_steps = min_len_steps + + def prepare_state_and_action( + self, + ): + proprio = self.data_dict["observations"][Modality.STATES.value][:] + num_steps = proprio.shape[0] + # [Optional] We drop too-short episode + if num_steps < self.min_len_steps: + return False, None + # [Optional] We skip the first few still steps + EPS = 1e-2 + # Get the idx of the first qpos whose delta exceeds the threshold + proprio_delta = np.abs(proprio - proprio[0:1]) + indices = np.where(np.any(proprio_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + target_actions = self.data_dict[Modality.ACTIONS.value][:] + # Parse the state and action + state = proprio[first_idx - 1 :] + action = target_actions[first_idx - 1 :] + # Return the resulting sample + + return True, {Modality.STATES.value: state, Modality.ACTIONS.value: action} + + def statistic( + self, + ) -> Dict: + EPS = 1e-8 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + _, episode = self.prepare_state_and_action() + episode_cnt += 1 + + states = episode[Modality.STATES.value] + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) + - (z_state_sum / state_cnt) ** 2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + ) + ).tolist(), + "state_min": state_min.tolist(), + "state_max": state_max.tolist(), + } + + return result + + +class DataDictExtractor: + def __init__( + self, + env: Union[BaseEnv, EmbodiedEnv], + save_path: str = None, + compression_opts: int = 9, + ): + self.env = env + self.save_path = save_path + self.data = {} + self.filtered_action_types = [] + self.filtered_proprio_types = [] + + # First check if control_parts exists and is non-empty. Only filter if valid, else use the original types. + control_parts = self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", None + ) + if control_parts and len(control_parts) > 0: + control_parts_set = set(control_parts) + self.filtered_proprio_types = [ + proprio_name + for proprio_name in SUPPORTED_PROPRIO_TYPES + if any(part in proprio_name for part in control_parts_set) + ] + self.filtered_action_types = [ + action_name + for action_name in SUPPORTED_ACTION_TYPES + if any(part in action_name for part in control_parts_set) + ] + + if ( + len(self.filtered_proprio_types) == 0 + or len(self.filtered_action_types) == 0 + ): + log_warning( + "No control parts found in the robot metadata. Using all supported proprio and action types." + ) + self.filtered_proprio_types = list(SUPPORTED_PROPRIO_TYPES) + self.filtered_action_types = list(SUPPORTED_ACTION_TYPES) + + # save all supported proprio and action types. + robot_meta_config = deepcopy(self.env.metadata["dataset"]["robot_meta"]) + robot_meta_config["observation"][ + Modality.STATES.value + ] = self.filtered_proprio_types + robot_meta_config[Modality.ACTIONS.value] = self.filtered_action_types + + self.state_unifier = StateUnifier(robot_meta=robot_meta_config) + self.compression_opts = compression_opts + + @cached_property + def robot_control_parts(self) -> List[str]: + """Get the robot's control parts. + + Note: + If control_parts is specified in the robot metadata, return those parts. + Otherwise, return all control parts. + + Returns: + List[str]: The robot's control parts. + """ + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + control_parts = robot_meta_config.get("control_parts", None) + if control_parts is None: + log_warning( + "Please make sure you have configurated the control parts. This branch may cause underlying error for training data." + ) + return [] + else: + return control_parts + + def _get_arm_control_parts(self) -> List[str]: + control_parts = self.robot_control_parts + arm_control_parts = [] + for part in control_parts: + if "arm" in part: + arm_control_parts.append(part) + return arm_control_parts + + def _has_exteroception(self) -> bool: + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + return PrivilegeType.EXTEROCEPTION.value in robot_meta_config["observation"] + + def extract( + self, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + data_dict: Dict = DATA_FORMATS, + save: bool = True, + ): + if save: + assert ( + self.save_path is not None + ), "Please provide a save path for the dataset." + data_dict = deepcopy(data_dict) + + self._init_data(data_dict) + + ret = {} + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + + if isinstance(self.env, BaseEnv): + for i, (obs, action) in enumerate(zip(obs_list, action_list)): + self._extract_vision_obs(obs, data_dict) + self._extract_proprioception(obs, data_dict) + self._extract_action(action, data_dict) + action = self._collate_action(data_dict) + proprio = self._collate_proprio(data_dict) + else: + for i, (obs, action) in enumerate(zip(obs_list, action_list)): + self._extract_vision_obs_v2(obs, data_dict) + self._extract_proprioception_v2(obs, data_dict) + self._extract_action_v2(action, data_dict) + action = self._collate_action(data_dict) + proprio = self._collate_proprio(data_dict) + + robot_meta = self._collate_metainfo() + + extra_vision_config = robot_meta_config["observation"]["vision"] + obs = {"observations": {}} + images = self.collate_sub_anns( + data_dict, extra_vision_config, Modality.IMAGES.value + ) + obs["observations"].update(proprio) + obs["observations"].update(images) + + extra_vision_names = list( + set([name for list in extra_vision_config.values() for name in list]) + ) + for extra_vision_name in extra_vision_names: + extra_vision_obs = self.collate_sub_anns( + data_dict, extra_vision_config, extra_vision_name + ) + obs["observations"].update(extra_vision_obs) + + ret.update(robot_meta) + ret.update(obs) + ret.update(action) + + statistics = ActStateStatistic( + ret, self.env.metadata["dataset"]["robot_meta"]["min_len_steps"] + ).statistic() + ret.update(statistics) + + if save: + if has_h5ffmpeg: + cvhdf5 = CompressedVideoHDF5(self.save_path) + all_video_names = [Modality.IMAGES.value] + [ + name + for name in extra_vision_names + if name != PrivilegeType.EXTEROCEPTION.value + ] + all_dtypes = [ + np.uint16 if name == Modality.GEOMAP.value else np.uint8 + for name in all_video_names + ] + cvhdf5.dump(ret, video_names=all_video_names, dtypes=all_dtypes) + else: + logger.log_info( + "h5ffmpeg is not installed, saving dataset without compression." + ) + import hdfdict + + # Open the file once and pass the file object to hdfdict.dump to + # avoid opening/truncating the same file path twice which causes + # "unable to truncate a file which is already open" errors on + # some platforms and HDF5 builds. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + + return ret + + def _init_data(self, data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for proprio_name in self.filtered_proprio_types: + data_dict["observations"][Modality.STATES.value][proprio_name] = [] + for action_name in self.filtered_action_types: + data_dict[Modality.ACTIONS.value][action_name] = [] + + for camera_name, extra_vision_list in extra_vision_config.items(): + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name] = [] + if is_stereo: + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ] = [] + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + data_dict["observations"][extra_vision_name][camera_name] = [] + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ] = [] + + def _extract_vision_obs(self, obs: Dict[str, Any], data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for camera_name, extra_vision_list in extra_vision_config.items(): + if camera_name in obs["sensor"]: + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name].append( + obs["sensor"][camera_name]["rgb"] + ) + if is_stereo: + # save rgb right + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ].append(obs["sensor"][camera_name]["rgb_right"]) + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + if extra_vision_name == PrivilegeType.EXTEROCEPTION.value: + if is_stereo: + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs[extra_vision_name][camera_name]["l"]) + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append(obs[extra_vision_name][camera_name]["r"]) + elif camera_name in obs.get(extra_vision_name, {}): + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs[extra_vision_name][camera_name]) + elif extra_vision_name == PrivilegeType.MASK.value: + # save semantic mask for monocular cameras + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["semantic_mask_l"].astype( + np.uint8 + ) + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name][ + "semantic_mask_r" + ].astype(np.uint8), + ) + elif extra_vision_name == Modality.GEOMAP.value: + if not is_stereo: + log_error( + f"Camera {camera_name} is not stereo, while '{extra_vision_name}' is in gym_config.dataset.robot_meta.vision, please check again." + ) + if "depth" in obs["sensor"][camera_name]: + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs["sensor"][camera_name]["depth"]) + else: + log_error( + f"obs['sensor'][{camera_name}] has no key named 'depth' while it's required in gym_config.dataset.robot_meta.vision, please check again." + ) + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + else: + logger.log_error( + f"Camera {camera_name} not found in observations, please check your sensor configuration in gym_config.json" + ) + + def _extract_vision_obs_v2(self, obs: Dict[str, Any], data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for camera_name, extra_vision_list in extra_vision_config.items(): + if camera_name in obs["sensor"]: + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name].append( + obs["sensor"][camera_name]["color"] + .squeeze(0)[:, :, :3] + .cpu() + .numpy() + ) + if is_stereo: + # save rgb right + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["color_right"] + .squeeze_(0)[:, :, :3] + .cpu() + .numpy() + ) + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + if extra_vision_name == PrivilegeType.EXTEROCEPTION.value: + if is_stereo: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name]["l"] + .cpu() + .numpy() + ) + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs[extra_vision_name][camera_name]["r"] + .cpu() + .numpy() + ) + elif camera_name in obs.get(extra_vision_name, {}): + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name].cpu().numpy() + ) + elif extra_vision_name == PrivilegeType.MASK.value: + # save semantic mask for monocular cameras + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["semantic_mask_l"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["semantic_mask_r"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + elif extra_vision_name == Modality.GEOMAP.value: + if not is_stereo: + log_error( + f"Camera {camera_name} is not stereo, while '{extra_vision_name}' is in gym_config.dataset.robot_meta.vision, please check again." + ) + if "depth" in obs["sensor"][camera_name]: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["depth"] + .squeeze_() + .numpy() + ) + else: + log_error( + f"obs['sensor'][{camera_name}] has no key named 'depth' while it's required in gym_config.dataset.robot_meta.vision, please check again." + ) + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + else: + logger.log_error( + f"Camera {camera_name} not found in observations, please check your sensor configuration in gym_config.json" + ) + + def _extract_action( + self, + action: Dict[str, Any], + data_dict: Dict, + ): + + agent: LearnableRobot = self.env.get_agent() + # extract qpos. + for key in data_dict[Modality.ACTIONS.value].keys(): + indices = agent.get_data_index(key, warning=False) + if len(indices) > 0: + action_data = action[JointType.QPOS.value][indices].copy() + action_data = HandQposNormalizer.normalize_hand_qpos( + action_data, key, agent=agent + ) + data_dict[Modality.ACTIONS.value][key].append(action_data) + qpos = action[JointType.QPOS.value] + action_eef_pose_dict = agent.map_env_qpos_to_eef_pose( + np.array([qpos]), to_dict=True + ) + for key, val in action_eef_pose_dict.items(): + data_dict[Modality.ACTIONS.value][key].append(val[0]) + + def _extract_action_v2( + self, + action: torch.Tensor, + data_dict: Dict, + ): + robot: Robot = self.env.robot + + for key in data_dict[Modality.ACTIONS.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = get_control_part_joint_ids(self.env, key) + qpos_data = ( + action[0, indices].cpu().numpy() + if isinstance(action, torch.Tensor) + else action[0, indices] + ) + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, part, robot=robot + ) + data_dict[Modality.ACTIONS.value][key].append(qpos_data) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, action, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict[Modality.ACTIONS.value][key].append( + val.squeeze_(0).cpu().numpy() + if isinstance(val, torch.Tensor) + else val.squeeze_(0) + ) + + def _extract_proprioception( + self, + obs: Dict[str, Any], + data_dict: Dict, + ): + agent: LearnableRobot = self.env.get_agent() + # extract qpos. + qpos = obs["agent"][agent.uid][JointType.QPOS.value] + for key in data_dict["observations"][Modality.STATES.value].keys(): + indices = agent.get_data_index(key, warning=False) + if len(indices) > 0: + qpos_data = qpos[ + indices + ].copy() # Deep copy to avoid modifying original data + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, key, agent=agent + ) + data_dict["observations"][Modality.STATES.value][key].append(qpos_data) + + eef_pose_dict: Dict = agent.map_env_qpos_to_eef_pose( + np.array([qpos]), to_dict=True + ) + for key, val in eef_pose_dict.items(): + data_dict["observations"][Modality.STATES.value][key].append(val[0]) + + def _extract_proprioception_v2( + self, + obs: Dict[str, Any], + data_dict: Dict, + ): + robot: Robot = self.env.robot + + qpos = obs["robot"][JointType.QPOS.value] + for key in data_dict["observations"][Modality.STATES.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = get_control_part_joint_ids(self.env, key) + qpos_data = qpos[0][indices].cpu().numpy() + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, part, robot=robot + ) + data_dict["observations"][Modality.STATES.value][key].append(qpos_data) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, qpos, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict["observations"][Modality.STATES.value][key].append( + val.squeeze_(0).cpu().numpy() + ) + + def _collate_proprio(self, data_dict: Dict) -> Dict: + proprio_dict = {} + for proprio_name in self.state_unifier.proprio_meta: + proprio = np.array( + data_dict["observations"][Modality.STATES.value][proprio_name] + ) + proprio_dict[proprio_name] = proprio + proprios = self.state_unifier.fill_in_state(proprio_dict) + return {Modality.STATES.value: proprios} + + def _collate_metainfo( + self, + ) -> Dict: + meta_info = { + "arm_dofs": self.env.metadata["dataset"]["robot_meta"].get("arm_dofs", 12), + "observation": self.env.metadata["dataset"]["robot_meta"].get( + "observation", {} + ), + "min_len_steps": self.env.metadata["dataset"]["robot_meta"].get( + "min_len_steps", 125 + ), + } + return { + "robot_meta": meta_info, + "instruction": { + "lang": self.env.metadata["dataset"]["instruction"].get("lang", "") + }, + } + + def _collate_action(self, data_dict: Dict) -> Dict: + action_data_dict = data_dict[Modality.ACTIONS.value] + for k, v in action_data_dict.items(): + action_data_dict[k] = np.array(v) + + action_dict = {} + action_dict.update(action_data_dict) + action = self.state_unifier.fill_in_action(action_dict) + return {Modality.ACTIONS.value: action} + + @staticmethod + def collate_sub_anns( + data_dict: Dict, + extra_vision_config: Dict, + key: str = Modality.IMAGES.value, + ) -> Dict: + ret = {key: {}} + for camera_name in extra_vision_config: + images_list = data_dict["observations"][key].pop(camera_name, None) + if images_list is None: + continue + if len(images_list) > 0: + ret[key][camera_name] = np.empty( + (len(images_list),) + images_list[0].shape, + dtype=images_list[0].dtype, + ) + for idx, image in enumerate(images_list): + ret[key][camera_name][idx] = image + else: + ret[key][camera_name] = np.array([]) + + del images_list + if get_right_name(camera_name) in data_dict["observations"][key]: + images_right_list = data_dict["observations"][key].pop( + get_right_name(camera_name), None + ) + if images_right_list is None: + continue + if len(images_right_list) > 0: + ret[key][get_right_name(camera_name)] = np.empty( + (len(images_right_list),) + images_right_list[0].shape, + dtype=images_right_list[0].dtype, + ) + for idx, image in enumerate(images_right_list): + ret[key][get_right_name(camera_name)][idx] = image + else: + ret[key][get_right_name(camera_name)] = np.array([]) + del images_right_list + + return ret + + +def fetch_imitation_dataset( + env: BaseEnv, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + id: str, + folder_name: str, + save: bool = True, +) -> Dict: + """ + Save imitation dataset for a single episode. + + Args: + env (BaseEnv): Environment instance. + obs_list (List[Dict]): List of observation dicts. + action_list (List[Dict]): List of action dicts. + id (str): Unique identifier for the episode. + folder_name (str): Folder name for saving the dataset. + + Returns: + dict: Contains data_path, id, current_episode, and extracted data. + """ + # Get dataset save path + dataset_path = env.metadata["dataset"].get("save_path", None) + if dataset_path is None: + from embodichain.data import database_demo_dir + + dataset_path = database_demo_dir + + # Create folder if first episode + dataset_save_path = os.path.join(dataset_path, folder_name) + if env.curr_episode == 0 and id: + os.makedirs(dataset_save_path, exist_ok=True) + + # Check robot dof validity + try: + if isinstance(env, BaseEnv): + agent: LearnableRobot = env.get_agent() + max_dofs = len(agent.get_data_index(agent.uid)) + assert ( + env.metadata["dataset"]["robot_meta"]["arm_dofs"] <= max_dofs + ), f"Control dof {env.metadata['dataset']['robot_meta']['arm_dofs']} must be less than {max_dofs}." + else: + robot: Robot = env.robot + assert ( + env.metadata["dataset"]["robot_meta"]["arm_dofs"] <= robot.dof + ), f"Control dof {env.metadata['dataset']['robot_meta']['arm_dofs']} must be less than {robot.dof}." + except Exception as e: + logger.log_error(f"Robot DOF check failed: {e}") + return None + + # Select data format + data_format = DATA_FORMATS + + # Extract and save data + if id is None: + ret = DataDictExtractor(env).extract( + obs_list, action_list, save=False, data_dict=data_format + ) + save_path = None + else: + save_path = os.path.join(dataset_save_path, id + ".hdf5") + logger.log_info(f"Save episode {env.curr_episode} to '{save_path}'") + ret = DataDictExtractor(env, save_path).extract( + obs_list, action_list, save=save, data_dict=data_format + ) + + # Update episode count + env.curr_episode += 1 + + # Return result dict + return { + "data_path": dataset_save_path, + "id": id, + "current_episode": env.curr_episode, + "data": ret, + "save_path": save_path, + } diff --git a/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py new file mode 100644 index 0000000..d4b3477 --- /dev/null +++ b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py @@ -0,0 +1,696 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import fnmatch +from embodichain.utils.logger import log_warning, log_info + +try: + import h5ffmpeg as hf +except Exception as e: + log_warning("Fail to import h5ffmpeg.") +import h5py +import numpy as np +from typing import Dict, Callable, List, Tuple +from embodichain.utils.utility import get_right_name, pad_to_chunk, convert_bytes +from embodichain.utils.logger import log_warning, log_info +from embodichain.data.enum import Proprioception, Image, Exteroception, ModalInput +from copy import deepcopy +from typing import Dict +from embodichain.utils.utility import timer +from embodichain.data.enum import ( + Modality, + PrivilegeType, + ArmEnum, + JointType, + CameraName, + TeleoperationData, + CobotMagicTeleoperationData, +) +from embodichain.data.data_engine.indices_unifier import ActionIndicesGenerator + +DEFAULT_ONLINE_DATASET_LEN = 10000 + + +class SimRealUnifiedDictDataset: + """Dataset class for unified simulation and real-world data. + + This class handles loading, parsing, and sampling from datasets that may be + either simulation or real-world HDF5 files. It supports both offline and online + data sources, and provides utilities for extracting and standardize state, + action, and image modalities. + + Args: + data_path (str): Path to the HDF5 dataset directory. + batch_size (int): Batch size for sampling. + chunk_size (int): Number of timesteps per sample. + state (List): List of state modalities to extract. + output (List): List of output modalities to extract. + data_meta (Dict): Metadata describing the dataset. + arm_type (ArmEnum): Type of robot arm. + img_history_size (int): Number of image frames in history. + state_history_len (int): Number of state frames in history. + precomp_lang_embed (bool, optional): Whether to use precomputed language embeddings. Defaults to True. + online_config (Dict, optional): Configuration for online data engine. Defaults to None. + camera_used (List[str], optional): List of camera names to use. Defaults to None. + indices_generator (ActionIndicesGenerator, optional): Generator for action/state indices. Defaults to None. + """ + + def __init__( + self, + data_path: str, + batch_size: int, + chunk_size: int, + state: List, + output: List, + data_meta: Dict, + arm_type: ArmEnum, + robot_name: str, + img_history_size: int, + state_history_len: int, + precomp_lang_embed: bool = True, + online_engine: Dict = None, + camera_used: List[str] = None, + indices_generator=None, + ) -> None: + """Initialize the SimRealUnifiedDictDataset.""" + # [Modify] The path to the HDF5 dataset directory + # Each HDF5 file contains one episode + + self.precomp_lang_embed = precomp_lang_embed + self.batch_size = batch_size + self.chunk_size = chunk_size + self.state = state + self.output = output + self.data_meta = data_meta + self.arm_type = arm_type + self.robot_name = robot_name + self.img_history_size = img_history_size + self.state_history_len = state_history_len + self.engine = online_engine + self.camera_used = camera_used + self.indices_generator = indices_generator + if self.camera_used is not None: + for cam in CameraName: + if cam.value not in camera_used: + log_warning( + "{} does not exist in {}".format(cam.value, camera_used) + ) + + if self.engine is not None: + self.DATASET_NAME = "online_whatever" + else: + log_info("Init offline vla dataset.", color="purple") + self.data_path = data_path + assert os.path.exists(self.data_path), "{} does not exist.".format( + self.data_path + ) + if os.path.isabs(self.data_path) is False: + self.data_path = os.path.join(os.getcwd(), self.data_path) + self.DATASET_NAME = os.path.basename(self.data_path) + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info( + f"Init dataset with size of: {len(self.file_paths)}", color="purple" + ) + + def update_data_size(self): + """Update the dataset size for validation datasets generated on the fly.""" + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info(f"Update dataset with size of: {len(self.file_paths)}", color="purple") + + def __len__(self): + """Return the number of episodes in the dataset. + + Returns: + int: Number of episodes. + """ + return ( + len(self.file_paths) if self.engine is None else DEFAULT_ONLINE_DATASET_LEN + ) + + def get_item(self, index: int = None, chunk_size: int = None): + """Get a training sample at a random timestep. + + Args: + index (int, optional): The index of the episode. If not provided, a random episode will be selected. + chunk_size (int, optional): Number of timesteps per sample. Defaults to self.chunk_size. + + Returns: + dict: A dictionary containing the training sample. + """ + chunk_size = self.chunk_size if chunk_size is None else chunk_size + while True: + if self.engine is None: + # offline + if index is None: + file_path = np.random.choice(self.file_paths) + else: + file_path = self.file_paths[index] + valid, sample = self.parse_hdf5_file(file_path, chunk_size) + else: + data_dict = self.engine.sample_data() + valid, sample = self.parse_dict(data_dict, chunk_size) + + if valid: + return sample + else: + if self.engine is None: + index = np.random.randint(0, len(self.file_paths)) + + @staticmethod + def parse_exteroception( + file: Dict, + step_id: int, + chunk_size: int, + camera_used: List[str] = [], + ) -> Exteroception: + """Parse exteroception data from the file. + + Args: + file (Dict): Data dictionary. + step_id (int): Starting timestep index. + chunk_size (int): Number of timesteps to extract. + camera_used (List[str], optional): List of cameras to use. + + Returns: + Exteroception: Parsed exteroception data. + """ + exteroception = [] + for cam in camera_used: + exteroception_full = file["observations"][ + PrivilegeType.EXTEROCEPTION.value + ][cam] + exteroception.append(exteroception_full[step_id : step_id + chunk_size]) + + exteroception = np.concatenate(exteroception, 1) + _, cs, kn, _ = exteroception.shape + exteroception = pad_to_chunk(exteroception, chunk_size) + return Exteroception( + data=exteroception.reshape(chunk_size, cs, kn, 2).transpose( + 1, 0, 2, 3 + ) # cs, chunk_size, kn, 2 + ) + + @staticmethod + def parse_img( + file: Dict, + step_id: int, + first_idx: int, + cam: str, + chunk_size: int, + key: str = Modality.IMAGES.value, + camera_used: List[str] = [], + np_ops: Callable = lambda x: x, + ) -> Image: + """Parse image data for a given camera. + + Args: + file (Dict): Data dictionary. + step_id (int): Current timestep index. + first_idx (int): First index for history. + cam (str): Camera name. + chunk_size (int): Number of timesteps to extract. + key (str, optional): Key for image modality. Defaults to Modality.IMAGES.value. + camera_used (List[str], optional): List of cameras to use. + np_ops (Callable, optional): Numpy operation to apply to images. + + Returns: + Image: Parsed image data. + """ + valid_len = min(step_id - (first_idx - 1) + 1, chunk_size) + cam_mask = np.array([False] * (chunk_size - valid_len) + [True] * valid_len) + if cam in camera_used: + temp = file["observations"][key][cam][0] + imgs = np.zeros((valid_len,) + temp.shape, dtype=temp.dtype) + for t, i in enumerate(range(max(step_id - chunk_size + 1, 0), step_id + 1)): + img = file["observations"][key][cam][i] + imgs[t] = img + imgs = np_ops(imgs) + imgs = pad_to_chunk(imgs, chunk_size=chunk_size) + mask = cam_mask.copy() + else: + imgs = np.zeros((chunk_size, 0, 0, 0)) + mask = np.zeros((chunk_size,), dtype=bool) + return Image(data=imgs, mask=mask, name=cam) + + def parse_hdf5_file(self, file_path, chunk_size: int) -> Dict[str, ModalInput]: + """Parse an HDF5 file and extract modalities. + + Args: + file_path (str): Path to the HDF5 file. + chunk_size (int): Number of timesteps to extract. + + Returns: + dict: Parsed modalities. + """ + import hdfdict + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + + with h5py.File(file_path, "r") as f: + data = hdfdict.load(f) + keyname = ( + JointType.QPOS.value + if SimRealUnifiedDictDataset.is_real_datasets(data) + else Modality.STATES.value + ) + step_id = SimRealUnifiedDictDataset.random_step_id( + data, chunk_size, keyname + ) + if not SimRealUnifiedDictDataset.is_real_datasets(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + else: + # Real data: if compressed structure is detected (containing *_index/*_start), also perform decoding filtering + try: + if CompressedVideoHDF5.is_compressed_hdf5(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + except Exception: + pass + ret = self.parse_dict(data, chunk_size, step_id) + + return ret + + @staticmethod + def random_step_id( + f: Dict, chunk_size: int, key: str = Modality.STATES.value + ) -> int: + """Randomly sample a timestep index. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + key (str, optional): Key for state modality. + + Returns: + int: Randomly selected timestep index. + """ + obs = f["observations"] + proprio = obs[key][:] + num_steps = proprio.shape[0] + # We randomly sample a timestep + first_idx = 1 + step_id = np.random.randint( + first_idx, np.maximum(first_idx + 1, num_steps - 1 - chunk_size) + ) + return step_id + + @staticmethod + def is_real_datasets(f: Dict): + """Check if the dataset is a real-world dataset. + + Args: + f (Dict): Data dictionary. + + Returns: + bool: True if real-world dataset, False if simulation. + """ + return "robot_meta" not in f.keys() + + def parse_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + if not SimRealUnifiedDictDataset.is_real_datasets(f): + log_warning("Using simulation hdf5 datasets.") + return self.parse_sim_dict(f, chunk_size, step_id) + else: + log_warning("Using real world offline hdf5 datasets.") + return self.parse_real_dict(f, chunk_size, step_id) + + def parse_real_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a real-world data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + ( + actions, + proprio, + meta, + camera_used_from_dualsys_to_real, + camera_used, + ) = RobotRealDataRouter(robot_name=self.robot_name).realdata2simdata( + f, + chunk_size, + given_camera_used=self.camera_used, + dataset_name=self.DATASET_NAME, + step_id=step_id, + ) + first_idx = 1 + parse_dict = self.parse_core(proprio, actions, step_id, chunk_size) + parse_dict.update({"meta": meta}) + for cam in self.camera_used: + if cam in camera_used: + parse_dict[cam] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + camera_used_from_dualsys_to_real[cam], + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used_from_dualsys_to_real[cam], + ) + else: + raise ValueError( + "cam name {} is not in all cam names {} for this datasets.".format( + cam, camera_used + ) + ) + return True, parse_dict + + @timer + def parse_sim_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a simulation data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + if step_id is None: + step_id = SimRealUnifiedDictDataset.random_step_id(f, chunk_size) + + obs = f["observations"] + metadata = dict(f["robot_meta"]) + first_idx = 1 + + proprio = obs[Modality.STATES.value][:] + num_steps = proprio.shape[0] + min_len_step = metadata["min_len_steps"] + # [Optional] We drop too-short episode + if num_steps < min_len_step: + return False, None + + # We randomly sample a timestep + + camera_used = ( + convert_bytes(list(metadata["observation"]["vision"].keys())) + if self.camera_used is None + else self.camera_used + ) + + # Assemble the meta + meta = { + "dataset_name": self.DATASET_NAME, + "#steps": num_steps, + "step_id": step_id, + "instruction": "", + "camera_used": camera_used, + "instruction": f["language_prompt"] + if f.get("language_prompt", None) + else "", + } + + assert ( + self.indices_generator.dof == metadata["arm_dofs"] + ), "Train dof {} but dataset dof {}.".format( + self.indices_generator.dof, metadata["arm_dofs"] + ) + parse_dict = self.parse_core( + proprio, f[Modality.ACTIONS.value], step_id, chunk_size + ) + parse_dict.update({"meta": meta}) + + for cam in camera_used: + cam_r = get_right_name(cam) + if cam_r in obs[Modality.IMAGES.value] and cam_r not in camera_used: + # insert camera name after cam + camera_used.insert(camera_used.index(cam) + 1, cam_r) + + for cam in camera_used: + parse_dict[cam] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used, + ) + if PrivilegeType.MASK.value in self.data_meta.get("privileges", []): + parse_dict[ + cam + "_{}".format(PrivilegeType.MASK.value) + ] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + PrivilegeType.MASK.value, + camera_used=camera_used, + ) + if PrivilegeType.EXTEROCEPTION.value in self.data_meta.get("privileges", []): + if obs[PrivilegeType.EXTEROCEPTION.value][camera_used[0]].shape[0] != 0: + parse_dict[ + PrivilegeType.EXTEROCEPTION.value + ] = SimRealUnifiedDictDataset.parse_exteroception( + f, + step_id, + chunk_size, + camera_used=camera_used, + ) + + if Modality.GEOMAP.value in self.data_meta.get("additional_modality", []): + if ( + hasattr(obs[Modality.GEOMAP.value][camera_used[0]], "shape") + and obs[Modality.GEOMAP.value][camera_used[0]].shape[0] != 0 + ): + parse_dict[Modality.GEOMAP.value] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + CameraName.HEAD.value, + self.img_history_size, + Modality.GEOMAP.value, + camera_used=camera_used, + np_ops=lambda x: np.tile(np.expand_dims(x, -1), [1, 1, 1, 3]), + ) + + # Return the resulting sample + # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0) + # E.g., return np.zeros((self.img_history_size, 0, 0, 0)) for the key "cam_left_wrist", + # if the left-wrist camera is unavailable on your robot + return True, parse_dict + + def parse_core( + self, proprio: np.ndarray, actions: np.ndarray, step_id: int, chunk_size: int + ): + """Parse and normalize state and action data. + + Args: + proprio (np.ndarray): Proprioceptive state data. + actions (np.ndarray): Action data. + step_id (int): Current timestep index. + chunk_size (int): Number of timesteps to extract. + + Returns: + dict: Dictionary containing normalized state, action, and statistics. + """ + # Parse the state and action + state = proprio[np.maximum(step_id - self.state_history_len, 0) : step_id] + state = np.concatenate( + [np.tile(state[0:1], [self.state_history_len - state.shape[0], 1]), state], + 0, + ) + self.indices_generator: ActionIndicesGenerator + global_mapping = self.indices_generator.global_mapping + state_indices = global_mapping.get_indices( + convert_bytes(self.state), + ) + state_indicator = np.zeros_like(state, dtype=np.int8) + state_indicator[:, state_indices] = 1 + state *= state_indicator + proprio *= state_indicator[0:1] + state_std = np.std(proprio, axis=0) + state_mean = np.mean(proprio, axis=0) + state_norm = np.sqrt(np.mean(proprio**2, axis=0)) + action_indices = self.indices_generator.get( + self.output, + ) + actions = deepcopy(actions[step_id : step_id + chunk_size]) + delta_qpos_indices = self.indices_generator.get_all_delta_qpos( + handness=self.arm_type + ) + qpos_indices = self.indices_generator.get_all_qpos(handness=self.arm_type) + # NOTE: Ops `cumsum` equal to action[:horizon]-action[0:1]. + # TODO: action = action_chunk - current_obs. + actions[:, delta_qpos_indices] = ( + actions[:, qpos_indices] - state[-1:, qpos_indices] + ) + actions = pad_to_chunk(actions, chunk_size=chunk_size) + + action_indicator = np.zeros_like(actions, dtype=np.int8) + action_indicator[:, action_indices] = 1 + actions *= action_indicator[0:1] + + parse_dict = { + "state_std": state_std, + "state_mean": state_mean, + "state_norm": state_norm, + Modality.STATES.value: Proprioception(data=state, mask=state_indicator), + Modality.ACTIONS.value: Proprioception(data=actions, mask=action_indicator), + PrivilegeType.PROGRESS.value: step_id / proprio.shape[0], + } + return parse_dict + + +class RobotRealDataRouter: + def __init__(self, robot_name: str): + from embodichain.data.enum import ( + ControlParts, + EndEffector, + JointType, + ) + + assert robot_name in [ + "CobotMagic", + "DexforceW1", + ], "Robot type {} not supported.".format(robot_name) + self.robot_name = robot_name + + if robot_name == "CobotMagic": + self._REAL_SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, + ] + self.qpos_index_dict = { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: CobotMagicTeleoperationData.LEFT_ARM_QPOS_INDICES.value, + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: CobotMagicTeleoperationData.RIGHT_ARM_QPOS_INDICES.value, + ControlParts.LEFT_EEF.value + + EndEffector.GRIPPER.value: CobotMagicTeleoperationData.LEFT_EEF_GRIPPER_INDICES.value, + ControlParts.RIGHT_EEF.value + + EndEffector.GRIPPER.value: CobotMagicTeleoperationData.RIGHT_EEF_GRIPPER_INDICES.value, + } + self.arm_dofs = 12 + self.camera_used_from_real_to_dualsys = { + CameraName.LEFT_WRIST.value: CameraName.LEFT_WRIST.value, + CameraName.RIGHT_WRIST.value: CameraName.RIGHT_WRIST.value, + CameraName.HEAD.value: CameraName.HEAD.value, + } + elif robot_name == "DexforceW1": + self._REAL_SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.HEAD.value + JointType.QPOS.value, + ControlParts.WAIST.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ] + self.qpos_index_dict = { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: TeleoperationData.LEFT_ARM_QPOS_INDICES.value, + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: TeleoperationData.RIGHT_ARM_QPOS_INDICES.value, + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.LEFT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.RIGHT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.HEAD.value + + JointType.QPOS.value: TeleoperationData.HEAD_QPOS_INDICES.value, + ControlParts.WAIST.value + + JointType.QPOS.value: TeleoperationData.WAIST_QPOS_INDICES.value, + } + self.arm_dofs = 14 + self.camera_used_from_real_to_dualsys = { + "cam_hand_left": CameraName.LEFT_WRIST.value, + "cam_hand_right": CameraName.RIGHT_WRIST.value, + "cam_high_left": CameraName.HEAD.value, + } + + def realdata2simdata( + self, + f: Dict, + chunk_size: int = -1, + given_camera_used: List[str] = [], + dataset_name: str = "", + step_id: int = None, + ): + + from embodichain.data.data_engine.indices_unifier import ( + StateUnifier, + ) + + if step_id is None: + step_id = VLADataset.random_step_id(f, chunk_size, "qpos") + obs = f["observations"] + proprio = obs["qpos"][:] + num_steps = proprio.shape[0] + camera_used_in_real = list(obs[Modality.IMAGES.value].keys()) + camera_used_from_dualsys_to_real = { + val: key for key, val in self.camera_used_from_real_to_dualsys.items() + } + # Now assume it is from W1. + camera_used = [ + self.camera_used_from_real_to_dualsys[cam] + for cam in camera_used_in_real + if cam in self.camera_used_from_real_to_dualsys + ] + + # Assemble the meta + meta = { + "dataset_name": dataset_name, + "#steps": num_steps, + "step_id": step_id, + "camera_used": [ + cam_name for cam_name in given_camera_used if cam_name in camera_used + ], + "instruction": f["language_prompt"] + if f.get("language_prompt", None) + else "", + } + # save all supported proprio and action types. + robot_meta_config = {"arm_dofs": self.arm_dofs, "observation": {}} + + robot_meta_config["observation"][ + Modality.STATES.value + ] = self._REAL_SUPPORTED_PROPRIO_TYPES + robot_meta_config[Modality.ACTIONS.value] = self._REAL_SUPPORTED_PROPRIO_TYPES + state_unifier = StateUnifier(robot_meta=robot_meta_config) + + qpos_dict = {} + for key, indices in self.qpos_index_dict.items(): + qpos_dict[key] = proprio[:, indices] + actions = state_unifier.fill_in_action(qpos_dict) + proprio = state_unifier.fill_in_state(qpos_dict) + return actions, proprio, meta, camera_used_from_dualsys_to_real, camera_used diff --git a/embodichain/data/data_engine/indices_unifier.py b/embodichain/data/data_engine/indices_unifier.py new file mode 100644 index 0000000..3cfd025 --- /dev/null +++ b/embodichain/data/data_engine/indices_unifier.py @@ -0,0 +1,395 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.data.global_indices import ( + GLOBAL_INDICES, + STATE_VEC_LEN, +) +from embodichain.data.global_mapping import GlobalMapping +import numpy as np +from typing import List, Dict, Tuple, Union +from embodichain.data.enum import ( + ArmEnum, + Modality, + JointType, + ActionMode, + EefType, + ControlParts, + EndEffector, + Modality, +) +from embodichain.utils.logger import log_info, log_warning + +DEFAULT_EMPTY_STATE = -1 + +__all__ = ["StateUnifier", "ActionIndicesGenerator"] + +"""Unified state utilities for EmbodiChain. + +This module provides helpers to construct and query a unified state/action +vector representation used across EmbodiChain environments and agents. + +Classes: + StateUnifier: Fill sparse per-modality state/action dictionaries into a + fixed-length unified state vector where unspecified entries are set + to a sentinel value (DEFAULT_EMPTY_STATE). + + ActionIndicesGenerator: Query index ranges in the unified vector for + common action/state groups (e.g. qpos, delta qpos, end-effector pose). + +Constants: + DEFAULT_EMPTY_STATE (int): Sentinel value used to mark unspecified + entries in the unified vector. +""" + + +class StateUnifier: + """Convert per-modality state/action arrays into a unified vector. + + The StateUnifier is constructed with ``robot_meta`` (the robot's + metadata) which should contain an ``observation`` mapping with keys for + modalities (e.g. ``Modality.STATES``) and an ``actions`` specification. + + Attributes: + metadata (dict): Robot metadata passed at construction. + arm_dofs (int): Degrees of freedom for the arm (default: 12). + indices_generator (ActionIndicesGenerator): Helper for action indices. + proprio_meta: Metadata list for proprioceptive modalities. + global_mapping (GlobalMapping): Mapping from names to unified indices. + output: Action output specification from metadata. + state_dim (int): Fixed length of the unified state vector. + """ + + def __init__(self, robot_meta: Dict) -> None: + assert "arm_dofs" in robot_meta + assert "observation" in robot_meta + assert Modality.ACTIONS.value in robot_meta + + self.arm_dofs = robot_meta["arm_dofs"] + self.indices_generator = ActionIndicesGenerator(self.arm_dofs) + self.proprio_meta = robot_meta["observation"][Modality.STATES.value] + self.global_mapping = GlobalMapping(self.arm_dofs) + self.output = robot_meta[Modality.ACTIONS.value] + + self.state_dim = STATE_VEC_LEN + + def fill_in_state( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified state vector from given values. + + Args: + values (np.ndarray or dict): If ``values`` is a numpy array it is + assumed to already be aligned to the unified layout and will + be placed into the output container. If it is a ``dict``, + keys should match entries from the robot metadata + ``observation[Modality.STATES]`` and values are numpy arrays + with a trailing dimension matching each state's width. + + Returns: + np.ndarray: An array with shape ``(..., STATE_VEC_LEN)`` containing + the unified state with unspecified entries set to + ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.global_mapping.get_indices(self.proprio_meta) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for val in values.values(): + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for state_name in self.proprio_meta: + state_indices = self.global_mapping.get_indices([state_name]) + if values[state_name].size != 0: + uni_vec[..., state_indices] = values[state_name] + + return uni_vec + + def fill_in_action( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified action vector from given action values. + + This mirrors :meth:`fill_in_state` but uses the metadata's action + output specification to determine which named outputs map into the + unified vector. + + Args: + values (np.ndarray or dict): Action values aligned to the unified + layout or a mapping from output names to numpy arrays. + + Returns: + np.ndarray: Unified vector shaped ``(..., STATE_VEC_LEN)`` with + unspecified entries filled with ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.indices_generator.get(self.output) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for key, val in values.items(): + + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for out_name in self.output: + state_indices = self.global_mapping.get_indices([out_name]) + if out_name in values and values[out_name].size != 0: + uni_vec[..., state_indices] = values[out_name] + return uni_vec + + +class ActionIndicesGenerator: + """Utility for generating index lists for action/state groups. + + The ActionIndicesGenerator wraps :class:`GlobalMapping` to provide + common queries like retrieving indices for all joint positions (qpos), + delta qpos (relative mode), end-effector transforms/poses, and + hand-specific selections (left/right/both). + + Args: + dof (int, optional): If provided, a :class:`GlobalMapping` is + constructed and reused for queries. + """ + + def __init__(self, dof: int = None): + self.global_mapping = None + self.dof = dof + if dof is not None: + self.global_mapping = GlobalMapping(dof) + + def get_all_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices covering all joint position entries. + + Args: + dof (int, optional): Degrees of freedom to construct a temporary + :class:`GlobalMapping` if the generator was not initialized + with a ``dof``. + handness (str): One of values from :class:`ArmEnum` specifying + which arm(s) to include. + + Returns: + List[int]: Ordered list of indices in the unified vector + corresponding to qpos entries for the requested arm + selection. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [qpos_name], [delta_qpos_name]) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + inv_handness = ControlParts.RIGHT_ARM.value + return self.get( + all_names, dof, [qpos_name], [delta_qpos_name, inv_handness + qpos_name] + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + inv_handness = ControlParts.LEFT_ARM.value + return self.get( + all_names, dof, [qpos_name], [delta_qpos_name, inv_handness + qpos_name] + ) + + def get_all_delta_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices for delta (relative) joint position entries. + + Args and return are the same as :meth:`get_all_qpos` but select the + ``ActionMode.RELATIVE`` named entries. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [delta_qpos_name], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + inv_handness = ControlParts.RIGHT_ARM.value + return self.get( + all_names, dof, [delta_qpos_name], [inv_handness + delta_qpos_name] + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + inv_handness = ControlParts.LEFT_ARM.value + return self.get( + all_names, dof, [delta_qpos_name], [inv_handness + delta_qpos_name] + ) + + def get_all_eef( + self, + dof: int = None, + eef_effector: str = "", + handness: str = ArmEnum.DUAL_ARM.value, + ) -> List[int]: + """Retrieves the indices of all end-effectors (EEF) based on the specified parameters. + + Args: + dof (int, optional): Degree of freedom to use for mapping. If None, uses default. + eef_effector (str, optional): Type of end-effector. Must be one of + EndEffector.DEXTROUSHAND.value, EndEffector.GRIPPER.value, or "" (empty string). + handness (str, optional): Specifies which arm(s) to consider. Must be one of + ArmEnum.DUAL_ARM.value, ArmEnum.LEFT_ARM_ONLY.value, or ArmEnum.RIGHT_ARM_ONLY.value. + + Returns: + List[int]: List of indices corresponding to the selected end-effectors. + + Raises: + AssertionError: If an invalid end-effector type is provided. + """ + assert eef_effector in [ + EndEffector.DEXTROUSHAND.value, + EndEffector.GRIPPER.value, + "", + ], "Invalid end-effector effector type {}.".format(eef_effector) + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get( + all_names, + dof, + [ + ControlParts.LEFT_EEF.value + eef_effector, + ControlParts.RIGHT_EEF.value + eef_effector, + ], + [], + ) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_EEF.value + return self.get( + all_names, + dof, + [handness + eef_effector], + [], + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_EEF.value + return self.get( + all_names, + dof, + [handness + eef_effector], + [], + ) + + def get_all_eef_pose( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices specifically for EEF pose entries. + + Args: + dof (int, optional): Degrees of freedom for mapping lookup. + handness (str): Which arm(s) to include (left/right/both). + + Returns: + List[int]: Indices corresponding to EEF poses. + """ + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [EefType.POSE.value], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + + def get_mapping(self, dof: int = None): + """Return the :class:`GlobalMapping` used by this generator. + + If a mapping was created during initialization (because ``dof`` was + provided), ensure any provided ``dof`` argument matches it. Otherwise + construct and return a temporary :class:`GlobalMapping` for the + requested ``dof``. + + Args: + dof (int, optional): Degrees of freedom to construct a mapping + if one was not provided at initialization. + + Returns: + GlobalMapping: Mapping instance for name->index lookups. + """ + if self.global_mapping is not None: + assert dof is None or dof == self.dof + global_mapping = self.global_mapping + else: + assert ( + dof is not None + ), "Dof must be set when dof is not provided in initialization." + global_mapping = GlobalMapping(dof) + return global_mapping + + def get( + self, + output: List[str], + dof: int = None, + white_list: List[str] = None, + black_list: List[str] = None, + ) -> List[int]: + """Select and return indices from ``output`` names applying optional + white/black list filters. + + Args: + output (List[str]): Names (keys) in a :class:`GlobalMapping` + whose indices should be collected. + dof (int, optional): Degrees of freedom used to construct a + temporary :class:`GlobalMapping` if needed. + white_list (List[str], optional): If provided, only include names + that contain any of these substrings. + black_list (List[str], optional): If provided, exclude names + that contain any of these substrings. + + Returns: + List[int]: Ordered list of unified-vector indices for the + selected names. + """ + + action_indices = [] + global_mapping = self.get_mapping(dof) + + for action_type in output: + if isinstance(white_list, list) and isinstance(black_list, list): + if any([temp in action_type for temp in white_list]) and all( + [temp not in action_type for temp in black_list] + ): + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + else: + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + + return action_indices # keep order. diff --git a/embodichain/data/data_engine/online/engine.py b/embodichain/data/data_engine/online/engine.py new file mode 100644 index 0000000..fbdba4b --- /dev/null +++ b/embodichain/data/data_engine/online/engine.py @@ -0,0 +1,513 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import time +import sys +import numpy as np +from threading import Thread +from typing import Dict, Tuple, Any, List, Callable, Optional +from copy import deepcopy +import threading +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +import torch +import torch.multiprocessing as mp +import copy +from embodichain.utils.logger import ( + log_info, + log_warning, + decorate_str_color, + log_debug, +) + +# Must call cuda init to prevent cuda error in subprocess. +torch._C._cuda_init() + +from dexsim.utility import NumpyRNG + +import threading +from multiprocessing import shared_memory +import pickle +from datetime import datetime +import zmq + +rng = NumpyRNG.get_rng() + +log_info_produce = lambda x: log_info(decorate_str_color(x, "cyan")) +log_info_consume = lambda x: log_info(decorate_str_color(x, "orange")) + +MAX_LOOP_TIMES = 40000 + + +def init_context(port): + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://localhost:{}".format(port)) + return socket + + +class DataPoolCont: + data: Any + count: int = 0 + tag: str + + @staticmethod + def from_list(data_pool: List[Dict]) -> List["DataPoolCont"]: + ret = [] + for data in data_pool: + dcnt = DataPoolCont() + dcnt.data = data + dcnt.count = 0 + dcnt.tag = str(datetime.now()).split(".")[0] + ret.append(dcnt) + return ret + + @staticmethod + def clean_data_pool_in_place( + data_pool: List["DataPoolCont"], clean_indices: List[int] + ): + if clean_indices is None: + data_pool = [] + else: + if len(clean_indices) > 0: + log_debug( + "Clean data pool with data indices {}, counts {}.".format( + clean_indices, + [data_pool[index].count for index in clean_indices], + ), + color="purple", + ) + for i in list(np.sort(clean_indices)[::-1]): + data_pool.pop(i) + + +def fetch_data( + queue_data: mp.Queue, data_pool: List[DataPoolCont], worker_info, debug: bool = True +) -> bool: + start_time = time.time() + try: + existing_shm = queue_data.get(timeout=5) + except Exception as error: + log_debug("Timeout! {}.".format(str(error)), color="red") + return False + log_debug( + "[Thread {}][Worker {}][Get] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + start_time = time.time() + scene_data = pickle.loads(existing_shm.buf[:]) + log_debug( + "[Thread {}][Worker {}][Pickle] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + if np.random.random() > 0.5 or queue_data.qsize() == 0: + start_time = time.time() + queue_data.put(existing_shm) # put back + log_debug( + "[Thread {}][Worker {}][Put] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + assert isinstance(scene_data, list), "Invalid data format {}.".format( + type(scene_data) + ) + start_time = time.time() + data = DataPoolCont.from_list(scene_data) + data_pool.extend(data) + + log_debug( + "[Thread {}][Worker {}][Other] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + return True + + +class RestockCriterion: + def __init__(self, data_pool_limit: int, buffer_size: int, max_sample_num: int): + self.data_pool_limit = data_pool_limit + self.buffer_size = buffer_size + self.max_sample_num = max_sample_num + + def restock_condition(self, data_pool: List, queue: mp.Queue) -> bool: + return len(data_pool) < self.data_pool_limit + + def expired_condition( + self, data_pool: List[DataPoolCont], inverse: bool = False + ) -> List[bool]: + + if len(data_pool) == 0: + return [] + + if inverse: + return [data.count <= self.max_sample_num for data in data_pool] + else: + return [data.count > self.max_sample_num for data in data_pool] + + +class OnlineEngine: + """Data manager for online data production and training. + + The objectives of this class are: + - Manage the fetch data in a separate thread. + - Perform data synchronization between the data production process and + the training process (main process). + - Provide data sampling interface for the training process, which is designed + to return a batch of synthetic data with the different scene id. + - Data lifecycle management. + + To achieve the above objectives, the following functions should be implemented: + - from_shm_thread (static method) + + Args: + insight_config (List[CfgNode]): The config of insight pipeline. + episode_limit (int, optional): The maximum number of frames in the data pool. Defaults to 24. + max_sample_num (int, optional): The maximum number of times that a data can be sampled. + Defaults to 2. + target_device (torch.device, optional): The target device of the data. Defaults to torch.device('cpu'). + annos_param (Dict[str, Any], optional): The parameters of the annotations. Defaults to None. + data_gen_func (Callable, optional): The data generation function. Defaults to None. + unique_scene_frame (int, optional): The number of unique scene frame to be sampled. Defaults to None. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + + def __init__( + self, + episode_limit: int = 24, + max_sample_num: int = 2, + port: int = 5555, + buffer_size: int = 10, + multiprocess: bool = False, + **kwargs, + ) -> None: + + self.episode_limit = episode_limit + self._max_sample_num = max_sample_num + self.port = port + + self._data_pool = [] + + self._duration = 0.01 + + self._context = mp.get_context("forkserver") + + self._queue_data = self._context.Queue() + self._queue_data.cancel_join_thread() + + self.buffer_size = buffer_size + + self._data_gen_proc = None + self._fetch_data_thread = None + self._restock_data_pool = None + + self._is_started = False + self._is_restocked = False + self._socket = init_context(port + 1 if multiprocess else port) + + self._restock_criterion = RestockCriterion( + data_pool_limit=episode_limit, + buffer_size=buffer_size, + max_sample_num=max_sample_num, + ) + self._lock = threading.RLock() + + def start( + self, + ) -> None: + """Start the data production process and the data synchronization thread. + + Args: + wait_for_limit (bool, optional): Whether to wait for the data pool to reach + the frame limit. Defaults to False. + """ + + self._signal_gen = self._context.Value("b", True) + self._signal_fetch = self._context.Value("b", True) + + self._fetch_data_thread = Thread( + target=self.from_shm_thread, + args=( + self._socket, + self._queue_data, + self._duration, + self.buffer_size, + ), + daemon=True, + ) + self._fetch_data_thread.start() + self._is_started = True + log_info( + "Now start the thread to fetch data from share memory.", color="purple" + ) + + def start_restock(self, static: bool = False): + if static: + self._restock_data_pool = Thread( + target=self.restock_data_pool_static, + args=( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ), + daemon=True, + ) + else: + self._restock_data_pool = Thread( + target=self.restock_data_pool, + daemon=True, + ) + + self._restock_data_pool.start() + self._is_restocked = True + + def stop(self) -> None: + if self.is_started: + self._is_started = False + self._signal_fetch.value = 2 + self._fetch_data_thread.join() + self.empty_queue(self._queue_data, self._context) + self.clean_data_pool_in_place() + self._signal_gen.value = 2 + else: + log_info( + "The data generation process has not been started.", color="purple" + ) + + @property + def is_started(self) -> bool: + return self._is_started + + @property + def data_size(self) -> int: + with self._lock: + return len(self._data_pool) + + @property + def queue_size(self) -> int: + return self._queue.qsize() + + @property + def unique_scene_frame(self) -> int: + return self._unique_scene_frame + + @staticmethod + def empty_queue(queue: mp.Queue, context: mp) -> None: + while queue.qsize() > 0: + try: + queue.get() + except Exception as e: + log_info("queue put invaild data format") + queue.close() + queue.join_thread() + queue = context.Queue() + break + return queue + + @staticmethod + def empty_share_memory(queue: mp.Queue) -> None: + while queue.qsize() > 0: + shm_name = queue.get() + shm = shared_memory.SharedMemory(shm_name) + shm.close() + shm.unlink() + + def restock_data_pool(self): + return OnlineEngine.restock_data_pool_static( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ) + + @staticmethod + def restock_data_pool_static( + data_pool: List[DataPoolCont], + queue_data: mp.Queue, + duration: float, + restock_criterion: RestockCriterion, + context, + thread_lock, + ): + counts = 0 + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + while True: + time.sleep(duration) + # always clean the data pool first. + + start_time = time.time() + with thread_lock: + # delete + clean_indices = list( + np.argwhere(restock_criterion.expired_condition(data_pool)).reshape( + -1 + ) + ) + DataPoolCont.clean_data_pool_in_place( + data_pool, + clean_indices, + ) + if len(clean_indices) > 0: + log_debug( + "[Thread {}][Delete][Cost {}s]".format( + threading.current_thread().ident, time.time() - start_time + ) + ) + + # after clean, we check whether to restock data. + while restock_criterion.restock_condition(data_pool, queue_data): + + prev_data_size = len(data_pool) + should_fetch = False + for i in range(worker_info.num_workers): + if queue_data.qsize() > 0 and worker_info.id == i: + should_fetch = True + if should_fetch: + start_time = time.time() + with thread_lock: + # add + fetch_data( + data_pool=data_pool, + queue_data=queue_data, + worker_info=worker_info, + ) + log_debug( + "[Thread {}][Worker {}][ToDataPool] Produce data: {}->{}. Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + prev_data_size, + len(data_pool), + time.time() - start_time, + ) + ) + counts = 0 + else: + counts += 1 + + if counts % MAX_LOOP_TIMES == 0 and counts != 0: + log_info("Can not find the shm after {} times.".format(counts)) + # queue_data = OnlineEngine.empty_queue(queue_data, context) + + @staticmethod + def from_shm_thread( + socket, + queue_data: mp.Queue, + duration: float = 0.001, + buffer_size: int = 10, + ) -> None: + """The data fetching thread for data synchronization. + + The queue_data_size is used to control the data fetching thread. + If queue_data_size < buffer_size, the data fetching thread will fetch data from the queue. + If queue_data_size >= buffer_size, the data fetching thread will stop fetch data. + + Args: + socket (zmq.Context): The socket send signal for connect fetch and generator. + queue_data (mp.Queue): This queue contains information about shared memory. + duration (float, optional): _description_. Defaults to 0.001. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + counts = 0 + while True: + time.sleep(duration) + counts += 1 + if queue_data.qsize() < buffer_size: + socket.send_string(ConsumerTeleEnum.SHAKEHAND.value) + message = socket.recv() + try: + message_str = message.decode() + except Exception as e: + log_debug(str(e), color="red") + message_str = "" + if message_str != ProducerTeleEnum.NOREADY.value: + log_debug("Receive data.", color="purple") + shm_name = pickle.loads(message).popleft() + existing_shm = shared_memory.SharedMemory(name=shm_name) + queue_data.put(existing_shm) + log_debug( + "[FromShmThread] Produce queue: {}->{};".format( + queue_data.qsize() - 1, queue_data.qsize() + ) + ) + else: + if counts % MAX_LOOP_TIMES == 0: + log_debug("Queue is full. Skip this stage.", "purple") + + def sample_data( + self, + ): + + if self._is_restocked: + pass + else: + log_debug("Now start the thread to restock data.", color="purple") + self.start_restock(static=False) + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + counts = 0 + while True: + time.sleep(self._duration) + if len(self._data_pool) > 0: + start_time = time.time() + with self._lock: + index = rng.integers(0, len(self._data_pool)) + data = deepcopy(self._data_pool[index]) + self._data_pool[index].count += 1 + log_debug( + "[SampleData, worker {}] Consume data {}: index {}; times: {}->{}; Show queue size: {}; Cost time: {}s.".format( + worker_info.id, + data.tag, + index, + data.count, + data.count + 1, + self._queue_data.qsize(), + np.round(time.time() - start_time, 4), + ) + ) + counts = 0 + return data.data + else: + counts += 1 + if counts % MAX_LOOP_TIMES == 0: + log_info("Data pool is always empty after {} times.".format(counts)) diff --git a/embodichain/data/data_engine/online/enum.py b/embodichain/data/data_engine/online/enum.py new file mode 100644 index 0000000..fe521bd --- /dev/null +++ b/embodichain/data/data_engine/online/enum.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from enum import Enum + + +class ConsumerTeleEnum(Enum): + SHAKEHAND = "Data is ready?" + CONSUME = "Fetch data!" + NOCONSUME = "Data_pool is full." + GOTDATA = "Feched data!" + NOGOTDATA = "Not fetching data." + + +class ProducerTeleEnum(Enum): + READY = "Yes" + NOREADY = "No ready" + FULL = "Data_pool is full" + FAIL = "Failed" + SEND = "Send!" + EMPTYSTR = "Empty String." diff --git a/embodichain/data/data_engine/online/online_generator.py b/embodichain/data/data_engine/online/online_generator.py new file mode 100644 index 0000000..b1364f4 --- /dev/null +++ b/embodichain/data/data_engine/online/online_generator.py @@ -0,0 +1,181 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import time +import zmq +import random +from multiprocessing import shared_memory +import pickle +from collections import deque +from typing import List +from threading import Thread +import multiprocessing as mp +import traceback +from embodichain.utils.logger import log_info, log_warning, log_error, log_debug +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +torch._C._cuda_init() + + +class OnlineGenerator: + """Callback collection for online training mode.""" + + def __init__( + self, port: int, max_limit_gb: int = 50, multiprocess: bool = False, **kwargs + ) -> None: + self.shm_val = None + max_limit = max_limit_gb * 1024**3 + self._context = mp.get_context("forkserver") + self.port = port + self.socket = self.init_context(self.port, multiprocess) + self._duration = 0.01 + self.queue = deque() + self.queue_memroy = deque() + self.max_limit = max_limit + + self.validation_config = kwargs.get("validation", {}) + + def get_validation_config(self): + return self.validation_config + + def init_context(self, port, multiprocess: bool = False): + context = zmq.Context() + socket = context.socket(zmq.REP) + if multiprocess: + socket.connect(f"tcp://127.0.0.1:{port}") + else: + socket.bind(f"tcp://*:{port}") + + return socket + + def generator(self, generate_func, loop_times: int = -1, **kwargs): + self.signal = self._context.Value("b", True) + + self._zmq_send = Thread( + target=self.zmq_send, args=(self.queue, self.signal), daemon=True + ) + self._zmq_send.start() + log_debug("Start zmq sending.") + scene_id = 0 + + # -1 means infinite loop + while scene_id < loop_times or loop_times == -1: + if self.signal.value == 1: + first_time = True + try: + t0 = time.time() + return_list = generate_func( + time_id=scene_id, **self.validation_config + ) + + # TODO: support multiple trajectories for each scene. + if len(return_list) > 1: + log_error( + "Only support one trajectory for each scene in online generation mode." + ) + + data_dict_list = [return_list[0]["data"]] + + if ( + scene_id == 0 + and self.validation_config.get("num_samples", 0) > 0 + and "data_path" in return_list[0] + ): + # create shared memory to store the validation dataset path, which will be accessed by training process. + import sys + + data_path = return_list[0]["data_path"] + + shared_name = self.validation_config.get( + "dataset_name", "val_data_path" + ) + log_info( + f"Create shared memory for validation data path: {shared_name}", + color="green", + ) + self.shm_val = shared_memory.SharedMemory( + name=shared_name, + create=True, + size=len(data_path.encode()) + sys.getsizeof(""), + ) + self.shm_val.buf[: len(data_path.encode())] = data_path.encode() + log_info( + f"Craete shared memory for validation data path: {data_path}" + ) + + log_info( + f"Generate scene {scene_id + 1} time cost: {time.time() - t0}" + ) + serialized_data = pickle.dumps(data_dict_list) + shm = shared_memory.SharedMemory( + create=True, size=len(serialized_data) + ) + self.queue.append(shm.name) + self.queue_memroy.append( + {"name": shm.name, "size": len(serialized_data)} + ) + shm.buf[: len(serialized_data)] = serialized_data + except Exception as e: + log_error(f"Error in data generation process: {e}.") + traceback.print_exc() + self._zmq_send.join() + break + scene_id += 1 + self.empty_memory() + elif self.signal.value == 0: + if first_time: + log_warning("zmq recive full signal, wait generator signal") + first_time = False + log_warning("Signal value is 0.") + time.sleep(self._duration) + continue + else: + log_error("Unknown signal, data generator stop") + break + + def zmq_send(self, queue, signal): + while True: + try: + message = self.socket.recv_string() + if message == ConsumerTeleEnum.SHAKEHAND.value: + if len(queue) > 0: + log_warning( + "Recieve {} and send [data] to consumer.".format(message) + ) + self.socket.send(pickle.dumps(queue)) + queue.clear() + else: + self.socket.send(ProducerTeleEnum.NOREADY.value.encode()) + signal.value = 1 + except Exception as e: + print(e) + traceback.print_exc() + break + + def empty_memory(self): + total_size = sum([x["size"] for x in self.queue_memroy]) + log_info(f"share memory size is {total_size/(1024**3)} GB") + while total_size >= self.max_limit: + shm_name = self.queue_memroy.popleft() + if shm_name["name"] in self.queue: + log_info(f"remove {shm_name['name']} from queue") + self.queue.remove(shm_name["name"]) + try: + shm = shared_memory.SharedMemory(shm_name["name"]) + except: + continue + shm.close() + shm.unlink() + total_size = sum([x["size"] for x in self.queue_memroy]) + + def __del__(self): + if self.shm_val: + self.shm_val.close() + self.shm_val.unlink() From 8bfeb7f028165b9396563708146bcd34b1878520 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 14:35:44 +0800 Subject: [PATCH 10/49] Fix: can not process 'func' --- embodichain/lab/gym/utils/gym_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 40ced87..b6aee4a 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -437,6 +437,9 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) + # load dataset config + env_cfg.dataset = config["env"].get("dataset", None) + # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. env_cfg.dataset = ComponentCfg() @@ -471,6 +474,7 @@ class ComponentCfg: "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", "embodichain.lab.gym.envs.managers.events", + "embodichain.lab.gym.envs.managers.real2sim", ] # parser env events config From 8dba0da72b4cc0c6937b3a415da2377ef1ae53b5 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 14:36:34 +0800 Subject: [PATCH 11/49] Migrate indices and mapping --- embodichain/data/global_indices.py | 122 +++++++++++++++++++++++ embodichain/data/global_mapping.py | 151 +++++++++++++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 embodichain/data/global_indices.py create mode 100644 embodichain/data/global_mapping.py diff --git a/embodichain/data/global_indices.py b/embodichain/data/global_indices.py new file mode 100644 index 0000000..b7cb8dc --- /dev/null +++ b/embodichain/data/global_indices.py @@ -0,0 +1,122 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np + +GLOBAL_INDICES = { + # [0, 10): right arm joint positions + **{"arm_joint_{}_pos".format(i): i for i in range(10)}, + **{"right_arm_joint_{}_pos".format(i): i for i in range(10)}, + # [10, 15): right gripper joint positions + **{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + **{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + "gripper_open": 10, # alias of right_gripper_joint_0_pos + "right_gripper_open": 10, + # [15, 25): right arm joint velocities + **{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + **{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + # [25, 30): right gripper joint velocities + **{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + **{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + "gripper_open_vel": 25, # alias of right_gripper_joint_0_vel + "right_gripper_open_vel": 25, + # [30, 33): right end effector positions + "eef_pos_x": 30, + "right_eef_pos_x": 30, + "eef_pos_y": 31, + "right_eef_pos_y": 31, + "eef_pos_z": 32, + "right_eef_pos_z": 32, + # [33, 39): right end effector 6D pose + "eef_angle_0": 33, + "right_eef_angle_0": 33, + "eef_angle_1": 34, + "right_eef_angle_1": 34, + "eef_angle_2": 35, + "right_eef_angle_2": 35, + "eef_angle_3": 36, + "right_eef_angle_3": 36, + "eef_angle_4": 37, + "right_eef_angle_4": 37, + "eef_angle_5": 38, + "right_eef_angle_5": 38, + # [39, 42): right end effector velocities + "eef_vel_x": 39, + "right_eef_vel_x": 39, + "eef_vel_y": 40, + "right_eef_vel_y": 40, + "eef_vel_z": 41, + "right_eef_vel_z": 41, + # [42, 45): right end effector angular velocities + "eef_angular_vel_roll": 42, + "right_eef_angular_vel_roll": 42, + "eef_angular_vel_pitch": 43, + "right_eef_angular_vel_pitch": 43, + "eef_angular_vel_yaw": 44, + "right_eef_angular_vel_yaw": 44, + # [45, 50): reserved + # [50, 60): left arm joint positions + **{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)}, + # [60, 65): left gripper joint positions + **{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)}, + "left_gripper_open": 60, # alias of left_gripper_joint_0_pos + # [65, 75): left arm joint velocities + **{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)}, + # [75, 80): left gripper joint velocities + **{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)}, + "left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel + # [80, 83): left end effector positions + "left_eef_pos_x": 80, + "left_eef_pos_y": 81, + "left_eef_pos_z": 82, + # [83, 89): left end effector 6D pose + "left_eef_angle_0": 83, + "left_eef_angle_1": 84, + "left_eef_angle_2": 85, + "left_eef_angle_3": 86, + "left_eef_angle_4": 87, + "left_eef_angle_5": 88, + # [89, 92): left end effector velocities + "left_eef_vel_x": 89, + "left_eef_vel_y": 90, + "left_eef_vel_z": 91, + # [92, 95): left end effector angular velocities + "left_eef_angular_vel_roll": 92, + "left_eef_angular_vel_pitch": 93, + "left_eef_angular_vel_yaw": 94, + # [95, 100): reserved + # [100, 102): base linear velocities + "base_vel_x": 100, + "base_vel_y": 101, + # [102, 103): base angular velocities + "base_angular_vel": 102, + # [103, 115): dextrous hand joint positions + **{"left_hand_joint_{}_pos".format(i): i + 103 for i in range(6)}, + **{"right_hand_joint_{}_pos".format(i): i + 109 for i in range(6)}, + # [115, 119): torso joint positions + **{"torso_joint_{}_pos".format(i): i + 115 for i in range(4)}, + # [119, 121): head joint positions + **{"head_joint_{}_pos".format(i): i + 119 for i in range(2)}, + "waist": 115, + # [121, 123): head joint velocities + **{"head_joint_{}_vel".format(i): i + 121 for i in range(2)}, + "waist_vel": 123, + # [124, 128): reserved +} + + +STATE_VEC_LEN = 128 + + +def get_all_left_related_indices(including_end: bool = True): + if including_end: + return np.arange(50, 128, step=1) + else: + return np.arange(50, 100) + + +def get_all_right_related_indices(): + return np.arange(0, 50) diff --git a/embodichain/data/global_mapping.py b/embodichain/data/global_mapping.py new file mode 100644 index 0000000..0be40a0 --- /dev/null +++ b/embodichain/data/global_mapping.py @@ -0,0 +1,151 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.data.enum import ( + ControlParts, + ActionMode, + EndEffector, + JointType, + EefType, + is_dual_arms, +) +from embodichain.data.global_indices import GLOBAL_INDICES +import numpy as np +from typing import List + + +class GlobalMapping: + def __init__(self, dof: int): + self_attrs = GlobalMapping.__dict__ + num_arm = 2 if is_dual_arms(dofs=dof) else 1 + single_dof = dof // num_arm + function_dict = {} + for k, v in self_attrs.items(): + if isinstance(v, staticmethod) and "__" not in k: + function_dict.update(v.__func__(dof=single_dof, num_arm=num_arm)) + self.mapping_from_name_to_indices = function_dict + + @staticmethod + def get_qpos_indices(dof: int, num_arm, **kwrags): + + return { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.HEAD.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_pos".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + JointType.QPOS.value: [GLOBAL_INDICES["waist"]], + } + + @staticmethod + def get_gripper_open_state_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open"]], + ControlParts.RIGHT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open"]], + } + + @staticmethod + def get_hand_qpos_indices(num_arm: int, hand_dof: int = 6, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"left_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"right_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + } + + @staticmethod + def get_gripper_open_vel_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open_vel"]], + ControlParts.RIGHT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open_vel"]], + } + + @staticmethod + def get_delta_qpos_indices(dof: int, num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.HEAD.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_vel".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [GLOBAL_INDICES["waist_vel"]], + } + + @staticmethod + def get_eef_pose_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["left_eef_pos_x"], + GLOBAL_INDICES["left_eef_pos_y"], + GLOBAL_INDICES["left_eef_pos_z"], + GLOBAL_INDICES["left_eef_angle_0"], + GLOBAL_INDICES["left_eef_angle_1"], + GLOBAL_INDICES["left_eef_angle_2"], + GLOBAL_INDICES["left_eef_angle_3"], + GLOBAL_INDICES["left_eef_angle_4"], + GLOBAL_INDICES["left_eef_angle_5"], + ], + ControlParts.RIGHT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["right_eef_pos_x"], + GLOBAL_INDICES["right_eef_pos_y"], + GLOBAL_INDICES["right_eef_pos_z"], + GLOBAL_INDICES["right_eef_angle_0"], + GLOBAL_INDICES["right_eef_angle_1"], + GLOBAL_INDICES["right_eef_angle_2"], + GLOBAL_INDICES["right_eef_angle_3"], + GLOBAL_INDICES["right_eef_angle_4"], + GLOBAL_INDICES["right_eef_angle_5"], + ], + } + + def get_indices(self, state_meta: List[str]): + state_indices = [] + + for proprio_name in state_meta: + state_indices += self.mapping_from_name_to_indices[proprio_name] + + return state_indices + + def ret_all_state( + self, + ): + state_indices = [] + + for val in self.mapping_from_name_to_indices.values(): + state_indices += val + + return state_indices From 4195d30320f1e1db565d5c343c61f4a83c29f035 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:03:04 +0800 Subject: [PATCH 12/49] Migrate environments: pour water and rearrangement --- embodichain/lab/gym/envs/__init__.py | 8 ++ .../gym/envs/tasks/tableware/pour_water_v3.py | 78 ++++++++++++++++ .../envs/tasks/tableware/rearrangement_v3.py | 93 +++++++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py create mode 100644 embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py index 8825769..e4a16da 100644 --- a/embodichain/lab/gym/envs/__init__.py +++ b/embodichain/lab/gym/envs/__init__.py @@ -26,6 +26,14 @@ PourWaterEnv, ) from embodichain.lab.gym.envs.tasks.tableware.scoop_ice import ScoopIce +from embodichain.lab.gym.envs.tasks.tableware.pour_water_v3 import ( + PourWaterEnv3, + PourWaterAgentEnv3, +) +from embodichain.lab.gym.envs.tasks.tableware.rearrangement_v3 import ( + RearrangementEnv3, + RearrangementAgentEnv3, +) # Reinforcement learning environments from embodichain.lab.gym.envs.tasks.rl.push_cube import PushCubeEnv diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py new file mode 100644 index 0000000..20963f2 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py @@ -0,0 +1,78 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, Union, Optional, Sequence, Tuple, List + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import configclass, logger + +from embodichain.lab.gym.envs.tasks.tableware.base_agent_env import BaseAgentEnv + +__all__ = ["PourWaterEnv3"] + + +@register_env("PourWater-v3", max_episode_steps=600) +class PourWaterEnv3(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + action_config = kwargs.get("action_config", None) + if action_config is not None: + self.action_config = action_config + + def is_task_success(self, **kwargs) -> torch.Tensor: + """Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. + + Args: + **kwargs: Additional arguments for task-specific success criteria. + + Returns: + torch.Tensor: A boolean tensor indicating success for each environment in the batch. + """ + + bottle = self.sim.get_rigid_object("bottle") + cup = self.sim.get_rigid_object("cup") + + bottle_final_xpos = bottle.get_local_pose(to_matrix=True) + cup_final_xpos = cup.get_local_pose(to_matrix=True) + + bottle_ret = self._is_fall(bottle_final_xpos) + cup_ret = self._is_fall(cup_final_xpos) + + return ~(bottle_ret | cup_ret) + + def _is_fall(self, pose: torch.Tensor) -> torch.Tensor: + # Extract z-axis from rotation matrix (last column, first 3 elements) + pose_rz = pose[:, :3, 2] + world_z_axis = torch.tensor([0, 0, 1], dtype=pose.dtype, device=pose.device) + + # Compute dot product for each batch element + dot_product = torch.sum(pose_rz * world_z_axis, dim=-1) # Shape: (batch_size,) + + # Clamp to avoid numerical issues with arccos + dot_product = torch.clamp(dot_product, -1.0, 1.0) + + # Compute angle and check if fallen + angle = torch.arccos(dot_product) + return angle >= torch.pi / 4 + + +@register_env("PourWaterAgent-v3", max_episode_steps=600) +class PourWaterAgentEnv3(BaseAgentEnv, PourWaterEnv3): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + super()._init_agents(**kwargs) + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + obs, info = super().reset(seed=seed, options=options) + super().get_states() + return obs, info \ No newline at end of file diff --git a/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py new file mode 100644 index 0000000..2d28b6b --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py @@ -0,0 +1,93 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, Union, Optional, Sequence, Tuple, List + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import configclass, logger + +from embodichain.lab.gym.envs.tasks.tableware.base_agent_env import BaseAgentEnv + +__all__ = ["RearrangementEnv3"] + + +@register_env("Rearrangement-v3", max_episode_steps=600) +class RearrangementEnv3(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + action_config = kwargs.get("action_config", None) + if action_config is not None: + self.action_config = action_config + + def is_task_success(self) -> bool: + fork = self.sim.get_rigid_object("fork") + spoon = self.sim.get_rigid_object("spoon") + plate = self.sim.get_rigid_object("plate") + plate_pose = plate.get_local_pose(to_matrix=True) + # TODO: now only for 1 env + ( + spoon_place_target_x, + spoon_place_target_y, + spoon_place_target_z, + ) = self.affordance_datas["spoon_place_pose"][:3, 3] + ( + fork_place_target_x, + fork_place_target_y, + fork_place_target_z, + ) = self.affordance_datas["fork_place_pose"][:3, 3] + + spoon_pose = spoon.get_local_pose(to_matrix=True) + spoon_x, spoon_y, spoon_z = spoon_pose[0, :3, 3] + + fork_pose = fork.get_local_pose(to_matrix=True) + fork_x, fork_y, fork_z = fork_pose[0, :3, 3] + + tolerance = self.metadata.get("success_params", {}).get("tolerance", 0.02) + + # spoon and fork should with the x y range of tolerance related to plate. + return ~( + abs(spoon_x - spoon_place_target_x) > tolerance + or abs(spoon_y - spoon_place_target_y) > tolerance + or abs(fork_x - fork_place_target_x) > tolerance + or abs(fork_y - fork_place_target_y) > tolerance + ) + +@register_env("RearrangementAgent-v3", max_episode_steps=600) +class RearrangementAgentEnv3(BaseAgentEnv, RearrangementEnv3): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + super()._init_agents(**kwargs) + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + obs, info = super().reset(seed=seed, options=options) + super().get_states() + return obs, info + + def is_task_success(self): + fork = self.sim.get_rigid_object("fork") + spoon = self.sim.get_rigid_object("spoon") + plate = self.sim.get_rigid_object("plate") + + plate_pose = plate.get_local_pose(to_matrix=True) + spoon_place_target_y = plate_pose[0, 1, 3] - 0.16 + fork_place_target_y = plate_pose[0, 1, 3] + 0.16 + + spoon_pose = spoon.get_local_pose(to_matrix=True) + spoon_y = spoon_pose[0, 1, 3] + + fork_pose = fork.get_local_pose(to_matrix=True) + fork_y = fork_pose[0, 1, 3] + + tolerance = self.metadata.get("success_params", {}).get("tolerance", 0.02) + + # spoon and fork should with the y range of tolerance related to plate. + return (abs(spoon_y - spoon_place_target_y) <= tolerance or abs(fork_y - fork_place_target_y) <= tolerance) \ No newline at end of file From 96bb2b98d18e756bfed5cb1b4a920d927160ad15 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:07:50 +0800 Subject: [PATCH 13/49] Migrate API interfaces --- embodichain/toolkits/interfaces.py | 987 +++++++++++++++++++++++++++++ 1 file changed, 987 insertions(+) create mode 100644 embodichain/toolkits/interfaces.py diff --git a/embodichain/toolkits/interfaces.py b/embodichain/toolkits/interfaces.py new file mode 100644 index 0000000..565d7aa --- /dev/null +++ b/embodichain/toolkits/interfaces.py @@ -0,0 +1,987 @@ +from typing import List, Dict +from embodichain.lab.gym.structs import Object +import numpy as np +from embodichain.toolkits.toolkits import ToolkitsBase +from embodichain.utils.logger import log_info, log_warning, log_error +from copy import deepcopy +from embodichain.lab.gym.utils.misc import mul_linear_expand, is_qpos_flip, get_rotation_replaced_pose +from embodichain.toolkits.graspkit.pg_grasp.antipodal import GraspSelectMethod +from matplotlib import pyplot as plt +import torch +from tqdm import tqdm +from embodichain.lab.gym.motion_generation.action.arm_action import ArmAction +from embodichain.data.enum import ControlParts, EndEffector, JointType +from scipy.spatial.transform import Rotation as R +from embodichain.lab.sim.utility.workspace_analyzer_new import compute_xpos_reachability +from embodichain.utils.utility import encode_image +import ast +''' +--------------------------------------------Some useful functions---------------------------------------------------- +--------------------------------------------Some useful functions---------------------------------------------------- +--------------------------------------------Some useful functions---------------------------------------------------- +''' +def draw_axis(env, pose): + from embodichain.lab.sim.cfg import MarkerCfg + marker_cfg = MarkerCfg( + name='test', + marker_type="axis", + axis_xpos=pose, + axis_size=0.01, + axis_len=0.2, + arena_index=-1, # All arenas + ) + env.sim.draw_marker(cfg=marker_cfg) + env.sim.update() + +def get_arm_states(env, robot_name): + + left_arm_current_qpos, right_arm_current_qpos = env.get_current_qpos_agent() + left_arm_current_pose, right_arm_current_pose = env.get_current_xpos_agent() + left_arm_current_gripper_state, right_arm_current_gripper_state = env.get_current_gripper_state_agent() + + side = "right" if "right" in robot_name else "left" + is_left = True if side == "left" else False + select_arm = "left_arm" if is_left else "right_arm" + + arms = {"left": (left_arm_current_qpos, left_arm_current_pose, left_arm_current_gripper_state), + "right": (right_arm_current_qpos, right_arm_current_pose, right_arm_current_gripper_state)} + select_arm_current_qpos, select_arm_current_pose, select_arm_current_gripper_state = arms[side] + + return is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, select_arm_current_gripper_state + +def find_nearest_valid_pose(env, select_arm, pose, xpos_resolution=0.1): + # use the validator to choose the nearest valid pose + # delete the cache every time + if isinstance(pose, torch.Tensor): + pose = pose.detach().cpu().numpy() + ret, _ = compute_xpos_reachability(env.robot, select_arm, pose, xpos_resolution=xpos_resolution, + qpos_resolution=np.radians(60), cache_mode="disk", use_cached=False, + visualize=False) + ret = np.stack(ret, axis=0) + # find the nearest valid pose + xyz = pose[:3, 3] + ts = np.stack([M[:3, 3] for M in ret], axis=0) # shape (N,3) + dists = np.linalg.norm(ts - xyz[None, :], axis=1) + best_idx = np.argmin(dists) + nearest_valid_pose = ret[best_idx] + return torch.from_numpy(nearest_valid_pose) + +def get_qpos(env, is_left, select_arm, pose, qpos_seed, force_valid=False, name=''): + if force_valid: + try: + ret, qpos = env.get_arm_ik( + pose, is_left=is_left, qpos_seed=qpos_seed + ) + if not ret: + log_error(f"Generate {name} qpos failed.\n") + except Exception as e: + log_warning(f"Original {name} pose invalid, using nearest valid pose. ({e})\n") + pose = find_nearest_valid_pose(env, select_arm, pose) + + ret, qpos = env.get_arm_ik( + pose, is_left=is_left, qpos_seed=qpos_seed + ) + else: + ret, qpos = env.get_arm_ik( + pose, is_left=is_left, qpos_seed=qpos_seed + ) + if not ret: + log_error(f"Generate {name} qpos failed.\n") + + return qpos + +def get_offset_pose( + pose_to_change: torch.Tensor, + offset_value: float, + direction: str = "z", + mode: str = "intrinsic", +) -> torch.Tensor: + + device = pose_to_change.device + dtype = pose_to_change.dtype + + if isinstance(direction, str): + if direction == "x": + direction_vec = torch.tensor([1.0, 0.0, 0.0], device=device, dtype=dtype) + elif direction == "y": + direction_vec = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=dtype) + elif direction == "z": + direction_vec = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=dtype) + else: + log_error(f"Invalid direction '{direction}'. Must be 'x', 'y', or 'z'.") + return pose_to_change + else: + direction_vec = torch.as_tensor(direction, device=device, dtype=dtype) + + direction_vec = direction_vec / torch.linalg.norm(direction_vec) + offset_matrix = torch.eye(4, device=device, dtype=dtype) + offset_matrix[:3, 3] = offset_value * direction_vec + + if mode == "extrinsic": + offset_pose = offset_matrix @ pose_to_change + elif mode == "intrinsic": + offset_pose = pose_to_change @ offset_matrix + else: + log_error(f"Invalid mode '{mode}'. Must be 'extrinsic' or 'intrinsic'.") + return pose_to_change + + return offset_pose + +def plan_trajectory(env, select_arm, qpos_list, sample_num, + select_arm_current_gripper_state, select_qpos_traj, ee_state_list_select): + traj_list, _ = ArmAction.create_discrete_trajectory( + agent=env.robot, + uid=select_arm, + qpos_list=qpos_list, + sample_num=sample_num, + qpos_seed=qpos_list[0], + is_use_current_qpos=False, + **getattr(env, "planning_config", {}), + ) + + select_qpos_traj.extend(traj_list) + ee_state_list_select.extend([select_arm_current_gripper_state] * len(traj_list)) + +def plan_gripper_trajectory(env, is_left, sample_num, execute_open, + select_arm_current_qpos, select_qpos_traj, ee_state_list_select): + open_state = env.open_state + close_state = env.close_state + + if execute_open: + ee_state_expand_select = np.array([close_state, open_state]) + env.set_current_gripper_state_agent(open_state, is_left=is_left) + else: + ee_state_expand_select = np.array([open_state, close_state]) + env.set_current_gripper_state_agent(close_state, is_left=is_left) + + ee_state_expand_select = mul_linear_expand(ee_state_expand_select, [sample_num], include_endpoint=True) + + select_qpos_traj.extend([select_arm_current_qpos] * sample_num) + ee_state_list_select.extend(ee_state_expand_select) + +def finalize_actions(select_qpos_traj, ee_state_list_select): + # mimic eef state + actions = np.concatenate([np.array(select_qpos_traj), np.array(ee_state_list_select), np.array(ee_state_list_select)], axis=-1) + return actions + +def extract_drive_calls(code_str: str) -> list[str]: + tree = ast.parse(code_str) + lines = code_str.splitlines() + + drive_blocks = [] + + for node in tree.body: + # Match: drive(...) + if ( + isinstance(node, ast.Expr) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "drive" + ): + # AST line numbers are 1-based + start = node.lineno - 1 + end = node.end_lineno + block = "\n".join(lines[start:end]) + drive_blocks.append(block) + + return drive_blocks + +''' +--------------------------------------------Atom action functions---------------------------------------------------- +--------------------------------------------Atom action functions---------------------------------------------------- +--------------------------------------------Atom action functions---------------------------------------------------- +''' +# TODO: write a move_to_pose atom action, the use this action to form other atom actions +def grasp( + robot_name: str, + obj_name: str, + pre_grasp_dis: float = 0.05, + env=None, + force_valid=False, + **kwargs +): + # Get target object + obj_uids = env.sim.get_rigid_object_uid_list() + if obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + else: + log_error(f"No matched object {obj_uids}.") + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0) + + # Open the gripper if currently closed + actions = None + select_arm_current_gripper_state = env.left_arm_current_gripper_state if 'left' in robot_name else env.right_arm_current_gripper_state + if select_arm_current_gripper_state <= env.open_state - 0.01: + actions = open_gripper(robot_name, env, **kwargs) + + # Retract the end-effector to avoid collision + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + select_arm_base_pose = env.left_arm_base_pose if is_left else env.right_arm_base_pose + base_to_eef_xy_dis = torch.norm(select_arm_base_pose[:2, 3] - select_arm_current_pose[:2, 3]) + base_to_obj_xy_dis = torch.norm(select_arm_base_pose[:2, 3] - target_obj_pose[:2, 3]) + dis_eps = kwargs.get('dis_eps', 0.05) + select_arm_init_pose = env.left_arm_init_xpos if is_left else env.right_arm_init_xpos + if base_to_eef_xy_dis > base_to_obj_xy_dis and not torch.allclose(select_arm_current_pose, select_arm_init_pose, rtol=1e-5, atol=1e-8): + delta = base_to_eef_xy_dis - (base_to_obj_xy_dis - dis_eps) + back_actions = move_by_relative_offset(robot_name=robot_name, dx=0.0, dy=0.0, dz=-delta, env=env, force_valid=force_valid, mode="intrinsic", sample_num=15, **kwargs) + actions = np.concatenate([actions, back_actions], axis=0) if actions is not None else back_actions + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Move the end-effector to a good place for starting grasping to avoid bad poses + select_arm_retract_pose = deepcopy(env.left_arm_init_xpos if is_left else env.right_arm_init_xpos) + select_arm_retract_pose = get_offset_pose(select_arm_retract_pose, 0.15, "z", "intrinsic") + select_arm_retract_qpos = get_qpos(env, is_left, select_arm, select_arm_retract_pose, env.left_arm_init_qpos if is_left else env.right_arm_init_qpos, + force_valid=force_valid, name='retract_to_good_pose') + qpos_list_back_to_retract = [select_arm_current_qpos, select_arm_retract_qpos] + sample_num = 30 + + plan_trajectory( + env, + select_arm, + qpos_list_back_to_retract, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + select_arm_current_qpos = select_arm_retract_qpos + select_arm_current_pose = select_arm_retract_pose + + # Rotate the arm base to face the object for better grasping + delta_xy = target_obj_pose[:2, 3] - select_arm_base_pose[:2, 3] + dx, dy = delta_xy[0], delta_xy[1] + aim_horizontal_angle = np.arctan2(dy, dx) + delta_angle = abs(select_arm_current_qpos[0] - aim_horizontal_angle) + select_arm_aim_qpos = deepcopy(select_arm_current_qpos) + select_arm_aim_qpos[0] = aim_horizontal_angle + + # Get best grasp pose from affordance data + grasp_pose_object = env.init_obj_info.get(obj_name)['grasp_pose_obj'] + if grasp_pose_object[0, 2] > 0.5: # whether towards x direction TODO: make it robust + # Align the object pose's z-axis with the arm's aiming direction + target_obj_pose = torch.tensor(get_rotation_replaced_pose(np.array(target_obj_pose), float(select_arm_aim_qpos[0]), "z","intrinsic")) + best_pickpose = target_obj_pose @ grasp_pose_object + grasp_pose = deepcopy(best_pickpose) + grasp_pose_pre1 = deepcopy(grasp_pose) + grasp_pose_pre1 = get_offset_pose(grasp_pose_pre1, -pre_grasp_dis, "z", 'intrinsic') + + # Solve IK for pre-grasp and grasp poses + grasp_qpos_pre1 = get_qpos(env, is_left, select_arm, grasp_pose_pre1, select_arm_aim_qpos, force_valid=force_valid, name='grasp pre1') + grasp_qpos = get_qpos(env, is_left, select_arm, grasp_pose, grasp_qpos_pre1, force_valid=force_valid, name='grasp') + + # Update env state to final grasp pose + env.set_current_qpos_agent(grasp_qpos, is_left=is_left) + env.set_current_xpos_agent(grasp_pose, is_left=is_left) + + # ------------------------------------ Traj 0: init → aim ------------------------------------ + qpos_list_init_to_aim = [select_arm_current_qpos, select_arm_aim_qpos] + # base_sample_num = 10 + # base_angle = 0.08 + # sample_num = max(int(delta_angle / base_angle * base_sample_num), 2) + + sample_num = 10 + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_aim, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ------------------------------------ Traj 1: aim → pre-grasp ------------------------------------ + qpos_list_aim_to_pre1 = [select_arm_aim_qpos, grasp_qpos_pre1] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_aim_to_pre1, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ------------------------------------ Traj 2: pre-grasp → grasp ------------------------------------ + qpos_list_pre1_to_grasp = [grasp_qpos_pre1, grasp_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_pre1_to_grasp, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + traj_actions = finalize_actions(select_qpos_traj, ee_state_list_select) + actions = traj_actions if actions is None else np.concatenate([actions, traj_actions], axis=0) + + # ------------------------------------ Close gripper ------------------------------------ + close_gripper_actions = close_gripper(robot_name, env, **kwargs) + actions = np.concatenate([actions, close_gripper_actions], axis=0) + + log_info(f"Total generated trajectory number for grasp: {len(actions)}.", color="green") + + return actions + +# def place_on_table( +# robot_name: str, +# obj_name: str, +# x: float = None, +# y: float = None, +# pre_place_dis: float = 0.08, +# env=None, +# force_valid=False, +# **kwargs +# ): +# +# # ---------------------------------------- Prepare ---------------------------------------- +# select_qpos_traj = [] +# ee_state_list_select = [] +# +# is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ +# select_arm_current_gripper_state = get_arm_states(env, robot_name) +# +# grasp_pose_object = env.init_obj_info.get(obj_name).get('grasp_pose_obj') +# init_obj_pose = env.init_obj_info.get(obj_name).get('pose') +# +# select_arm_base_pose = env.left_arm_base_pose if is_left else env.right_arm_base_pose +# delta_xy = init_obj_pose[:2, 3] - select_arm_base_pose[:2, 3] +# aim_horizontal_angle = np.arctan2(delta_xy[1], delta_xy[0]) +# +# # Align the object pose's z-axis with the arm's aiming direction +# init_obj_pose = torch.tensor(get_rotation_replaced_pose(np.array(init_obj_pose), float(aim_horizontal_angle), "z","intrinsic")) +# +# place_pose = init_obj_pose @ grasp_pose_object +# place_pose[0, 3] = x +# place_pose[1, 3] = y +# place_pose[2, 3] += kwargs.get('eps', 0.02) +# +# pre_place_pose = deepcopy(place_pose) +# pre_place_pose[2, 3] += pre_place_dis +# +# # Solve IK for pre-place and place poses +# place_qpos_pre1 = get_qpos(env, is_left, select_arm, pre_place_pose, select_arm_current_qpos, force_valid=force_valid, name='place pre1') +# place_qpos = get_qpos(env, is_left, select_arm, place_pose, place_qpos_pre1, force_valid=force_valid, name='place') +# +# # Update env states +# env.set_current_qpos_agent(place_qpos, is_left=is_left) +# env.set_current_xpos_agent(place_pose, is_left=is_left) +# +# # ------------------------------------ Traj 0: current → pre-place ------------------------------------ +# qpos_list_current_to_preplace = [select_arm_current_qpos, place_qpos_pre1] +# sample_num = 30 +# +# plan_trajectory( +# env, +# select_arm, +# qpos_list_current_to_preplace, +# sample_num, +# select_arm_current_gripper_state, +# select_qpos_traj, +# ee_state_list_select +# ) +# +# # ------------------------------------ Traj 1: pre-place → place ------------------------------------ +# qpos_list_preplace_to_place = [place_qpos_pre1, place_qpos] +# sample_num = 20 +# +# plan_trajectory( +# env, +# select_arm, +# qpos_list_preplace_to_place, +# sample_num, +# select_arm_current_gripper_state, +# select_qpos_traj, +# ee_state_list_select +# ) +# +# # ---------------------------------------- Final ---------------------------------------- +# traj_actions = finalize_actions(select_qpos_traj, ee_state_list_select) +# +# open_actions = open_gripper(robot_name, env, **kwargs) +# +# actions = np.concatenate([traj_actions, open_actions], axis=0) +# +# log_info(f"Total generated trajectory number for place on table: {len(actions)}.", color="green") +# +# return actions + +def place_on_table( + robot_name: str, + obj_name: str, + x: float = None, + y: float = None, + pre_place_dis: float = 0.08, + env=None, + force_valid=False, + **kwargs +): + + init_obj_height = env.init_obj_info.get(obj_name).get('height') + height = init_obj_height + kwargs.get('eps', 0.03) + + traj_actions = move_to_absolute_position(robot_name, x=x, y=y, z=height, env=env, force_valid=force_valid, **kwargs) + open_actions = open_gripper(robot_name, env, **kwargs) + + actions = np.concatenate([traj_actions, open_actions], axis=0) + + log_info(f"Total generated trajectory number for place on table: {len(actions)}.", color="green") + + return actions + +def move_relative_to_object( + robot_name: str, + obj_name: str, + x_offset: float = 0, + y_offset: float = 0, + z_offset: float = 0, + env=None, + force_valid=False, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Resolve target object + obj_uids = env.sim.get_rigid_object_uid_list() + if obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + else: + log_error("No matched object.") + + # Get object base pose (4x4 matrix) + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0) + + # Construct target pose (preserve orientation) + move_target_pose = deepcopy(select_arm_current_pose) + move_target_pose[:3, 3] = target_obj_pose[:3, 3] + move_target_pose[0, 3] += x_offset + move_target_pose[1, 3] += y_offset + move_target_pose[2, 3] += z_offset + + # Solve IK for target pose + move_target_qpos = get_qpos(env, is_left, select_arm, move_target_pose, select_arm_current_qpos, force_valid=force_valid, name='move relative to object') + + # Update env states + env.set_current_qpos_agent(move_target_qpos, is_left=is_left) + env.set_current_xpos_agent(move_target_pose, is_left=is_left) + + # ------------------------------------ Traj 1: init → target ------------------------------------ + qpos_list_init_to_target = [select_arm_current_qpos, move_target_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_target, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info(f"Total generated trajectory number for move relative to object: {len(actions)}.", color="green") + + return actions + +def move_to_absolute_position( + robot_name: str, + x: float = None, + y: float = None, + z: float = None, + env=None, + force_valid=False, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Start from current pose, then selectively update xyz + move_pose = deepcopy(select_arm_current_pose) + + current_xyz = move_pose[:3, 3].clone() + + target_xyz = current_xyz.clone() + if x is not None: + target_xyz[0] = x + if y is not None: + target_xyz[1] = y + if z is not None: + target_xyz[2] = z + + move_pose[:3, 3] = target_xyz + + # Try IK on target pose + move_qpos = get_qpos(env, is_left, select_arm, move_pose, select_arm_current_qpos, force_valid=force_valid, name='move to absolute position') + + # Update env states + env.set_current_qpos_agent(move_qpos, is_left=is_left) + env.set_current_xpos_agent(move_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_move = [select_arm_current_qpos, move_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_move, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info(f"Total generated trajectory number for move to absolute position: {len(actions)}.", color="green") + + return actions + +def move_by_relative_offset( + robot_name: str, + dx: float = 0.0, + dy: float = 0.0, + dz: float = 0.0, + mode: str = 'extrinsic', + env=None, + force_valid=False, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + move_pose = deepcopy(select_arm_current_pose) + + # Apply relative offsets (dx, dy, dz always floats) + move_pose = get_offset_pose(move_pose, dx, "x", mode) + move_pose = get_offset_pose(move_pose, dy, "y", mode) + move_pose = get_offset_pose(move_pose, dz, "z", mode) + + # Solve IK + move_qpos = get_qpos(env, is_left, select_arm, move_pose, select_arm_current_qpos, force_valid=force_valid, name='move by relative offset') + + # Update environment states + env.set_current_qpos_agent(move_qpos, is_left=is_left) + env.set_current_xpos_agent(move_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_move = [select_arm_current_qpos, move_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_move, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info(f"Total generated trajectory number for move by relative offset: {len(actions)}.", color="green") + + return actions + +def back_to_initial_pose( + robot_name: str, + env=None, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + # Get arm states + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # Retrieve the initial joint configuration of this arm + target_qpos = env.left_arm_init_qpos if is_left else env.right_arm_init_qpos + target_qpos = torch.as_tensor(target_qpos, dtype=select_arm_current_qpos.dtype) + + # ---------------------------------------- Pose ---------------------------------------- + # Pre-back pose: move along tool z by a small offset (use intrinsic frame) + pre_back_pose = deepcopy(select_arm_current_pose) + pre_back_pose = get_offset_pose(pre_back_pose, -0.08, "z", "intrinsic") + + # IK for pre-back + pre_back_qpos = get_qpos(env, is_left, select_arm, pre_back_pose, select_arm_current_qpos, force_valid=kwargs.get("force_valid", False), name="pre back pose") + + # Update env states (move to target pose) + target_pose = env.get_arm_fk(qpos=target_qpos, is_left=is_left) + env.set_current_qpos_agent(target_qpos, is_left=is_left) + env.set_current_xpos_agent(target_pose, is_left=is_left) + + # ------------------------------------ Traj: init → pre back_pose ------------------------------------ + qpos_list_init_to_preback = [select_arm_current_qpos, pre_back_qpos] + sample_num = 20 + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_preback, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ------------------------------------ Traj: init → initial_pose ------------------------------------ + qpos_list_preback_to_target = [pre_back_qpos, target_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_preback_to_target, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions( + select_qpos_traj, + ee_state_list_select + ) + + log_info(f"Total generated trajectory number for back to initial pose: {len(actions)}.", color="green") + + return actions + +def rotate_eef( + robot_name: str, + degree: float = 0, + env=None, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Compute new joint positions + rotated_qpos = deepcopy(select_arm_current_qpos) + rotated_qpos[5] += np.deg2rad(degree) + + # Optional: limit checking (commented out by default) + # joint5_limit = env.get_joint_limits(select_arm)[5] + # if rotated_qpos[5] < joint5_limit[0] or rotated_qpos[5] > joint5_limit[1]: + # log_warning("Rotated qpos exceeds joint limits.\n") + + # Compute FK for new pose + rotated_pose = env.get_arm_fk( + qpos=rotated_qpos, + is_left=is_left, + ) + + # Update environment state + env.set_current_qpos_agent(rotated_qpos, is_left=is_left) + env.set_current_xpos_agent(rotated_pose, is_left=is_left) + + # ------------------------------------ Traj 1: init → rotated ------------------------------------ + qpos_list_init_to_rotated = [select_arm_current_qpos, rotated_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_rotated, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info(f"Total generated trajectory number for rotate eef: {len(actions)}.", color="green") + + return actions + +def orient_eef( + robot_name: str, + direction: str = 'front', # 'front' or 'down' + env=None, + force_valid=False, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + # Get arm state + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Generate replacement rotation matrix + replaced_rotation_matrix = np.eye(4) + if direction == 'front': + rotation_matrix = R.from_euler("xyz", [180, -90, 0], degrees=True).as_matrix() + replaced_rotation_matrix[:3, :3] = rotation_matrix @ replaced_rotation_matrix[:3, :3] + elif direction == 'down': + rotation_matrix = R.from_euler("x", 180, degrees=True).as_matrix() + replaced_rotation_matrix[:3, :3] = rotation_matrix @ replaced_rotation_matrix[:3, :3] + else: + log_error("Rotation direction must be 'front' or 'down'.") + + rotation_replaced_pose = deepcopy(select_arm_current_pose) + rot_torch = torch.as_tensor( + replaced_rotation_matrix[:3, :3], + dtype=rotation_replaced_pose.dtype, + device=rotation_replaced_pose.device + ) + rotation_replaced_pose[:3, :3] = rot_torch + + # Solve IK for the new pose + replace_target_qpos = get_qpos(env, is_left, select_arm, rotation_replaced_pose, select_arm_current_qpos, force_valid=force_valid, name='replaced-rotation') + + # ---------------------------------------- Update env ---------------------------------------- + env.set_current_qpos_agent(replace_target_qpos, is_left=is_left) + env.set_current_xpos_agent(rotation_replaced_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_rotated = [select_arm_current_qpos, replace_target_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_rotated, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions( + select_qpos_traj, + ee_state_list_select + ) + + log_info(f"Total generated trajectory number for orient eef: {len(actions)}.", color="green") + + return actions + +def close_gripper( + robot_name: str, + env=None, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Traj ---------------------------------------- + sample_num = kwargs.get("sample_num", 15) + execute_open = False # False → closing motion + + plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions( + select_qpos_traj, + ee_state_list_select + ) + + log_info(f"Total generated trajectory number for close gripper: {len(actions)}.", color="green") + + return actions + +def open_gripper( + robot_name: str, + env=None, + **kwargs +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ + select_arm_current_gripper_state = get_arm_states(env, robot_name) + + # ---------------------------------------- Traj ---------------------------------------- + sample_num = kwargs.get("sample_num", 15) + execute_open = True # True → opening motion + + plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions( + select_qpos_traj, + ee_state_list_select + ) + + log_info(f"Total generated trajectory number for open gripper: {len(actions)}.", color="green") + + return actions + +def drive( + left_arm_action = None, + right_arm_action = None, + env=None, + **kwargs, +): + + if left_arm_action is not None and right_arm_action is not None: + len_left = len(left_arm_action) + len_right = len(right_arm_action) + + if len_left < len_right: + diff = len_right - len_left + padding = np.repeat(left_arm_action[-1:], diff, axis=0) + left_arm_action = np.concatenate([left_arm_action, padding], axis=0) + elif len_right < len_left: + diff = len_left - len_right + padding = np.repeat(right_arm_action[-1:], diff, axis=0) + right_arm_action = np.concatenate([right_arm_action, padding], axis=0) + + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + actions = np.zeros((len(right_arm_action), len(env.init_qpos))) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + elif left_arm_action is None and right_arm_action is not None: + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + left_arm_action = finalize_actions(env.left_arm_current_qpos, env.left_arm_current_gripper_state) + left_arm_action = np.repeat(left_arm_action[None, :], len(right_arm_action), axis=0) + + actions = np.zeros((len(right_arm_action), len(env.robot.get_qpos().squeeze(0))), dtype=np.float32) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + elif right_arm_action is None and left_arm_action is not None: + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + right_arm_action = finalize_actions(env.right_arm_current_qpos, env.right_arm_current_gripper_state) + right_arm_action = np.repeat(right_arm_action[None, :], len(left_arm_action), axis=0) + + actions = np.zeros((len(left_arm_action), len(env.robot.get_qpos().squeeze(0))), dtype=np.float32) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + else: + log_error("At least one arm action should be provided.") + + actions = torch.from_numpy(actions).to(dtype=torch.float32).unsqueeze(1) + actions = list(actions.unbind(dim=0)) + for i in tqdm(range(len(actions))): + action = actions[i] + obs, reward, terminated, truncated, info = env.step(action) + return actions + +def save_observations( + step_id: int = 0, + step_name: str = None, + env=None, + **kwargs, +): + # When using feedback script + log_dir = kwargs.get("log_dir") + if log_dir: + save_dir = log_dir / "camera_images" + + # Prepare subfolder: {id}_generate_num/episode{current_check_num} + gen_id = kwargs.get("id", "unknown_id") + episode_id = kwargs.get("current_check_num", 0) + + sub_dir = save_dir / f"{gen_id}_generate_num" / f"episode{episode_id}" + sub_dir.mkdir(parents=True, exist_ok=True) + + # Encode image to Base64 + base64_image = encode_image(env.get_obs_for_agent()["rgb"]) + + # Decode Base64 back to raw image bytes + import base64 + img_bytes = base64.b64decode(base64_image) + + # Ensure step_name is not None + step_name = step_name if step_name is not None else "unnamed_step" + + # Save the decoded image + output_path = sub_dir / f"step{step_id}_{step_name}.png" + with open(output_path, "wb") as f: + f.write(img_bytes) + + # Print save info + log_info(f"[save_observations] Saved image to: {output_path}") + + # When only running the script (no feedback script) + else: + pass \ No newline at end of file From 74eee5e717f39942a65a9368ee679a453ea97870 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:13:39 +0800 Subject: [PATCH 14/49] Migrate object --- embodichain/lab/gym/structs/object.py | 311 ++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 embodichain/lab/gym/structs/object.py diff --git a/embodichain/lab/gym/structs/object.py b/embodichain/lab/gym/structs/object.py new file mode 100644 index 0000000..08a8785 --- /dev/null +++ b/embodichain/lab/gym/structs/object.py @@ -0,0 +1,311 @@ +import numpy as np +import open3d as o3d +import os +import hashlib + +from typing import List, Dict, Union + +from embodichain.data import get_data_path +from embodichain.utils import logger + +from embodichain.toolkits.processor.function.mesh_processor.base import ( + MeshProcessorList, +) +from embodichain.toolkits.graspkit.pg_grasp import ( + AntipodalGenerator, +) + + +def generate_pickpose_sampler( + file_name: str, mesh: o3d.t.geometry.TriangleMesh, params: dict +) -> None: + logger.log_info(f"Generating object mesh {file_name} pick poses.") + if os.path.exists(file_name): + try: + if os.path.exists(file_name) and not file_name.endswith(".dae"): + pickpose_sampler = AntipodalGenerator( + mesh, + **params, + unique_id=hashlib.md5(file_name.encode()).hexdigest(), + ) + elif file_name.endswith(".dae"): + pickpose_sampler = None + else: + logger.log_warning( + f"Failed to build AntipodalGenerator because {file_name} is invalid!" + ) + except Exception as e: + logger.log_warning(f"Failed to build AntipodalGenerator: {str(e)}") + else: + logger.log_warning( + f"Failed to build AntipodalGenerator cause {file_name} unvalid!" + ) + return pickpose_sampler + + +# TODO: We should refactor the Object to support Group object design. +class Object: + name: str = "Undefined" + description: str = "Undefined" + pick_poses: List[np.ndarray] = None + pose: np.ndarray + parts: List["Object"] + articulations: Dict[str, np.ndarray] + mesh: o3d.geometry.TriangleMesh + scale: Union[List, np.ndarray] = [1, 1, 1] + mesh_file: str # to be depracted. + active_state: bool = False + unit: str = "m" + pickpose_sampler: AntipodalGenerator = None + pickpose_sampler_params: dict = None + # for object group + folder_path: str + mesh_processor: MeshProcessorList = None + + def get_mesh_file(self): + if hasattr(self, "mesh_file"): + return self.mesh_file + else: + obj_cad_files = [ + file + for file in os.listdir(self.folder_path) + if file.startswith("mesh_processed_") is False + ] + + target_file = np.random.choice(obj_cad_files) + + return self.select_mesh_file_from_folder(target_file) + + def select_mesh_file_from_folder(self, target_file: str): + from embodichain.toolkits.processor.component import TriangleComponent + from embodichain.toolkits.processor.entity import MeshEntity + + cache_fpath = os.path.join(self.folder_path, f"mesh_processed_{target_file}") + if os.path.exists(cache_fpath) is False: + tri_component = TriangleComponent.from_fpath( + os.path.join(self.folder_path, target_file) + ) + mesh_entity = MeshEntity("mesh", tri_component) + mesh = self.mesh_processor.apply([mesh_entity])[0] + mesh.save_mesh(cache_fpath) + + if self.pickpose_sampler_params is not None: + mesh = o3d.t.io.read_triangle_mesh(cache_fpath) + self.pickpose_sampler = generate_pickpose_sampler( + cache_fpath, mesh, self.pickpose_sampler_params + ) + + return cache_fpath + + @staticmethod + def from_folder(path: str, obj_data: dict) -> "Object": + from embodichain.toolkits.processor.function.mesh_processor import ( + build_mesh_processors, + ) + + obj = Object() + obj.folder_path = path + obj.description = obj_data.get("description", "Undefined") + obj.name = obj_data.get("name", "Undefined") + obj.unit = obj_data.get("unit", "m") + obj.pickpose_sampler_params = obj_data.get("auto_pickpose_generator", None) + obj.pose = obj_data.get("pose", np.eye(4)) + mesh_processor_config = obj_data.get("mesh_processor", None) + if mesh_processor_config is not None: + obj.mesh_processor = build_mesh_processors(mesh_processor_config) + + return obj + + @staticmethod + def from_mesh(path: str, downsample: bool = False, local_dir: str = "") -> "Object": + r"""Create an Object instance from a mesh file. + + Args: + path (str): The file path to the mesh key, value in obj_related[k].items(): + setattr(obj, key, value) + objs.append(obj) + + # Order the objects based file. + downsample (bool, optional): Whether to downsample the mesh. Defaults to False. + + Returns: + Object: An `Object` instance containing the mesh data. + """ + obj = Object() + + if not os.path.exists(path): + if local_dir is not None and local_dir != "": + path = os.path.join(local_dir, path) + else: + path = get_data_path(path) + + mesh = o3d.io.read_triangle_mesh(path) + if downsample: + obj.mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=4000) + else: + obj.mesh = mesh + obj.mesh_file = path + return obj + + @staticmethod + def from_urdf(path: str) -> List["Object"]: + r"""Create a list of Object instances from a URDF file. + + This method reads a URDF (Unified Robot Description Format) file, extracts + the geometry data, and returns a list of `Object` instances representing + the visual elements described in the URDF. + + Args: + path (str): The file path to the URDF file. + + Returns: + List[Object]: A list of `Object` instances representing the visual elements. + """ + import pinocchio + import copy + + data_path = copy.deepcopy(path) + if not os.path.exists(data_path): + data_path = get_data_path(path) + package_dirs = [os.path.dirname(data_path)] + + model, collision_model, visual_model = pinocchio.buildModelsFromUrdf( + data_path, package_dirs=package_dirs + ) + + urdf_dir = os.path.dirname(data_path) + objs = [] + + # Parse the geometry data from URDF + for geom in visual_model.geometryObjects.tolist(): + if hasattr(geom, "meshPath"): + mesh_path = geom.meshPath + if not os.path.isabs(mesh_path): + mesh_path = os.path.join(urdf_dir, mesh_path) + obj = Object.from_mesh(mesh_path) + obj.name = geom.name + objs.append(obj) + + return objs + + @staticmethod + def _save_mesh_or_urdf(obj: "Object", new_file_name: str): + r"""Save the mesh or URDF file with error handling. + + Args: + obj (Object): The object containing the mesh or URDF data. + new_file_name (str): The new file path where the data will be saved. + """ + try: + if new_file_name.endswith(".urdf"): + obj.save_as_urdf(new_file_name) + else: + o3d.io.write_triangle_mesh(new_file_name, obj.mesh) + except Exception as e: + logger.log_error(f"Failed to save the file {new_file_name}: {str(e)}") + + @staticmethod + def _generate_new_filename(file_path: str, extension: str) -> str: + r"""Generate a new filename with a specified extension. + + Args: + file_path (str): The original file path. + extension (str): The new extension to append to the filename. + + Returns: + str: The generated file path with the new extension. + """ + _, file_extension = os.path.splitext(file_path) + return os.path.join( + os.path.dirname(file_path), + os.path.basename(file_path).split(".")[0] + f"_{extension}{file_extension}", + ) + + @staticmethod + def _apply_common_settings(obj: "Object", obj_data: dict, file_path: str): + r"""Apply common settings such as unit conversion and pose generation to the object. + + Args: + obj (Object): The object to which settings will be applied. + obj_data (dict): Dictionary containing object-specific configuration data. + file_path (str): Object file path. + """ + if "unit" in obj_data and obj_data["unit"] == "mm": + obj.mesh.scale(1e-3, center=np.zeros((3))) + obj.unit = "m" + obj_data.pop("unit") + new_file_name = Object._generate_new_filename(obj.mesh_file, extension="m") + Object._save_mesh_or_urdf(obj, new_file_name) + obj.mesh_file = new_file_name + + obj.scale = obj_data.get("scale", [1, 1, 1]) + from dexsim.utility.meshproc import scale_trianglemesh + + obj.mesh = scale_trianglemesh(obj.mesh, obj.scale) + + if "auto_pickpose_generator" in obj_data: + mesh = o3d.t.geometry.TriangleMesh.from_legacy(obj.mesh) + obj.pickpose_sampler = generate_pickpose_sampler( + obj.mesh_file, mesh, obj_data["auto_pickpose_generator"] + ) + + for key, value in obj_data.items(): + setattr(obj, key, value) + + @staticmethod + def from_config( + path: Union[str, Dict], downsample: bool = False, local_dir: str = "" + ) -> List["Object"]: + r"""Create a list of Object instances from a configuration file. + + Args: + path (Union[str, Dict]): The file path to the configuration file or a dictionary containing the configuration. + downsample (bool, optional): Whether to downsample the mesh. Defaults to False. + + Returns: + List[Object]: A list of `Object` instances as specified in the configuration file. + """ + from embodichain.utils.utility import load_json + + if isinstance(path, str): + config = load_json(path) + else: + config = path + + obj_related = config["obj_list"] + objs = [] + + for k, obj_data in obj_related.items(): + if obj_data.get("mesh_file", None) is not None: + file = obj_data.pop("mesh_file") + file_ext = os.path.splitext(file)[-1].lower() + + if local_dir is not None and local_dir != "": + data_file_path = os.path.join(local_dir, file) + else: + data_file_path = get_data_path(file) + + if file_ext == ".urdf": + urdf_objs = Object.from_urdf(file) + for obj in urdf_objs: + Object._apply_common_settings(obj, obj_data, data_file_path) + objs.extend(urdf_objs) + else: + obj = Object.from_mesh(file, downsample, local_dir=local_dir) + Object._apply_common_settings(obj, obj_data, data_file_path) + objs.append(obj) + else: + folder_path = obj_data.get("folder_path", None) + if folder_path is None: + logger.log_error( + f"Object configuration {k} does not contain a valid mesh file or folder path." + ) + obj = Object.from_folder(folder_path, obj_data) + objs.append(obj) + + # TODO: to be improved. + if len(objs) == len(obj_related): + order = [int(k) for k in obj_related.keys()] + return [objs[i] for i in np.argsort(order)] + else: + return objs From 3421022efd2f73034008e42f62feec1677c8caae Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:13:48 +0800 Subject: [PATCH 15/49] Migrate object --- embodichain/lab/gym/structs/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 embodichain/lab/gym/structs/__init__.py diff --git a/embodichain/lab/gym/structs/__init__.py b/embodichain/lab/gym/structs/__init__.py new file mode 100644 index 0000000..a240d03 --- /dev/null +++ b/embodichain/lab/gym/structs/__init__.py @@ -0,0 +1 @@ +from .object import Object From fa8f0fe3e23f7f704fd0abd4f0a23b0ff9175f5b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:20:10 +0800 Subject: [PATCH 16/49] Migrate necessory files in toolkits --- .../toolkits/processor/component/__init__.py | 7 + .../toolkits/processor/component/component.py | 311 ++++++++++++++++++ .../toolkits/processor/entity/__init__.py | 8 + .../toolkits/processor/entity/entity_base.py | 217 ++++++++++++ .../toolkits/processor/entity/meshobject.py | 122 +++++++ .../function/mesh_processor/__init__.py | 2 + .../processor/function/mesh_processor/base.py | 42 +++ .../function/mesh_processor/processor.py | 163 +++++++++ .../toolkits/processor/types/__init__.py | 7 + embodichain/toolkits/processor/types/types.py | 7 + embodichain/toolkits/toolkits.py | 18 + 11 files changed, 904 insertions(+) create mode 100644 embodichain/toolkits/processor/component/__init__.py create mode 100644 embodichain/toolkits/processor/component/component.py create mode 100644 embodichain/toolkits/processor/entity/__init__.py create mode 100644 embodichain/toolkits/processor/entity/entity_base.py create mode 100644 embodichain/toolkits/processor/entity/meshobject.py create mode 100644 embodichain/toolkits/processor/function/mesh_processor/__init__.py create mode 100644 embodichain/toolkits/processor/function/mesh_processor/base.py create mode 100644 embodichain/toolkits/processor/function/mesh_processor/processor.py create mode 100644 embodichain/toolkits/processor/types/__init__.py create mode 100644 embodichain/toolkits/processor/types/types.py create mode 100644 embodichain/toolkits/toolkits.py diff --git a/embodichain/toolkits/processor/component/__init__.py b/embodichain/toolkits/processor/component/__init__.py new file mode 100644 index 0000000..774cdb9 --- /dev/null +++ b/embodichain/toolkits/processor/component/__init__.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .component import * diff --git a/embodichain/toolkits/processor/component/component.py b/embodichain/toolkits/processor/component/component.py new file mode 100644 index 0000000..040aa23 --- /dev/null +++ b/embodichain/toolkits/processor/component/component.py @@ -0,0 +1,311 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import abc +import hashlib +import numpy as np +import open3d as o3d +from typing import List, Dict, Any +from dataclasses import dataclass, field, fields, is_dataclass, asdict +from scipy.spatial.transform import Rotation + +from embodichain.utils.cfg import CfgNode +import dexsim.utility as dexutils +from embodichain.toolkits.processor.types import CFG_DEF_TYPE_KEYS +from dexsim.kit.meshproc.generate_thicker_acd import get_pc_thickness + + +class BaseMetaClass(type): + register_class = {} + + def __new__(cls, name: str, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "EntityComponent": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class EntityComponent(metaclass=BaseMetaClass): + """An abstract class for entity components. + + EntityComponent 只是单纯的数据类,没有其他功能行为。 因此我们可以用 `@dataclass` 来装饰他们。 + + 在使用 `@dataclass` 装饰器时,有几点需要注意: + 1. 使用 eq=False 避免使用 dataclass 默认的 __eq__ 方法,而是使用基类的 __eq__ 方法, + 特别是当数据类中有 np.ndarray 类型的变量 比较时会出现错误 + 2. 使用 fronzen=True 避免多个 entity 共享同一个组件实例的时候, + 修改某个组件的属性会影响到其他 entity 的问题。这样做的话,一旦一个组件被初始化,其属性就不能再被修改了,只能通过 new 方法创建新的组件实例。 + 当一个 Component 可能会被多个 entity 共享时,就应该使用 frozen=True 来避免这个问题。 + """ + + def __eq__(self, other): + """ + 用于检查两个相同 EntityComponent 是否相等。 + + 重写 dataclass 的 __eq__ 方法,是增加 EntityComponent 中有 np.ndarray 类型变量的比较。 + + Args: + other (object): The object to compare with. + Returns: + bool: True if the two instances are equal, False otherwise. + """ + if isinstance(other, self.__class__): + for k, v in self.__dict__.items(): + if k not in other.__dict__: + return False + if isinstance(v, np.ndarray): + if not np.array_equal(v, other.__dict__[k]): + return False + elif v != other.__dict__[k]: + return False + return True + return False + + def new(self, **kwargs): + """ + 根据传入的参数更新组件的属性,并返回一个新的组件实例。 + + 一个组件可能会被多个 entity 共享,所以在更新组件的属性时,如果需要对不同的 entity 有不同的属性值, + 使用该接口进行参数更新,而不是直接修改组件的属性。 + + Args: + **kwargs: Keyword arguments to set the attributes of the new instance. + + Raises: + ValueError: If any of the keyword arguments are not valid fields of the class. + + Returns: + An instance of the class with the given keyword arguments set as attributes. + + """ + for k, v in kwargs.items(): + if not hasattr(self, k): + raise ValueError(f"{k} is not a valid field") + # TODO check the type of value + # setattr(self, k, v) + for field in fields(self): + if field.name not in kwargs: + kwargs[field.name] = getattr(self, field.name) + return type(self)(**kwargs) + + def save(self) -> Dict: + """ + 保存当前组件实例的数据。 + + 该函数首先检查当前实例是否为 dataclass,如果是,就使用 `asdict` 函数来保存数据。如果不是,就抛出一个 `NotImplementedError` 异常。 + 即当前我们仅支持 dataclass 类型的组件实例。 + + Returns: + Dict: The data of the current instance as a dictionary. + + Raises: + NotImplementedError: If the current instance is not a dataclass. + """ + if not is_dataclass(self): + raise NotImplementedError + return asdict(self) + + @classmethod + def from_config(cls, cfg: CfgNode): + if not is_dataclass(cls): + raise NotImplementedError + data_fields = fields(cls) + if not isinstance(cfg, dict): + if len(data_fields) > 1: + raise ValueError(f"Config should be a dict, but got {cfg}.") + return cls(cfg) + else: + params = {} + for key, val in cfg.items(): + params[key.lower()] = val + return cls(**params) + + +def build_component_from_config(cfg: CfgNode, name: str = None) -> EntityComponent: + type_key = None + if name is None: + for tk in CFG_DEF_TYPE_KEYS: + if tk in cfg: + if type_key is not None: + raise ValueError( + f"Config should only contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = tk + if type_key is None: + raise ValueError( + f"Config should contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = cfg.pop(type_key) + else: + type_key = name + register_class = EntityComponent.register_class + if isinstance(type_key, str): + type_key = type_key.upper() + if type_key not in register_class: + raise ValueError(f"Class {type_key} is not registered") + else: + raise TypeError(f"Class type {type_key} is not a string") + try: + return register_class[type_key].from_config(cfg) + except Exception as e: + raise ValueError(f"Failed to build component {type_key}, {e}") + + +@dataclass(eq=False) +class AxisAlignedBoundingBox(EntityComponent): + min_bound: np.ndarray + max_bound: np.ndarray + + def is_close( + self, other: "AxisAlignedBoundingBox", threshold: float = 1e-3 + ) -> bool: + return np.allclose( + self.min_bound, other.min_bound, atol=threshold + ) and np.allclose(self.max_bound, other.max_bound, atol=threshold) + + +@dataclass(eq=False) +class OrientedBoundingBox(EntityComponent): + center: np.ndarray + extent: np.ndarray + R: Rotation + + +@dataclass(eq=False, frozen=True) +class TriangleComponent(EntityComponent): + vertices: np.ndarray + triangles: np.ndarray + triangle_uvs: np.ndarray = np.empty((0, 3, 2)) + vertex_uvs: np.ndarray = np.empty((0, 2)) + vertex_colors: np.ndarray = np.empty((0, 3)) # or 4 + vertex_normals: np.ndarray = np.empty((0, 3)) + texture: np.ndarray = np.empty((0, 0, 3)) # hwc + mesh_fpath: str = None + optional_params: Dict[str, Any] = field(default_factory=dict) + + def md5_hash(self) -> str: + md5 = hashlib.md5() + hash_attr_keys = ["vertices", "triangles", "mesh_fpath"] + for key in hash_attr_keys: + val = getattr(self, key) + if val is None: + continue + if isinstance(val, np.ndarray): + md5.update(val.tobytes()) + elif isinstance(val, str): + md5.update(val.encode()) + + return md5.hexdigest() + + @classmethod + def from_config(cls, cfg: CfgNode): + if "MESH_FPATH" not in cfg: + raise ValueError(f"Config should contains key MESH_FPATH, but got {cfg}.") + mesh_fpath = cfg.MESH_FPATH + mesh = cls.__from_fpath(mesh_fpath) + + optional_params = {} + if "OPTIONAL_PARAMS" in cfg and cfg.OPTIONAL_PARAMS is not None: + for key, val in cfg.OPTIONAL_PARAMS.items(): + optional_params[key.lower()] = val + return cls( + vertices=mesh.vertices, + triangles=mesh.triangles, + triangle_uvs=mesh.triangle_uvs, + vertex_uvs=mesh.vertex_uvs, + vertex_colors=mesh.vertex_colors, + vertex_normals=mesh.vertex_normals, + texture=mesh.texture, + mesh_fpath=mesh_fpath, + optional_params=optional_params, + ) + + @classmethod + def from_fpath(cls, mesh_fpath: str): + mesh = cls.__from_fpath(mesh_fpath) + + return cls( + vertices=mesh.vertices, + triangles=mesh.triangles, + triangle_uvs=mesh.triangle_uvs, + vertex_uvs=mesh.vertex_uvs, + vertex_colors=mesh.vertex_colors, + vertex_normals=mesh.vertex_normals, + texture=mesh.texture, + mesh_fpath=mesh_fpath, + ) + + @staticmethod + def __from_fpath(mesh_fpath: str): + from dexsim.kit.meshproc.mesh_io import load_mesh + + mesh = load_mesh(mesh_fpath) + if isinstance(mesh, List): + msg = f"Mesh file {mesh_fpath} contains multiple meshes. Only the first mesh will be used." + dexutils.log_warning(msg) + mesh = mesh[0] + # vertices = mesh.vertices + # triangles = mesh.triangles + if mesh.vertex_uvs is None: + mesh.vertex_uvs = np.empty((0, 2)) + if mesh.triangle_uvs is None: + mesh.triangle_uvs = np.empty((0, 3, 2)) + if mesh.vertex_colors is None: + mesh.vertex_colors = np.empty((0, 3)) + if mesh.vertex_normals is None: + mesh.vertex_normals = np.empty((0, 3)) + if mesh.texture is None: + mesh.texture = np.empty((0, 0, 3)) + return mesh + + def get_thickness(self): + mesh_o3d = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(self.vertices), + o3d.utility.Vector3iVector(self.triangles), + ) + surface_pc_o3d = mesh_o3d.sample_points_uniformly(number_of_points=3000) + surface_pc = np.array(surface_pc_o3d.points) + thickness, standard_pose = get_pc_thickness(surface_pc) + return thickness + + +@dataclass() +class VisualComponent(EntityComponent): + is_visual: bool = True + + +@dataclass(eq=False) +class ScaleComponent(EntityComponent): + scale: np.ndarray = np.array([1.0, 1.0, 1.0]) + + def __post_init__(self): + if self.scale is not None: + self.scale = np.array(self.scale) + + +@dataclass(eq=False) +class SpatializationComponenet(EntityComponent): + location: np.ndarray = np.array([0, 0, 0], dtype=np.float32) + rotation: Rotation = Rotation.from_matrix(np.eye(3)) + + def __post_init__(self): + if self.location is not None: + self.location = np.array(self.location, dtype=np.float32) + if self.rotation is not None and not isinstance(self.rotation, Rotation): + self.rotation = Rotation.from_euler("xyz", self.rotation, degrees=True) + + def save(self) -> Dict: + return_dict = asdict(self) + return_dict["rotation"] = self.rotation.as_euler("xyz", degrees=True) + return return_dict + + def get_pose(self) -> np.ndarray: + pose = np.eye(4) + if self.location is not None: + pose[:3, 3] = self.location + if self.rotation is not None: + pose[:3, :3] = self.rotation.as_matrix() + return pose diff --git a/embodichain/toolkits/processor/entity/__init__.py b/embodichain/toolkits/processor/entity/__init__.py new file mode 100644 index 0000000..a3a7bea --- /dev/null +++ b/embodichain/toolkits/processor/entity/__init__.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .entity_base import EntityBase, build_entity_from_config +from .meshobject import MeshEntity diff --git a/embodichain/toolkits/processor/entity/entity_base.py b/embodichain/toolkits/processor/entity/entity_base.py new file mode 100644 index 0000000..a6a0951 --- /dev/null +++ b/embodichain/toolkits/processor/entity/entity_base.py @@ -0,0 +1,217 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from typing import Union, Dict, Any + +# from dataclasses import asdict, is_dataclass + +from embodichain.utils.cfg import CfgNode + +from embodichain.toolkits.processor.component import EntityComponent +from embodichain.toolkits.processor.types import CFG_DEF_TYPE_KEYS + + +class EntityMetaClass(type): + register_class = {} + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "EntityBase": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class EntityBase(metaclass=EntityMetaClass): + """ + The base class for all entities in the scene. + + Args: + name (str): The name of the entity. + + Attributes: + name (str): The name of the entity. + _components (Dict[type, EntityComponent]): A dictionary of components of type `EntityComponent`. + """ + + def __init__(self, name: str) -> None: + self.name = name + + self._components: Dict[type, EntityComponent] = {} + self._custom_properties: Dict[str, Any] = {} + + @classmethod + def kind(cls) -> str: + """ + The kind of the entity. + + Returns: + str: A string representing the name of the class. + """ + return cls.__name__ + + def get_name(self) -> str: + """ + Get the name of the entity. + + Returns: + str: the name of the entity + """ + return self.name + + def set_custom_property(self, **kwargs): + for key, value in kwargs.items(): + self._custom_properties[key] = value + + def has_custom_property(self, *key: str): + for k in key: + if k not in self._custom_properties: + return False + return True + + def get_custom_property(self, key: str): + if key not in self._custom_properties: + return None + return self._custom_properties[key] + + def get_custom_properties(self) -> Dict[str, Any]: + return self._custom_properties + + def remove_custom_property(self, *key: str): + for k in key: + if k in self._custom_properties: + del self._custom_properties[k] + + def add_component(self, *component: EntityComponent): + """ + Adds one or more components to the entity. + + Args: + *component (EntityComponent): One or more components to be added. + + Raises: + AssertionError: If any of the components are not instances of EntityComponent or not direct subclasses of EntityComponent. + + Description: + This method adds one or more components to the entity. If a component of the same type already exists, it will be replaced. + + Note: + Currently, the method only accepts components that are instances of EntityComponent and direct subclasses of EntityComponent. + + Example: + >>> entity = Entity() + >>> entity.add_component(ScaleComponent(), SpatializationComponenet()) + """ + # The existing component of the same type will be replaced. + for comp in component: + assert isinstance( + comp, EntityComponent + ), f"Component {type(comp)} must be an instance of EntityComponent" + assert ( + EntityComponent in comp.__class__.__bases__ + ), f"Component {type(comp)} must be a direct subclass of EntityComponent" + # we only accept dataclass type component + # assert is_dataclass(comp) + self._components[type(comp)] = comp + + def has_component(self, comp_type: type) -> bool: + """ + Check if the entity has a component of a specific type. + + Args: + comp_type (type): The type of component to check for. + + Returns: + bool: True if the entity has a component of the specified type, False otherwise. + """ + return comp_type in self._components + + def get_component(self, comp_type: type) -> Union[None, EntityComponent]: + """ + Get the component of the specified type from the entity. + + Args: + comp_type (type): The type of component to retrieve. + + Returns: + Union[None, EntityComponent]: The component of the specified type, or None if it does not exist. + """ + return self._components.get(comp_type, None) + + def remove_component(self, comp_type: type): + """ + Remove a component of the specified type from the entity. + + Parameters: + comp_type (type): The type of component to remove. + + Raises: + ValueError: If the component of the specified type does not exist. + + Returns: + None + """ + if comp_type in self._components: + del self._components[comp_type] + else: + raise ValueError(f"{comp_type} component does not exist") + + def save(self) -> Dict: + """ + Saves the components of the entity to a dictionary. + + Returns: + Dict: A dictionary containing the names of the component types as keys and the saved components as values. + """ + results = {} + for comp_type, comp in self._components.items(): + # # only save the components that are decorated by dataclass + # if is_dataclass(comp): + results[comp_type.__name__] = comp.save() + return results + + @classmethod + def from_config(cls, cfg: Union[str, CfgNode]) -> "EntityBase": + if isinstance(cfg, CfgNode): + if "NAME" not in cfg: + raise ValueError(f"Config should contains key NAME, but got {cfg}.") + name = cfg.pop("NAME") + elif isinstance(cfg, str): + name = cfg + else: + raise TypeError(f"Config should be a string or CfgNode, but got {cfg}.") + return cls(name) + + +def build_entity_from_config( + cfg: Union[str, CfgNode], entity_type: str = None +) -> EntityBase: + if isinstance(cfg, str) and entity_type is None: + err_msg = f"'entity_type' is required when 'cfg' is a string." + raise ValueError(err_msg) + type_key = None + if entity_type is None: + for tk in CFG_DEF_TYPE_KEYS: + if tk in cfg: + if type_key is not None: + raise ValueError( + f"Config should only contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = tk + if type_key is None: + raise ValueError( + f"Config should contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = cfg.pop(type_key) + else: + type_key = entity_type + register_class = EntityBase.register_class + if isinstance(type_key, str): + type_key = type_key.upper() + if type_key not in register_class: + raise ValueError(f"Class {type_key} is not registered") + else: + raise TypeError(f"Class type {type_key} is not a string") + return register_class[type_key].from_config(cfg) diff --git a/embodichain/toolkits/processor/entity/meshobject.py b/embodichain/toolkits/processor/entity/meshobject.py new file mode 100644 index 0000000..6552646 --- /dev/null +++ b/embodichain/toolkits/processor/entity/meshobject.py @@ -0,0 +1,122 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +import open3d as o3d + +from embodichain.toolkits.processor.component import EntityComponent +from embodichain.toolkits.processor.component import ( + OrientedBoundingBox, + ScaleComponent, + SpatializationComponenet, + TriangleComponent, + AxisAlignedBoundingBox, + TriangleComponent, + VisualComponent, +) +from .entity_base import EntityBase + + +class MeshEntity(EntityBase): + def __init__( + self, + name: str, + triangle_comp: TriangleComponent = None, + spatialization_comp: SpatializationComponenet = None, + visual_comp: VisualComponent = None, + ) -> None: + super().__init__(name) + if triangle_comp is not None: + self.add_component(triangle_comp) + if visual_comp is None: + visual_comp = VisualComponent() + self.add_component(visual_comp) + + # init with default component + if spatialization_comp is None: + spatialization_comp = SpatializationComponenet() + # add the initial component + self.add_component(spatialization_comp) + + def is_visible(self) -> bool: + # if not self.has_component(TriangleComponent): + # return False + visual_comp = self.get_component(VisualComponent) + if visual_comp: + visual = visual_comp.is_visual + else: + visual = False + return visual + + def add_component(self, *component: EntityComponent): + for comp in component: + if isinstance(comp, TriangleComponent): + self.triangle_comp = comp + # remove the old bounding box component + if self.has_component(AxisAlignedBoundingBox): + self.remove_componenet(AxisAlignedBoundingBox) + if self.has_component(OrientedBoundingBox): + self.remove_componenet(OrientedBoundingBox) + # add default visual component + if not self.has_component(VisualComponent): + self.add_component(VisualComponent()) + super().add_component(*component) + + def get_axis_aligned_bounding_box(self) -> o3d.geometry.AxisAlignedBoundingBox: + triangle_comp: TriangleComponent = self.get_component(TriangleComponent) + scale_comp = self.get_component(ScaleComponent) + vertices = triangle_comp.vertices + scale = np.array([1, 1, 1]) + if scale_comp is not None: + scale = scale_comp.scale + o3d_mesh = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(vertices * scale), + o3d.utility.Vector3iVector(triangle_comp.triangles), + ) + spatial_comp: SpatializationComponenet = self.get_component( + SpatializationComponenet + ) + o3d_mesh.transform(spatial_comp.get_pose()) + aabbox_o3d = o3d_mesh.get_axis_aligned_bounding_box() + return aabbox_o3d + + # def get_oriented_bounding_box(self) -> o3d.geometry.OrientedBoundingBox: + # if not self.has_component(OrientedBoundingBox): + # # self.add_component(OrientedBoundingBox(self.triangle_comp.vertices)) + # o3d_mesh = o3d.geometry.TriangleMesh(self.triangle_comp.vertices, + # self.triangle_comp.triangles) + # bbox = o3d_mesh.get_oriented_bounding_box() + # obb = OrientedBoundingBox(bbox.get_center(), bbox.get_extent(), + # bbox.get_rotation()) + # self.add_component(obb) + # obbox = self.get_component(OrientedBoundingBox) + # return obbox + + def get_o3d_mesh( + self, add_scale: bool = True, add_transform: bool = True + ) -> o3d.geometry.TriangleMesh: + triangle_comp: TriangleComponent = self.get_component(TriangleComponent) + scale_comp = self.get_component(ScaleComponent) + vertices = triangle_comp.vertices + scale = np.array([1, 1, 1]) + if add_scale and scale_comp is not None: + scale = scale_comp.scale + o3d_mesh = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(vertices * scale), + o3d.utility.Vector3iVector(triangle_comp.triangles), + ) + if add_transform: + spatial_comp: SpatializationComponenet = self.get_component( + SpatializationComponenet + ) + o3d_mesh.transform(spatial_comp.get_pose()) + return o3d_mesh + + def save_mesh(self, file_path: str): + from dexsim.kit.meshproc.mesh_io import save_mesh + + tri_comp: TriangleComponent = self.get_component(TriangleComponent) + save_mesh(file_path, **tri_comp.save()) diff --git a/embodichain/toolkits/processor/function/mesh_processor/__init__.py b/embodichain/toolkits/processor/function/mesh_processor/__init__.py new file mode 100644 index 0000000..145396a --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/__init__.py @@ -0,0 +1,2 @@ +from .base import MeshProcessor, build_mesh_processors +from .processor import * diff --git a/embodichain/toolkits/processor/function/mesh_processor/base.py b/embodichain/toolkits/processor/function/mesh_processor/base.py new file mode 100644 index 0000000..2b4d331 --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/base.py @@ -0,0 +1,42 @@ +import abc +from typing import List, Dict + +from embodichain.utils.cfg import CfgNode + +from embodichain.toolkits.processor.entity import MeshEntity + + +class MeshProcessorMetaClass(type): + register_class = {} + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "MeshProcessor": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class MeshProcessor(metaclass=MeshProcessorMetaClass): + def __init__(self, **kwargs): + pass + + @abc.abstractmethod + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + pass + + +class MeshProcessorList(List[MeshProcessor]): + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for processor in self: + meshes = processor.apply(meshes) + return meshes + + +def build_mesh_processors(config: CfgNode) -> MeshProcessorList: + processors = MeshProcessorList() + for name, cfg in config.items(): + if cfg is None: + cfg = {} + processor = MeshProcessor.register_class[name.upper()](**cfg) + processors.append(processor) + return processors diff --git a/embodichain/toolkits/processor/function/mesh_processor/processor.py b/embodichain/toolkits/processor/function/mesh_processor/processor.py new file mode 100644 index 0000000..dc851f1 --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/processor.py @@ -0,0 +1,163 @@ +import open3d as o3d +import numpy as np +from typing import List, Tuple + +from dexsim.kit.meshproc import face_uv_to_vert_uv +from dexsim.kit.meshproc.compute_uv import get_mesh_auto_uv +from dexsim.kit.meshproc import ( + simplification_decimation, + remesh_isotropic_explicit, +) +import dexsim.utility as dexsutil + +from embodichain.toolkits.processor.entity import MeshEntity +from embodichain.toolkits.processor.component import TriangleComponent + +from .base import MeshProcessor + + +class ComputeUV(MeshProcessor): + def __init__( + self, + compute_vertex_uvs: bool = True, + max_triangle_count: int = 10000, + remesh: bool = False, + ): + """Compute UV coordinates. + + Args: + compute_vertex_uvs (bool, optional): Compute texture uvs or triangle uvs, if True, compute texture uvs. Defaults to True. + max_triangle_count (int, optional): If the number of faces is larger than this value and there is no uvs, simplification will be applied. + It will cost more time to compute uvs if the number of faces is large .Defaults to 10000. + remesh (bool, optional): If set to True, remesh will be applied, and the uvs will be re-computed. Defaults to False. + """ + self.compute_vertex_uvs = compute_vertex_uvs + self.max_triangle_count = max_triangle_count + self.remesh = remesh + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for mesh in meshes: + tri_comp: TriangleComponent = mesh.get_component(TriangleComponent) + has_uvs = tri_comp.vertex_uvs.size > 0 or tri_comp.triangle_uvs.size > 0 + + mesh_o3d = mesh.get_o3d_mesh(add_scale=False, add_transform=False) + mesh_o3dt = o3d.t.geometry.TriangleMesh.from_legacy(mesh_o3d) + # if the number of faces is larger than max_triangle_count and there is no uvs, simplification will be applied + if not has_uvs: + if tri_comp.triangles.shape[0] > self.max_triangle_count: + # simplification + is_success, mesh_o3dt = simplification_decimation( + mesh_o3dt, sample_triangle_num=self.max_triangle_count + ) + if not is_success: + dexsutil.log_warning("failed to do simplification.") + # remesh need to apply after simplification + if self.remesh: + is_success, mesh_o3dt = remesh_isotropic_explicit( + mesh_o3dt, is_visual=False + ) + # has_uvs = False # need to recompute uvs + if self.compute_vertex_uvs: + if tri_comp.vertex_uvs.size == 0 or self.remesh: + if tri_comp.triangle_uvs.size > 0: + vertex_uvs = face_uv_to_vert_uv( + tri_comp.triangles, + tri_comp.triangle_uvs, + len(tri_comp.vertices), + ) + else: + _, vertex_uvs = get_mesh_auto_uv(mesh_o3dt) + tri_comp = tri_comp.new(vertex_uvs=vertex_uvs) + mesh.add_component(tri_comp) + else: + dexsutil.log_error("Not implemented for compute triangle uvs.") + return meshes + + +class MeshNormalize(MeshProcessor): + def __init__( + self, + set_origin: str = "center", + scale: float = 1.0, + unify_longest_side: bool = False, + ): + """Normalize the mesh to a standard size and origin. + + Args: + set_origin (str, optional): Set the origin location of the mesh to it's center or it's bottom center. + Choices=["center", "bottom"]. Defaults to 'center'. + scale (float, optional): Scale factor for the mesh . Defaults to 1.0. + unify_longest_side (float, optional): If True, the longest side of the mesh will be scaled to the scale factor. + Defaults to False. + """ + assert set_origin in [ + "center", + "bottom", + ], f"Invalid value for set_origin: {set_origin}" + self.set_origin = set_origin + self.scale = scale + self.unify_longest_side = unify_longest_side + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for mesh in meshes: + tri_comp: TriangleComponent = mesh.get_component(TriangleComponent) + vertices = tri_comp.vertices + # set center of the mesh to the origin + if self.set_origin == "center": + center = np.mean(vertices, axis=0) + elif self.set_origin == "bottom": + center_xy = np.mean(vertices[:, :2], axis=0) + center = np.array([center_xy[0], center_xy[1], np.min(vertices[:, 2])]) + else: + raise ValueError(f"Invalid value for set_origin: {self.set_origin}") + vertices -= center # in-place operation + + # scale the mesh + if self.unify_longest_side: + max_length = np.max(vertices, axis=0) - np.min(vertices, axis=0) + scale = self.scale / np.max(max_length) + else: + scale = self.scale + vertices *= scale + return meshes + + +class MeshAlign(MeshProcessor): + def __init__( + self, + method: str = "obb", + symmetry_axis: int = 0, + is_larger_positive: bool = True, + ): + assert method in ["obb", "svd"], f"Invalid value for method: {method}" + self.method = method + self.symmetry_axis = symmetry_axis + self.is_larger_positive = is_larger_positive + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + from dexsim.kit.meshproc import cad_standardlize_svd, cad_standardlize_obb + + for mesh in meshes: + mesh_o3d = mesh.get_o3d_mesh(add_scale=False, add_transform=False) + if self.method == "obb": + is_success, mesh_o3dt = cad_standardlize_obb( + mesh_o3d, + is_use_mesh_clean=False, + is_cad_eliminate_symmetry=True, + symmetry_axis=self.symmetry_axis, + is_larger_positive=self.is_larger_positive, + ) + elif self.method == "svd": + is_success, mesh_o3dt = cad_standardlize_obb( + mesh_o3d, + is_use_mesh_clean=False, + is_cad_eliminate_symmetry=True, + symmetry_axis=self.symmetry_axis, + is_larger_positive=self.is_larger_positive, + ) + vertices = mesh_o3dt.vertex.positions.numpy() + triangles = mesh_o3dt.triangle.indices.numpy() + tri_comp = mesh.get_component(TriangleComponent) + tri_comp = tri_comp.new(vertices=vertices, triangles=triangles) + mesh.add_component(tri_comp) + return meshes diff --git a/embodichain/toolkits/processor/types/__init__.py b/embodichain/toolkits/processor/types/__init__.py new file mode 100644 index 0000000..b6b342b --- /dev/null +++ b/embodichain/toolkits/processor/types/__init__.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .types import * diff --git a/embodichain/toolkits/processor/types/types.py b/embodichain/toolkits/processor/types/types.py new file mode 100644 index 0000000..ed48ce2 --- /dev/null +++ b/embodichain/toolkits/processor/types/types.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +CFG_DEF_TYPE_KEYS = ["type", "_TYPE", "NAME", "name", "TYPE"] diff --git a/embodichain/toolkits/toolkits.py b/embodichain/toolkits/toolkits.py new file mode 100644 index 0000000..9b2ca34 --- /dev/null +++ b/embodichain/toolkits/toolkits.py @@ -0,0 +1,18 @@ +from abc import ABCMeta, abstractmethod +import os +import cv2 +from embodichain.utils.utility import load_json + + +class ToolkitsBase(metaclass=ABCMeta): + @classmethod + def from_config(cls, path: str): + assert ( + os.path.basename(path).split(".")[-1] == "json" + ), "only json file is supported." + config = load_json(path) + return config["ToolKits"][cls.__name__] + + @abstractmethod + def call(self, **kwargs): + pass From 302608a451bcbf5b204aaa386ec1e2cce76ee2c1 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:32:36 +0800 Subject: [PATCH 17/49] Migrate motion generation part --- .../gym/motion_generation/action/action.py | 77 ++ .../motion_generation/action/arm_action.py | 710 ++++++++++++++++++ .../planner/toppra_planner.py | 258 +++++++ .../gym/motion_generation/planner/utils.py | 42 ++ 4 files changed, 1087 insertions(+) create mode 100644 embodichain/lab/gym/motion_generation/action/action.py create mode 100644 embodichain/lab/gym/motion_generation/action/arm_action.py create mode 100644 embodichain/lab/gym/motion_generation/planner/toppra_planner.py create mode 100644 embodichain/lab/gym/motion_generation/planner/utils.py diff --git a/embodichain/lab/gym/motion_generation/action/action.py b/embodichain/lab/gym/motion_generation/action/action.py new file mode 100644 index 0000000..74f5b55 --- /dev/null +++ b/embodichain/lab/gym/motion_generation/action/action.py @@ -0,0 +1,77 @@ +import numpy as np + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.gym.envs import BaseEnv +from embodichain.utils import logger + + +class Action(ABC): + r"""Base class for action terms. + + The action term is responsible for processing the raw actions sent to the environment + and applying them to the asset managed by the term. The action term is comprised of two + operations: + + """ + + env = None + scene = None + + def __init__(self, env, **kwargs) -> None: + self.env: BaseEnv = env + self.scene = self.env.scene + + # def reset(self, env_ids: Sequence[int] | None = None) -> None: + # r"""Resets the manager term. + + # Args: + # env_ids: The environment ids. Defaults to None, in which case + # all environments are considered. + # """ + # pass + + # @abstractmethod + # def process_actions(self, actions: torch.Tensor): + # r"""Processes the actions sent to the environment. + + # Note: + # This function is called once per environment step by the manager. + + # Args: + # actions: The actions to process. + # """ + # raise NotImplementedError + + # @abstractmethod + # def apply_actions(self): + # r"""Applies the actions to the asset managed by the term. + + # Note: + # This is called at every simulation step by the manager. + # """ + # raise NotImplementedError + + def __call__(self, *args) -> Any: + """Returns the value of the term required by the manager. + + In case of a class implementation, this function is called by the manager + to get the value of the term. The arguments passed to this function are + the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`). + + .. attention:: + To be consistent with memory-less implementation of terms with functions, it is + recommended to ensure that the returned mutable quantities are cloned before + returning them. For instance, if the term returns a tensor, it is recommended + to ensure that the returned tensor is a clone of the original tensor. This prevents + the manager from storing references to the tensors and altering the original tensors. + + Args: + *args: Variable length argument list. + + Returns: + The value of the term. + """ + raise NotImplementedError diff --git a/embodichain/lab/gym/motion_generation/action/arm_action.py b/embodichain/lab/gym/motion_generation/action/arm_action.py new file mode 100644 index 0000000..fc85e4f --- /dev/null +++ b/embodichain/lab/gym/motion_generation/action/arm_action.py @@ -0,0 +1,710 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, List, Tuple, Union +from embodichain.utils import logger +from embodichain.lab.gym.motion_generation.action.action import Action +from embodichain.lab.gym.motion_generation.planner.utils import ( + TrajectorySampleMethod, +) +from embodichain.lab.gym.motion_generation.planner.toppra_planner import ( + ToppraPlanner, +) + + +class ArmAction(Action): + r"""Initialize the ArmAction class.""" + + def __init__(self, env, robot_uid, **kwargs) -> None: + super().__init__(env, **kwargs) + self.agent_uid = robot_uid + if "LeftManipulator" == robot_uid or "RightManipulator" == robot_uid: + self.agent = self.scene.get_robot("DualManipulator") + else: + self.agent = self.scene.get_robot(self.agent_uid) + + init_qpos = self.agent.get_init_qpos(self.agent_uid) + self.init_ee_xpos = self.agent.get_fk(qpos=init_qpos, uid=self.agent_uid) + self.init_base_xpos = self.agent.get_base_xpos(self.agent_uid) + + self.drive_controller = self.agent.drive_controllers[self.agent_uid] + + def move( + self, + xpos_list: np.ndarray, + is_linear: bool = False, + is_wait: bool = True, + is_world_coordinates=False, + **kwargs, + ): + r"""Move the robot to a specified position. + + Args: + xpos_list (np.ndarray): List of target positions. + is_linear (bool): If True, move in a linear path. + is_wait (bool): If True, wait until the movement is completed. + is_world_coordinates (bool): If True, interpret positions in world coordinates. + kwargs (dict): Additional arguments. + + Returns: + bool: True if movement is successful, else False. + """ + if hasattr(self.agent, "move"): + res = self.agent.move( + xpos_list, + is_linear=is_linear, + is_wait=is_wait, + is_world_coordinates=is_world_coordinates, + ) + + return res + else: + return False + + def move_in_joints( + self, + qpos_list: np.ndarray, + is_linear: bool = False, + is_wait: bool = True, + **kwargs, + ): + r"""Move the robot joints to specified positions. + + Args: + qpos_list (np.ndarray): List of target joint positions. + is_linear (bool): If True, move joints in a linear path. + is_wait (bool): If True, wait until the movement is completed. + kwargs (dict): Additional arguments. + + Returns: + bool: True if movement is successful, else False. + """ + if hasattr(self.agent, "move_in_joints"): + res = self.agent.move_in_joints( + qpos_list, is_linear=is_linear, is_wait=is_wait + ) + return res + else: + return False + + def apply_transform( + self, xpos: np.ndarray, lift_vector: np.ndarray, is_local: bool = False + ): + """Apply a lift to the given pose in either local or world coordinates. + + Args: + pick_xpos (np.ndarray): The original 4x4 transformation matrix. + lift_vector (np.ndarray): A 3-element vector representing the lift in [x, y, z] directions. + is_local (bool): If True, apply the lift in local coordinates; + if False, apply in world coordinates. + + Returns: + np.ndarray: The new 4x4 transformation matrix after applying the lift. + """ + xpos = np.array(xpos) + lift_vector = np.array(lift_vector) + # Assert to ensure xpos is a 4x4 matrix and lift_vector has three components + assert xpos.shape == (4, 4), "Target pose must be a 4x4 matrix." + assert lift_vector.shape == ( + 3, + ), "Lift vector must have three components [x, y, z]." + + # Create a copy of the xpos + new_xpos = deepcopy(xpos) + + # Create a translation matrix for lifting in world coordinates + translation_matrix = np.array( + [ + [1, 0, 0, lift_vector[0]], + [0, 1, 0, lift_vector[1]], + [0, 0, 1, lift_vector[2]], + [0, 0, 0, 1], + ] + ) + + if is_local: + # Apply lift in local coordinates + new_xpos = new_xpos @ translation_matrix + else: + # Apply the translation in the world coordinate system + new_xpos = translation_matrix @ new_xpos + + return new_xpos + + @staticmethod + def create_discrete_trajectory( + agent, + uid, + xpos_list: List[np.ndarray] = None, + is_use_current_qpos: bool = True, + is_linear: bool = False, + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.QUANTITY, + sample_num: Union[float, int] = 20, + qpos_seed: np.ndarray = None, + **kwargs, + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + r"""Generate a discrete trajectory between waypoints using cartesian or joint space interpolation. + + This method supports two trajectory planning approaches: + 1. Linear interpolation: Fast, uniform spacing, no dynamics constraints + 2. ToppraPlanner: Smooth, considers velocity/acceleration limits, realistic motion + + Args: + agent: The robot agent instance + uid: Unique identifier for the robot agent + xpos_list: List of waypoints as 4x4 transformation matrices + is_use_current_qpos: Whether to use current joint angles as IK seed + is_linear: If True, use cartesian linear interpolation, else joint space + sample_method: Sampling method for ToppraPlanner (QUANTITY or TIME) + sample_num: Number of interpolated points for final trajectory + qpos_seed: Initial joint configuration for IK solving + **kwargs: Additional arguments: + - qpos_list: Optional list of joint configurations + + Returns: + A tuple containing: + - List[np.ndarray]: Joint space trajectory as a list of joint configurations + - List[np.ndarray]: Cartesian space trajectory as a list of 4x4 matrices + """ + from scipy.spatial.transform import Rotation, Slerp + import numpy as np + + def interpolate_xpos( + current_xpos: np.ndarray, target_xpos: np.ndarray, num_samples: int + ) -> list[np.ndarray]: + if num_samples < 2: + num_samples = 2 + + slerp = Slerp( + [0, 1], + Rotation.from_matrix([current_xpos[:3, :3], target_xpos[:3, :3]]), + ) + interpolated_poses = [] + for s in np.linspace(0, 1, num_samples): + interp_rot = slerp(s).as_matrix() + interp_trans = (1 - s) * current_xpos[:3, 3] + s * target_xpos[:3, 3] + interp_pose = np.eye(4) + interp_pose[:3, :3] = interp_rot + interp_pose[:3, 3] = interp_trans + interpolated_poses.append(interp_pose) + return interpolated_poses + + def calculate_point_allocations( + xpos_list, step_size=0.002, angle_step=np.pi / 90 + ): + point_allocations = [] + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + if isinstance(start_pose, torch.Tensor): + start_pose = start_pose.squeeze().cpu().numpy() + if isinstance(end_pose, torch.Tensor): + end_pose = end_pose.squeeze().cpu().numpy() + + pos_dist = np.linalg.norm(end_pose[:3, 3] - start_pose[:3, 3]) + pos_points = max(1, int(pos_dist / step_size)) + + angle_diff = Rotation.from_matrix( + start_pose[:3, :3].T @ end_pose[:3, :3] + ) + angle = abs(angle_diff.as_rotvec()).max() + rot_points = max(1, int(angle / angle_step)) + + num_points = max(pos_points, rot_points) + point_allocations.append(num_points) + + return point_allocations + + def create_qpos_dict(position: np.ndarray, dof: int) -> Dict: + """Create qpos dictionary with zero velocity and acceleration""" + return { + "position": position.tolist() + if isinstance(position, np.ndarray) + else position, + "velocity": [0.0] * dof, + "acceleration": [0.0] * dof, + } + + if hasattr(agent, "get_dof"): + agent_dof = agent.get_dof(uid) + elif hasattr(agent, "control_parts"): + agent_dof = len(agent.control_parts[uid]) + + # TODO(@Jietao Chen): max_constraints should be read from URDF file + max_constraints = kwargs.get("max_constraints", None) + if max_constraints is None: + max_constraints = { + "velocity": [0.2] * agent_dof, + "acceleration": [0.5] * agent_dof, + } + planner = ToppraPlanner(agent_dof, max_constraints) + + out_qpos_list = [] + out_xpos_list = [] + + # Handle input arguments + qpos_list = kwargs.get("qpos_list", None) + if qpos_list is not None: + qpos_list = np.asarray(qpos_list) + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in qpos_list] + elif hasattr(agent, "compute_fk"): + qpos_list = ( + torch.tensor(qpos_list) + if not isinstance(qpos_list, torch.Tensor) + else qpos_list + ) + xpos_list = [ + agent.compute_fk(qpos=q, name=uid, to_matrix=True) + for q in qpos_list + ] + else: + logger.log_warning("Agent does not support FK computation") + + if is_use_current_qpos: + current_qpos = agent.get_current_qpos(uid) + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + current_xpos = agent.get_fk(uid=uid, qpos=current_qpos) + elif hasattr(agent, "compute_fk"): + current_xpos = agent.compute_fk( + qpos=current_qpos, name=uid, to_matrix=True + ) + else: + logger.log_warning("Agent does not support FK computation") + return [], [] + + pos_diff = np.linalg.norm(current_xpos[:3, 3] - xpos_list[0][:3, 3]) + rot_diff = np.linalg.norm(current_xpos[:3, :3] - xpos_list[0][:3, :3]) + + if pos_diff > 0.001 or rot_diff > 0.01: + xpos_list = np.concatenate( + [current_xpos[None, :, :], xpos_list], axis=0 + ) + if qpos_list is not None: + qpos_list = np.concatenate( + [current_qpos[None, :], qpos_list], axis=0 + ) + + if qpos_seed is None and qpos_list is not None: + qpos_seed = qpos_list[0] + + # Input validation + if xpos_list is None or len(xpos_list) < 2: + logger.log_warning("xpos_list must contain at least 2 points") + return [], [] + + # Calculate point allocations for interpolation + interpolated_point_allocations = calculate_point_allocations( + xpos_list, step_size=0.002, angle_step=np.pi / 90 + ) + + # Generate trajectory + interpolate_qpos_list = [] + if is_linear or qpos_list is None: + # Linear cartesian interpolation + for i in range(len(xpos_list) - 1): + interpolated_poses = interpolate_xpos( + xpos_list[i], xpos_list[i + 1], interpolated_point_allocations[i] + ) + + for xpos in interpolated_poses: + # TODO: It will use computed ik in the future + if hasattr(agent, "get_ik"): + success, qpos = agent.get_ik(xpos, qpos_seed=qpos_seed, uid=uid) + elif hasattr(agent, "compute_ik"): + success, qpos = agent.compute_ik( + pose=xpos, joint_seed=qpos_seed, name=uid + ) + else: + logger.log_warning("Agent does not support IK computation") + + if not success: + logger.log_warning(f"IK solving failed for pose {xpos}") + return [], [] + interpolate_qpos_list.append(qpos) + qpos_seed = qpos + else: + # Joint space interpolation + interpolate_qpos_list = qpos_list + + # Create trajectory dictionary + current_qpos_dict = create_qpos_dict(interpolate_qpos_list[0], agent_dof) + target_qpos_dict_list = [ + create_qpos_dict(pos, agent_dof) for pos in interpolate_qpos_list[1:] + ] + + # Plan trajectory + res, out_qpos_list, *_ = planner.plan( + current_qpos_dict, + target_qpos_dict_list, + sample_method=sample_method, + sample_interval=sample_num, + ) + if not res: + logger.log_warning("Failed to plan trajectory with ToppraPlanner") + return [], [] + + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + out_xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in out_qpos_list] + elif hasattr(agent, "compute_fk"): + out_qpos_list = ( + torch.tensor(out_qpos_list) + if not isinstance(out_qpos_list, torch.Tensor) + else out_qpos_list + ) + out_xpos_list = [ + agent.compute_fk(qpos=q, name=uid, to_matrix=True) + for q in out_qpos_list + ] + else: + logger.log_warning("Agent does not support FK computation") + + return out_qpos_list, out_xpos_list + + @staticmethod + def estimate_trajectory_sample_count( + agent, + uid, + xpos_list: List[np.ndarray] = None, + qpos_list: List[np.ndarray] = None, + step_size: float = 0.01, + angle_step: float = np.pi / 90, + **kwargs, + ) -> int: + """Estimate the number of trajectory sampling points required. + + This function estimates the total number of sampling points needed to generate + a trajectory based on the given waypoints and sampling parameters. It can be + used to predict computational load and memory requirements before actual + trajectory generation. + + Args: + agent: Robot agent instance + uid: Unique identifier for the robot agent + xpos_list: List of 4x4 transformation matrices representing waypoints + qpos_list: List of joint positions (optional) + is_linear: Whether to use linear interpolation + step_size: Maximum allowed distance between consecutive points (in meters) + angle_step: Maximum allowed angular difference between consecutive points (in radians) + **kwargs: Additional parameters for further customization + + Returns: + int: Estimated number of trajectory sampling points + """ + + def rotation_matrix_to_angle(self, rot_matrix: np.ndarray) -> float: + cos_angle = (np.trace(rot_matrix) - 1) / 2 + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return np.arccos(cos_angle) + + # Input validation + if xpos_list is None and qpos_list is None: + return 0 + + # If joint position list is provided but end effector position list is not, + # convert through forward kinematics + if qpos_list is not None and xpos_list is None: + if len(qpos_list) < 2: + return 1 if len(qpos_list) == 1 else 1 + try: + if hasattr(agent, "get_fk_batch"): + xpos_list = agent.get_fk_batch(uid=uid, qpos_list=qpos_list) + else: + xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in qpos_list] + except Exception as e: + logger.log_warning(f"Forward kinematics failed: {e}") + return 0 + + if xpos_list is None or len(xpos_list) == 0: + return 1 + + if len(xpos_list) == 1: + return 1 + + total_samples = 1 # Starting point + angle_step_inv = 1.0 / angle_step + + total_pos_dist = 0.0 + total_angle = 0.0 + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + pos_diff = end_pose[:3, 3] - start_pose[:3, 3] + total_pos_dist += np.linalg.norm(pos_diff) + + try: + rot_matrix = start_pose[:3, :3].T @ end_pose[:3, :3] + angle = rotation_matrix_to_angle(rot_matrix) + total_angle += angle + except Exception: + pass + + pos_samples = max(1, int(total_pos_dist / step_size)) + rot_samples = max(1, int(total_angle / angle_step)) + + total_samples = max(pos_samples, rot_samples) + + return max(2, total_samples) + + def create_action_dict_list( + self, + xpos_list: List[np.ndarray], + qpos_list: List[np.ndarray], + ee_state: float = 0.0, + ) -> List[Dict]: + """Constructs a list of actions based on the given end effector poses on agent base coordinates and joint positions. + + Args: + xpos_list (List[np.ndarray]): A list of end effector poses. + qpos_list (List[np.ndarray]): A list of joint positions. + ee_state (float, optional): The state of the end effector (e.g., open or closed). Defaults to 0.0. + + Returns: + List[Dict]: A list of actions, where each action contains: + - "ef_pose": The end effector pose at the step. + - "qpos": The joint positions corresponding to the step. + - "ee_state": The state of the end effector (e.g., open or closed). + """ + # Check if xpos_list or qpos_list is None + if xpos_list is None or qpos_list is None: + return [] + + # Check if xpos_list and qpos_list have the same length + if len(xpos_list) != len(qpos_list): + logger.log_warning("The xpos_list and qpos_list must have the same length.") + return [] + + action_list = [ + { + "ef_pose": xpos_list[i], + "qpos": qpos_list[i], + "ee_state": ee_state, + } + for i in range(len(xpos_list)) + ] + + return action_list + + def create_back_action_list( + self, + start_xpos: np.ndarray = None, + is_move_linear: bool = False, + qpos_seed: np.ndarray = None, + lift_height: float = 0.25, + reference_xpos: np.ndarray = None, + back_distance_z: float = 0.02, + traj_num: int = 20, + **kwargs, + ) -> List[Dict]: + r"""Generate a list of actions for the robot to move back to its initial joint position after completing a task. + + Args: + start_xpos (np.ndarray, optional): The starting position of the end effector (EE) in agent base coordinates, + represented as a 4x4 transformation matrix. If None, the agent's current EE + position is used. Defaults to None. + is_move_linear (bool, optional): True for linear movement and False for joint space interpolation. Defaults to False. + qpos_seed (np.ndarray, optional): Qpos seed for solving Inverse Kinematics (IK). Defaults to None, which uses the current Qpos. + lift_height (float, optional): The vertical distance the EE should be lifted. Defaults to 0.25 meters. + reference_xpos (np.ndarray, optional): An optional reference position used to compute the back path. If None, the path will be a simple lift and return. Defaults to None. + back_distance_z (float, optional): Distance to offset reference_xpos in the -z direction. Defaults to 0.02. + traj_num (int, optional): The number of discrete steps (trajectory points) to generate for the move back action. + More steps result in a smoother trajectory. Defaults to 20. + **kwargs: Additional parameters for further customization. + + Returns: + List[Dict]: A list of actions, where each action is represented as a dictionary containing: + - "ef_pose": The end effector pose at the step. + - "qpos": The joint positions corresponding to the step. + - "ee_state": The state of the end effector (e.g., open or closed). + + Note: + - The initial joint position ('init_qpos') is the home configuration of the robot's joints, representing the agent's Qpos + at the start or in a safe/rest position. It serves as a fallback in case no valid IK solutions are found for the lifting position, + ensuring that the robot can still return to its last known configuration before the task. + """ + if start_xpos is None: + start_xpos = self.agent.get_current_xpos( + name=self.agent_uid, is_world_coordinates=False + ) + + if reference_xpos is None: + lift_xpos = self.apply_transform( + xpos=start_xpos, lift_vector=[0.0, 0.0, lift_height], is_local=False + ) + back_path = [start_xpos, lift_xpos] + else: + z_back_xpos = self._compute_offset_xpos( + start_xpos, reference_xpos, back_distance_z + ) + lift_xpos = self.apply_transform( + xpos=z_back_xpos, lift_vector=[0.0, 0.0, lift_height], is_local=False + ) + back_path = [start_xpos, z_back_xpos, lift_xpos] + + back_qpos_path = [] + + for p in back_path: + res, qpos = self.drive_controller.get_ik(p, qpos_seed) + if res: + back_qpos_path.append(qpos) + + init_qpos = self.agent.get_current_qpos(self.agent_uid) + if back_qpos_path: + back_qpos_path.append(init_qpos) + else: + back_qpos_path = [init_qpos, init_qpos] + + qpos_list, xpos_list = self.drive_controller.create_discrete_trajectory( + qpos_list=back_qpos_path, + is_use_current_qpos=False, + is_linear=is_move_linear, + sample_num=traj_num, + qpos_seed=qpos_seed, + ) + + action_list = self.create_action_dict_list( + xpos_list=xpos_list, + qpos_list=qpos_list, + ee_state=self.end_effector.open_state, + ) + if not action_list: + logger.log_warning( + "Create approach action list failed. Please check the approach path!" + ) + + return action_list + + def supplyment_action_data( + self, action_list: Dict, connected_qpos: np.ndarray = None + ) -> Dict: + r"""Supplement the action data for a DualManipulator agent. + + This function checks if the agent is a DualManipulator and determines the + appropriate end effector index based on the end effector's unique identifier. + It retrieves the current open states of the end effectors and updates the + provided action list with new joint positions and end effector states. + + Args: + action_list (Dict): A list of actions to be modified, where each action + contains 'qpos' for joint positions and 'ee_state' for end effector states. + connected_qpos (connected_qpos): + + Returns: + Dict: The modified action list with updated joint positions and end + ` effector states. If the agent is not a DualManipulator or if the end + effector UID is invalid, the original action list is returned. + """ + if self.agent.__class__.__name__ != "DualManipulator": + return action_list + + # TODO: Does the number here really correspond to the ee obtained by step_action? + if "Left" in self.end_effector_uid: + ee_idx = 0 + elif "Right" in self.end_effector_uid: + ee_idx = 1 + else: + logger.log_warning("There is no left or right gripper, no processing.") + return action_list + + all_ee_state = np.array([]) + # TODO: Here we assume that the results are obtained in the order of left and then right. + ee_list = self.env.get_end_effector() + for ee in ee_list: + ee_open_state = ee.get_open_state() + all_ee_state = np.append(all_ee_state, ee_open_state) + + if connected_qpos is None: + current_qpos = self.agent.get_current_qpos("DualManipulator") + else: + current_qpos = connected_qpos + + target_joint_ids = self.agent.get_joint_ids(self.agent_uid) + + left_current_xpos = self.agent.get_current_xpos("LeftManipulator") + right_current_xpos = self.agent.get_current_xpos("RightManipulator") + + all_xpos = np.array([left_current_xpos, right_current_xpos]) + + for action in action_list: + new_qpos = np.copy(current_qpos) + new_qpos[target_joint_ids] = action["qpos"] + action["qpos"] = new_qpos + + all_xpos[ee_idx] = action["ef_pose"] + action["ef_pose"] = all_xpos + + new_ee_state = np.copy(all_ee_state) + new_ee_state[ee_idx] = action["ee_state"] + action["ee_state"] = new_ee_state + + return action_list + + def merge_action_data( + self, left_action_list: Dict, right_action_list: Dict + ) -> Dict: + r"""Merge action data from left and right action lists. + + This function is designed to combine action data from two separate action + lists (left and right) into a single unified action list. The implementation + details for merging the action lists will depend on the specific requirements + of the application. + + Args: + left_action_list (Dict): A dictionary containing actions for the left end effector. + right_action_list (Dict): A dictionary containing actions for the right end effector. + + Returns: + Dict: A merged dictionary containing combined action data from both + left and right action lists. The exact structure of the returned + dictionary will depend on the merging logic implemented in this method. + """ + merged_action_list = [] + if self.agent.__class__.__name__ != "DualManipulator": + return merged_action_list + + current_qpos = self.agent.get_current_qpos("DualManipulator") + # Assuming both action lists have the same length + for left_action, right_action in zip(left_action_list, right_action_list): + merged_action = {} + + # Get joint IDs for left and right actions + left_joint_ids = self.agent.get_joint_ids("LeftManipulator") + right_joint_ids = self.agent.get_joint_ids("RightManipulator") + + # Initialize new qpos and ee_state for the merged action + new_qpos = np.zeros( + len(current_qpos) + ) # Assuming total count includes both left and right + + # Set joint positions based on left action + new_qpos[left_joint_ids] = left_action["qpos"][left_joint_ids] + + # Set joint positions based on right action + new_qpos[right_joint_ids] = right_action["qpos"][right_joint_ids] + + # Set end effector states + new_ee_state = [] # Assuming two end effectors: left and right + new_ee_state.extend(left_action["ee_state"]) + new_ee_state.extend(right_action["ee_state"]) + + # Construct the merged action + merged_action["qpos"] = new_qpos + merged_action["ee_state"] = new_ee_state + + # Append the merged action to the list + merged_action_list.append(merged_action) + + return merged_action_list diff --git a/embodichain/lab/gym/motion_generation/planner/toppra_planner.py b/embodichain/lab/gym/motion_generation/planner/toppra_planner.py new file mode 100644 index 0000000..0fc2b79 --- /dev/null +++ b/embodichain/lab/gym/motion_generation/planner/toppra_planner.py @@ -0,0 +1,258 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +from embodichain.lab.gym.motion_generation.planner.utils import ( + TrajectorySampleMethod, +) + +from typing import TYPE_CHECKING, Union + +try: + import toppra as ta + import toppra.constraint as constraint +except ImportError: + raise ImportError("toppra not installed. Install with `pip install toppra==0.6.3`") + +ta.setup_logging(level="WARN") + + +class ToppraPlanner: + def __init__(self, DOFs, max_constraints): + r"""Initialize the TOPPRA trajectory planner. + + Args: + DOFs: Number of degrees of freedom + max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints + """ + + self.DOFs = DOFs + self.time_step = 0.01 + self.max_constraints = max_constraints + + # Create TOPPRA constraints + self.vlims = np.array([[-v, v] for v in max_constraints["velocity"]]) + self.alims = np.array([[-a, a] for a in max_constraints["acceleration"]]) + + def plan( + self, + current_state: dict, + target_states: list[dict], + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.TIME, + sample_interval: Union[float, int] = 0.01, + ): + r"""Execute trajectory planning. + + Args: + current_state: Dictionary containing 'position', 'velocity', 'acceleration' for current state + target_states: List of dictionaries containing target states + + Returns: + Tuple of (success, positions, velocities, accelerations, times, duration) + """ + if not isinstance(sample_interval, (float, int)): + raise TypeError( + f"sample_interval must be float/int, got {type(sample_interval)}" + ) + if sample_method == TrajectorySampleMethod.TIME and sample_interval <= 0: + raise ValueError("Time interval must be positive") + elif sample_method == TrajectorySampleMethod.QUANTITY and sample_interval < 2: + raise ValueError("At least 2 sample points required") + + # Check waypoints + if len(current_state["position"]) != self.DOFs: + print(f"Current wayponit does not align") + return False, None, None, None, None, None + for target in target_states: + if len(target["position"]) != self.DOFs: + print(f"Target Wayponits does not align") + return False, None, None, None, None, None + + if ( + len(target_states) == 1 + and np.sum( + np.abs( + np.array(target_states[0]["position"]) + - np.array(current_state["position"]) + ) + ) + < 1e-3 + ): + print(f"Only two same waypoints, do not plan") + return ( + True, + np.array([current_state["position"], target_states[0]["position"]]), + np.array([[0.0] * self.DOFs, [0.0] * self.DOFs]), + np.array([[0.0] * self.DOFs, [0.0] * self.DOFs]), + 0, + 0, + ) + + # Build waypoints + waypoints = [np.array(current_state["position"])] + for target in target_states: + waypoints.append(np.array(target["position"])) + waypoints = np.array(waypoints) + + # Create spline interpolation + # NOTE(fsh):适合密集的点 + ss = np.linspace(0, 1, len(waypoints)) + + # NOTE(fsh):适合稀疏的点,密集点容易不满足CubicSpline严格递增的条件 + # len_total = 0 + # len_from_start = [0] + # for i in range(len(waypoints)-1): + # len_total += np.sum(np.abs(waypoints[i+1] - waypoints[i])) + # len_from_start.append(len_total) + # ss = np.array([cur/len_total for cur in len_from_start]) + + path = ta.SplineInterpolator(ss, waypoints) + + # Set constraints + pc_vel = constraint.JointVelocityConstraint(self.vlims) + pc_acc = constraint.JointAccelerationConstraint(self.alims) + + # Create TOPPRA instance + instance = ta.algorithm.TOPPRA( + [pc_vel, pc_acc], + path, + parametrizer="ParametrizeConstAccel", + gridpt_min_nb_points=max(100, 10 * len(waypoints)), + ) + # NOTES:合理设置gridpt_min_nb_points对加速度约束很重要 + + # Compute parameterized trajectory + jnt_traj = instance.compute_trajectory() + if jnt_traj is None: + # raise RuntimeError("Unable to find feasible trajectory") + print(f"Unable to find feasible trajectory") + return False, None, None, None, None, None + + duration = jnt_traj.duration + # Sample trajectory points + if duration <= 0: + raise ValueError(f"Duration must be positive, got {duration}") + if sample_method == TrajectorySampleMethod.TIME: + n_points = max(2, int(np.ceil(duration / sample_interval)) + 1) + ts = np.linspace(0, duration, n_points) + else: + ts = np.linspace(0, duration, num=int(sample_interval)) + + positions = [] + velocities = [] + accelerations = [] + + for t in ts: + positions.append(jnt_traj.eval(t)) + velocities.append(jnt_traj.evald(t)) + accelerations.append(jnt_traj.evaldd(t)) + + return ( + True, + np.array(positions), + np.array(velocities), + np.array(accelerations), + ts, + duration, + ) + + def is_satisfied_constraint(self, velocities, accelerations) -> bool: + r"""Check if the trajectory satisfies velocity and acceleration constraints. + + Args: + velocities: array + accelerations: array + """ + # NOTE(fsh):密集点过多的情况下,当前实现容易求解无法严格满足约束,会有一定的越界 + vlims = self.vlims * (1 + 0.1) # 允许10%误差 + alims = self.alims * (1 + 0.25) # 允许25%误差 + + vel_check = np.all((velocities >= vlims[:, 0]) & (velocities <= vlims[:, 1])) + acc_check = np.all( + (accelerations >= alims[:, 0]) & (accelerations <= alims[:, 1]) + ) + + # 超限情况 + if not vel_check: + vel_exceed_info = [] + min_vel = np.min(velocities, axis=0) + max_vel = np.max(velocities, axis=0) + for i in range(self.DOFs): + exceed_percentage = 0 + if min_vel[i] < self.vlims[i, 0]: + exceed_percentage = (min_vel[i] - self.vlims[i, 0]) / self.vlims[ + i, 0 + ] + if max_vel[i] > self.vlims[i, 1]: + temp = (max_vel[i] - self.vlims[i, 1]) / self.vlims[i, 1] + if temp > exceed_percentage: + exceed_percentage = temp + vel_exceed_info.append(exceed_percentage * 100) + print(f"Velocity exceed info: {vel_exceed_info} percentage") + + if not acc_check: + acc_exceed_info = [] + min_acc = np.min(accelerations, axis=0) + max_acc = np.max(accelerations, axis=0) + for i in range(self.DOFs): + exceed_percentage = 0 + if min_acc[i] < self.alims[i, 0]: + exceed_percentage = (min_acc[i] - self.alims[i, 0]) / self.alims[ + i, 0 + ] + if max_acc[i] > self.alims[i, 1]: + temp = (max_acc[i] - self.alims[i, 1]) / self.alims[i, 1] + if temp > exceed_percentage: + exceed_percentage = temp + acc_exceed_info.append(exceed_percentage * 100) + print(f"Acceleration exceed info: {acc_exceed_info} percentage") + + return vel_check and acc_check + + def plot_trajectory(self, positions, velocities, accelerations): + r"""Plot trajectory data. + + Args: + positions: Position array + velocities: Velocity array + accelerations: Acceleration array + """ + import matplotlib.pyplot as plt + + time_steps = np.arange(positions.shape[0]) * self.time_step + fig, axs = plt.subplots(3, 1, figsize=(10, 8)) + + for i in range(self.DOFs): + axs[0].plot(time_steps, positions[:, i], label=f"Joint {i+1}") + axs[1].plot(time_steps, velocities[:, i], label=f"Joint {i+1}") + axs[2].plot(time_steps, accelerations[:, i], label=f"Joint {i+1}") + + axs[1].plot( + time_steps, + [self.vlims[0][0]] * len(time_steps), + "k--", + label="Max Velocity", + ) + axs[1].plot(time_steps, [self.vlims[0][1]] * len(time_steps), "k--") + axs[2].plot( + time_steps, + [self.alims[0][0]] * len(time_steps), + "k--", + label="Max Accleration", + ) + axs[2].plot(time_steps, [self.alims[0][1]] * len(time_steps), "k--") + + axs[0].set_title("Position") + axs[1].set_title("Velocity") + axs[2].set_title("Acceleration") + + for ax in axs: + ax.set_xlabel("Time [s]") + ax.legend() + ax.grid() + + plt.tight_layout() + plt.show() diff --git a/embodichain/lab/gym/motion_generation/planner/utils.py b/embodichain/lab/gym/motion_generation/planner/utils.py new file mode 100644 index 0000000..8e67195 --- /dev/null +++ b/embodichain/lab/gym/motion_generation/planner/utils.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from enum import Enum +from typing import Union + + +class TrajectorySampleMethod(Enum): + r"""Enumeration for different trajectory sampling methods. + + This enum defines various methods for sampling trajectories, + providing meaningful names for different sampling strategies. + """ + TIME = "time" + """Sample based on time intervals.""" + + QUANTITY = "quantity" + """Sample based on a specified number of points.""" + + DISTANCE = "distance" + """Sample based on distance intervals.""" + + @classmethod + def from_str( + cls, value: Union[str, "TrajectorySampleMethod"] + ) -> "TrajectorySampleMethod": + if isinstance(value, cls): + return value + try: + return cls[value.upper()] + except KeyError: + valid_values = [e.name for e in cls] + raise ValueError( + f"Invalid version '{value}'. Valid values are: {valid_values}" + ) + + def __str__(self): + """Override string representation for better readability.""" + return self.value.capitalize() From 5baed9b33c46b44ebbde993daad2f494d8244c6f Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:37:24 +0800 Subject: [PATCH 18/49] migrate lab.sim.utility --- embodichain/lab/sim/utility/sim_utils copy.py | 285 +++ .../lab/sim/utility/workspace_analyzer_new.py | 1617 +++++++++++++++++ 2 files changed, 1902 insertions(+) create mode 100644 embodichain/lab/sim/utility/sim_utils copy.py create mode 100644 embodichain/lab/sim/utility/workspace_analyzer_new.py diff --git a/embodichain/lab/sim/utility/sim_utils copy.py b/embodichain/lab/sim/utility/sim_utils copy.py new file mode 100644 index 0000000..93d6e48 --- /dev/null +++ b/embodichain/lab/sim/utility/sim_utils copy.py @@ -0,0 +1,285 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import dexsim +import open3d as o3d + +from typing import List, Union, Optional + +from dexsim.types import DriveType, ArticulationFlag, LoadOption, RigidBodyShape +from dexsim.engine import Articulation +from dexsim.environment import Env, Arena +from dexsim.models import MeshObject + +from embodichain.lab.sim.cfg import ArticulationCfg, RigidObjectCfg, SoftObjectCfg +from embodichain.lab.sim.shapes import MeshCfg, CubeCfg, SphereCfg +from embodichain.utils import logger +from dexsim.kit.meshproc import get_mesh_auto_uv +import numpy as np + + +def get_dexsim_arenas() -> List[dexsim.environment.Arena]: + """Get all arenas in the default dexsim world. + + Returns: + List[dexsim.environment.Arena]: A list of arenas in the default world, or an empty list if no world is found. + """ + world = dexsim.default_world() + if world is None: + logger.log_warning(f"No default world found. Returning empty arena list.") + return [] + + env = world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + return [env] + return arenas + + +def get_dexsim_arena_num() -> int: + """Get the number of arenas in the default dexsim world. + + Returns: + int: The number of arenas in the default world, or 0 if no world is found. + """ + arenas = get_dexsim_arenas() + return len(arenas) + + +def get_dexsim_drive_type(drive_type: str) -> DriveType: + """Get the dexsim drive type from a string. + + Args: + drive_type (str): The drive type as a string. + + Returns: + DriveType: The corresponding DriveType enum. + """ + if drive_type == "force": + return DriveType.FORCE + elif drive_type == "acceleration": + return DriveType.ACCELERATION + else: + logger.error(f"Invalid dexsim drive type: {drive_type}") + + +def set_dexsim_articulation_cfg(arts: List[Articulation], cfg: ArticulationCfg) -> None: + """Set articulation configuration for a list of dexsim articulations. + + Args: + arts (List[Articulation]): List of dexsim articulations to configure. + cfg (ArticulationCfg): Configuration object containing articulation settings. + """ + + def get_drive_type(drive_pros): + if isinstance(drive_pros, dict): + return drive_pros.get("drive_type", None) + return getattr(drive_pros, "drive_type", None) + + drive_pros = getattr(cfg, "drive_pros", None) + drive_type = get_drive_type(drive_pros) if drive_pros is not None else None + + if drive_type == "force": + drive_type = DriveType.FORCE + elif drive_type == "acceleration": + drive_type = DriveType.ACCELERATION + else: + logger.log_error(f"Unknow drive type {drive_type}") + + for i, art in enumerate(arts): + art.set_physical_attr(cfg.attrs.attr()) + art.set_articulation_flag(ArticulationFlag.FIX_BASE, cfg.fix_base) + art.set_articulation_flag( + ArticulationFlag.DISABLE_SELF_COLLISION, cfg.disable_self_collision + ) + art.set_solver_iteration_counts( + min_position_iters=cfg.min_position_iters, + min_velocity_iters=cfg.min_velocity_iters, + ) + link_names = art.get_link_names() + for name in link_names: + physical_body = art.get_physical_body(name) + inertia = physical_body.get_mass_space_inertia_tensor() + inertia = np.maximum(inertia, 1e-4) + physical_body.set_mass_space_inertia_tensor(inertia) + + if i == 0 and cfg.compute_uv: + render_body = art.get_render_body(name) + if render_body: + render_body.set_projective_uv() + + # TODO: will crash when exit if not explicitly delete. + # This may due to the destruction of render body order when exiting. + del render_body + + +def is_rt_enabled() -> bool: + """Check if Ray Tracing rendering backend is enabled in the default dexsim world. + + Returns: + bool: True if Ray Tracing rendering is enabled, False otherwise. + """ + config = dexsim.get_world_config() + + return config.renderer == dexsim.types.Renderer.FASTRT + + +def create_cube( + envs: List[Union[Env, Arena]], size: List[float], uid: str = "cube" +) -> List[MeshObject]: + """Create cube objects in the specified environments or arenas. + + Args: + envs (List[Union[Env, Arena]]): List of environments or arenas to create cubes in. + size (List[float]): Size of the cube as [length, width, height] in meters. + uid (str, optional): Unique identifier for the cube objects. Defaults to "cube". + + Returns: + List[MeshObject]: List of created cube mesh objects. + """ + cubes = [] + for i, env in enumerate(envs): + cube = env.create_cube(size[0], size[1], size[2]) + cube.set_name(f"{uid}_{i}") + cubes.append(cube) + return cubes + + +def create_sphere( + envs: List[Union[Env, Arena]], + radius: float, + resolution: int = 20, + uid: str = "sphere", +) -> List[MeshObject]: + """Create sphere objects in the specified environments or arenas. + + Args: + envs (List[Union[Env, Arena]]): List of environments or arenas to create spheres in. + radius (float): Radius of the sphere in meters. + resolution (int, optional): Resolution of the sphere mesh. Defaults to 20. + uid (str, optional): Unique identifier for the sphere objects. Defaults to "sphere". + + Returns: + List[MeshObject]: List of created sphere mesh objects. + """ + spheres = [] + for i, env in enumerate(envs): + sphere = env.create_sphere(radius, resolution) + sphere.set_name(f"{uid}_{i}") + spheres.append(sphere) + return spheres + + +def load_mesh_objects_from_cfg( + cfg: RigidObjectCfg, env_list: List[Arena], cache_dir: Optional[str] = None +) -> List[MeshObject]: + """Load mesh objects from configuration. + + Args: + cfg (RigidObjectCfg): Configuration for the rigid object. + env_list (List[Arena]): List of arenas to load the objects into. + + cache_dir (Optional[str], optional): Directory for caching convex decomposition files. Defaults to None + Returns: + List[MeshObject]: List of loaded mesh objects. + """ + obj_list = [] + body_type = cfg.to_dexsim_body_type() + if isinstance(cfg.shape, MeshCfg): + + option = LoadOption() + option.rebuild_normals = cfg.shape.load_option.rebuild_normals + option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent + option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal + option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent + option.smooth = cfg.shape.load_option.smooth + + cfg: RigidObjectCfg + max_convex_hull_num = cfg.max_convex_hull_num + fpath = cfg.shape.fpath + + compute_uv = cfg.shape.compute_uv + + for i, env in enumerate(env_list): + if max_convex_hull_num > 1: + obj = env.load_actor_with_coacd( + fpath, + duplicate=True, + attach_scene=True, + option=option, + cache_path=cache_dir, + actor_type=body_type, + max_convex_hull_num=max_convex_hull_num, + ) + else: + obj = env.load_actor( + fpath, duplicate=True, attach_scene=True, option=option + ) + obj.add_rigidbody(body_type, RigidBodyShape.CONVEX) + obj.set_name(f"{cfg.uid}_{i}") + obj_list.append(obj) + + if compute_uv: + vertices = obj.get_vertices() + triangles = obj.get_triangles() + + o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) + _, uvs = get_mesh_auto_uv( + o3d_mesh, np.array(cfg.shape.project_direction) + ) + obj.set_uv_mapping(uvs) + + elif isinstance(cfg.shape, CubeCfg): + from embodichain.lab.sim.utility.sim_utils import create_cube + + obj_list = create_cube(env_list, cfg.shape.size, uid=cfg.uid) + for obj in obj_list: + obj.add_rigidbody(body_type, RigidBodyShape.BOX) + + elif isinstance(cfg.shape, SphereCfg): + from embodichain.lab.sim.utility.sim_utils import create_sphere + + obj_list = create_sphere( + env_list, cfg.shape.radius, cfg.shape.resolution, uid=cfg.uid + ) + for obj in obj_list: + obj.add_rigidbody(body_type, RigidBodyShape.SPHERE) + else: + logger.log_error( + f"Unsupported rigid object shape type: {type(cfg.shape)}. Supported types: MeshCfg, CubeCfg, SphereCfg." + ) + return obj_list + + +def load_soft_object_from_cfg( + cfg: SoftObjectCfg, env_list: List[Arena] +) -> List[MeshObject]: + obj_list = [] + + option = LoadOption() + option.rebuild_normals = cfg.shape.load_option.rebuild_normals + option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent + option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal + option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent + option.smooth = cfg.shape.load_option.smooth + option.share_mesh = False + + for i, env in enumerate(env_list): + obj = env.load_actor( + fpath=cfg.shape.fpath, duplicate=True, attach_scene=True, option=option + ) + obj.add_softbody(cfg.voxel_attr.attr(), cfg.physical_attr.attr()) + if cfg.shape.compute_uv: + vertices = obj.get_vertices() + triangles = obj.get_triangles() + + o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) + _, uvs = get_mesh_auto_uv(o3d_mesh, cfg.shape.project_direction) + obj.set_uv_mapping(uvs) + obj.set_name(f"{cfg.uid}_{i}") + obj_list.append(obj) + return obj_list diff --git a/embodichain/lab/sim/utility/workspace_analyzer_new.py b/embodichain/lab/sim/utility/workspace_analyzer_new.py new file mode 100644 index 0000000..8443d0f --- /dev/null +++ b/embodichain/lab/sim/utility/workspace_analyzer_new.py @@ -0,0 +1,1617 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import gc +import os +import time +import numpy as np +import open3d as o3d +import torch +import dexsim + +from dataclasses import dataclass +from typing import List, Tuple, Optional, Union, Dict, Sequence +from itertools import product, islice +from tqdm import tqdm + +from embodichain.utils import logger +from embodichain.lab.sim.objects import Robot +from scipy.spatial.transform import Rotation as R + + +@dataclass +class JointConfig: + """Joint configuration parameters""" + + range: Tuple[float, float] # Joint motion range + samples: int # Number of samples + + +@dataclass +class JointSamplingConfig: + """Joint space sampling configuration""" + + joints: List[JointConfig] # List of joint configurations + + +def batched(iterable, n): + """Yield successive n-sized batches from iterable.""" + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +class WorkspaceAnalyzer: + def __init__( + self, + robot: Robot, + name: str, + joint_ranges: np.ndarray, + resolution: float = np.radians(35), + ): + self.robot = robot + self.solver = self.robot.get_solver(name) + self.control_part = name + self.resolution = resolution + self.joint_ranges = np.array(joint_ranges) + self.device = "cpu" + + self._sampling_configs = self._init_sampling_configs() + + self.control_part_base_xpos = self.robot.get_control_part_base_pose( + name=name, to_matrix=True + ) + + def _get_fk_result(self, qpos: np.ndarray) -> Tuple[bool, np.ndarray]: + r"""Calculate forward kinematics + + Computes the end-effector pose given joint angles. + + Args: + qpos: Joint angles array + + Returns: + tuple: (success, pose) + - success (bool): True if calculation succeeded + - pose (np.ndarray): 4x4 homogeneous transformation matrix + """ + try: + result = self.robot.compute_fk(name=self.control_part, qpos=qpos) + + # Default values + success = False + xpos = np.eye(4) + + # Handle different return types + if isinstance(result, tuple): + if len(result) >= 2: + success, xpos = result[:2] + else: + if result is None: + success = False + else: + success = True + xpos = result + + return success, xpos + + except Exception as e: + logger.log_warning(f"FK calculation failed: {str(e)}") + return False, np.eye(4) + + def _get_ik_result( + self, xpos: np.ndarray, qpos_seed: Optional[np.ndarray] = np.array([]) + ) -> Tuple[bool, np.ndarray]: + """Calculate inverse kinematics + + Computes joint angles that achieve the desired end-effector pose. + + Args: + xpos: Target 4x4 homogeneous transformation matrix + qpos_seed: Initial joint angles for IK solver (optional) + + Returns: + tuple: (success, joint_angles) + - success (bool): True if solution found + - joint_angles (np.ndarray): Solution joint angles + """ + # try: + # Call robot's IK solver + result = self.robot.get_ik( + uid=self.control_part, xpos=xpos, qpos_seed=qpos_seed + ) + + # Default values + success = False + q_sol = np.zeros(self.robot.get_dof(self.control_part)) + + # Process IK result + if isinstance(result, tuple): + if len(result) >= 2: + success, q_sol = result[:2] + else: + if result is None: + success = False + else: + success = True + q_sol = result + + return success, q_sol + + # except Exception as e: + # logger.log_warning(f"IK calculation failed: {str(e)}") + # return False, None + + def _init_sampling_configs(self) -> Dict[str, JointSamplingConfig]: + r"""Initialize joint space sampling configurations + + Returns: + Dictionary mapping config names to sampling configurations + """ + original_ranges = self.joint_ranges.copy() + + self.joint_ranges = np.clip(self.joint_ranges, -np.pi, np.pi) + + clipped_joints = [] + for i, (orig, clipped) in enumerate(zip(original_ranges, self.joint_ranges)): + if not np.allclose(orig, clipped): + clipped_joints.append(i) + + if clipped_joints: + logger.log_info("Some joint ranges were clipped to [-π, π]:") + for joint_idx in clipped_joints: + orig_range = original_ranges[joint_idx] + new_range = self.joint_ranges[joint_idx] + logger.log_info( + f"Joint {joint_idx}: [{orig_range[0]:.3f}, {orig_range[1]:.3f}] -> " + f"[{new_range[0]:.3f}, {new_range[1]:.3f}] rad" + ) + + # Calculate joint range sizes + joint_ranges_size = np.abs(self.joint_ranges[:, 1] - self.joint_ranges[:, 0]) + + # Calculate number of samples per joint + samples = [ + max(3, int(np.ceil(range_size / self.resolution))) + for range_size in joint_ranges_size + ] + + # Create default sampling configuration + sampling_config = JointSamplingConfig( + joints=[ + JointConfig(range=joint_range, samples=sample_num) + for joint_range, sample_num in zip(self.joint_ranges, samples) + ], + ) + + # Log sampling configuration info + logger.log_info(f"Analyze control part: [{self.control_part}]") + logger.log_info( + f"Angular Resolution: {self.resolution:.3f} rad ({np.degrees(self.resolution):.1f}°)" + ) + for i, (joint_range, num_samples) in enumerate(zip(self.joint_ranges, samples)): + range_size = abs(joint_range[1] - joint_range[0]) + actual_resolution = range_size / (num_samples - 1) if num_samples > 1 else 0 + logger.log_info( + f"- Joint {i+1}: Range={range_size:.2f}rad, Samples={num_samples}, " + f"Actual Resolution={actual_resolution:.3f}rad ({np.degrees(actual_resolution):.1f}°)" + ) + + return sampling_config + + def _generate_combinations(self, joint_values): + r"""Generator function to produce joint angle combinations one at a time + + This avoids generating all combinations at once to save memory + """ + if not joint_values: + yield [] + else: + for first in joint_values[0]: + for rest in self._generate_combinations(joint_values[1:]): + yield [first] + rest + + def _process_batch( + self, batch: List[np.ndarray], timeout: float = 10.0 + ) -> List[np.ndarray]: + r"""Process a batch of joint configurations + + Args: + batch: List of joint configurations to process + timeout: Batch processing timeout in seconds + + Returns: + List of end effector XYZ positions + """ + positions = [] + start_time = time.time() + + for qpos in batch: + if time.time() - start_time > timeout: + logger.log_warning(f"Batch processing timeout ({timeout}s)") + break + + try: + qpos = np.array(qpos) + res, xpos = self._get_fk_result(qpos=qpos) + if res: + # Only save XYZ position + positions.append(xpos[:3, 3]) + except Exception as e: + logger.log_warning(f"Error processing joint configuration: {str(e)}") + continue + + return positions + + def _validate_params(self, cache_mode: str, save_dir: str): + r"""Validate input parameters""" + if cache_mode not in ["memory", "disk"]: + raise ValueError("cache_mode must be 'memory' or 'disk'") + + if cache_mode == "disk" and save_dir is None: + raise ValueError("save_dir must be provided when cache_mode is 'disk'") + + def _init_joint_values(self, config: JointSamplingConfig) -> List[np.ndarray]: + r"""Initialize joint sampling values""" + return [ + np.linspace(joint.range[0], joint.range[1], joint.samples) + for joint in config.joints + ] + + def _save_batch_results( + self, positions: List[np.ndarray], save_dir: str, batch_id: int + ): + r"""Save results for a single batch + + Args: + positions: List of XYZ positions + save_dir: Directory to save results + batch_id: Batch identifier + """ + + batch_dir = os.path.join(save_dir, "batches") + # Ensure directory exists + os.makedirs(batch_dir, exist_ok=True) + # Save numpy array + batch_path = os.path.join(batch_dir, f"batch_{batch_id:04d}.npy") + np.save(batch_path, np.array(positions)) + logger.log_info( + f"Saved batch {batch_id}: {len(positions)} points -> {batch_path}" + ) + + def _process_point_cloud( + self, + positions: List[np.ndarray], + voxel_size: float = 0.05, + nb_neighbors: int = 20, + std_ratio: float = 2.0, + is_voxel_down: bool = True, + ) -> o3d.geometry.PointCloud: + r"""Process sampled point cloud data + + Args: + positions: List of XYZ positions + voxel_size: Voxel size (m) + nb_neighbors: Number of neighbors for statistical filter + std_ratio: Standard deviation ratio for statistical filter + + Returns: + o3d.geometry.PointCloud: Processed point cloud + """ + # Create point cloud object + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(np.array(positions)) + + logger.log_info(f"Point cloud processing:") + + if is_voxel_down: + # 1. Voxel downsampling + logger.log_info( + f"- Performing voxel downsampling (voxel_size={voxel_size}m)" + ) + pcd_down = pcd.voxel_down_sample(voxel_size=voxel_size) + + # 2. Statistical outlier removal + logger.log_info( + f"- Removing outliers (neighbors={nb_neighbors}, std_ratio={std_ratio})" + ) + cl, ind = pcd_down.remove_statistical_outlier( + nb_neighbors=nb_neighbors, std_ratio=std_ratio + ) + pcd_clean = pcd_down.select_by_index(ind) + else: + pcd_clean = pcd + + # 3. Estimate normals + logger.log_info("- Estimating point cloud normals") + pcd_clean.estimate_normals( + search_param=o3d.geometry.KDTreeSearchParamHybrid( + radius=voxel_size * 2, max_nn=30 + ) + ) + + # 4. Orient normals consistently + logger.log_info("- Orienting normals consistently") + + pcd_clean.orient_normals_to_align_with_direction() + + # 5. Add color based on distance to origin + points = np.asarray(pcd_clean.points) + + # Calculate distances to origin + distances = np.linalg.norm(points, axis=1) + + # Find the centroid + center = np.mean(points, axis=0) + + # Calculate distances to the centroid + distances_to_center = np.linalg.norm(points - center, axis=1) + + # Normalize distances + max_dist = np.max(distances_to_center) + normalized_distances = distances_to_center / max_dist + + # Create HSV color space (green to red gradient) + hsv_colors = np.zeros((len(points), 3)) + hsv_colors[:, 0] = 0.3333 * ( + 1 - normalized_distances + ) # Hue: green (0.3333) to red (0) + hsv_colors[:, 1] = 1.0 # Saturation: max saturation + hsv_colors[:, 2] = 0.8 # Value: medium brightness + + # Convert HSV to RGB + colors = np.zeros_like(points) + for i in range(len(points)): + h, s, v = hsv_colors[i] + + # HSV to RGB conversion + c = v * s + x = c * (1 - abs((h * 6) % 2 - 1)) + m = v - c + + if h < 1 / 6: + rgb = [c, x, 0] + elif h < 2 / 6: + rgb = [x, c, 0] + elif h < 3 / 6: + rgb = [0, c, x] + elif h < 4 / 6: + rgb = [0, x, c] + elif h < 5 / 6: + rgb = [x, 0, c] + else: + rgb = [c, 0, x] + + colors[i] = [r + m for r in rgb] + + pcd_clean.colors = o3d.utility.Vector3dVector(colors) + + logger.log_info(f"- Original points: {len(positions)}") + logger.log_info(f"- Processed points: {len(pcd_clean.points)}") + logger.log_info( + f"- Distance range: {np.min(distances):.3f}m ~ {np.max(distances):.3f}m" + ) + + return pcd_clean + + def _merge_batch_files(self, save_dir: str, total_batches: int) -> List[np.ndarray]: + r"""Merge all sampled points from batch files + + Args: + save_dir: Directory to save data + total_batches: Total number of batches + + Returns: + List[np.ndarray]: List of all sampled positions + """ + # Get current date for subdirectory name + # current_date = time.strftime("%Y%m%d") + batch_dir = os.path.join(save_dir, "batches") + + logger.log_info("Starting to merge batch files...") + all_xpos = [] + + # Load and process batches + for batch_id in tqdm(range(total_batches), desc="Merging progress"): + batch_path = os.path.join(batch_dir, f"batch_{batch_id:04d}.npy") + + try: + # Load batch data + batch_data = np.load(batch_path) + all_xpos.extend(batch_data) + # Delete processed batch file + # os.remove(batch_path) + except Exception as e: + logger.log_warning(f"Error processing batch {batch_id}: {str(e)}") + + # Remove empty batch directory + if os.path.exists(batch_dir) and not os.listdir(batch_dir): + os.rmdir(batch_dir) + + logger.log_info(f"Merging complete: {len(all_xpos)} sampled points") + return all_xpos + + def sample_qpos_workspace( + self, + resolution: float = None, + cache_mode: str = "memory", # Cache mode "memory" or "disk" + save_dir: str = None, # Save directory + batch_size: int = 100000, # Batch processing size + save_threshold: int = 10000000, # Save threshold + use_cached: bool = True, # Use cached results if available + ) -> List[np.ndarray]: + r"""Sample joint space and calculate corresponding workspace poses + + Args: + resolution: Sampling resolution + cache_mode: Cache mode ("memory" - in-memory list, "disk" - disk storage) + save_dir: Save directory path (must be provided when cache_mode="disk") + batch_size: Number of samples per batch + save_threshold: Number of samples to accumulate before saving in disk mode + use_cached: Whether to use cached results if available (only in disk mode) + + Returns: + List[np.ndarray]: List of valid end effector poses of poses (in memory mode) or empty list (in disk mode) + """ + if resolution is not None: + self.resolution = resolution + self._sampling_configs = self._init_sampling_configs() + + # Validate parameters + self._validate_params(cache_mode, save_dir) + + # Initialize sampling configuration + joint_values = self._init_joint_values(self._sampling_configs) + total_samples = np.prod([len(values) for values in joint_values]) + + logger.log_info( + f"Sampling joint space with resolution {np.degrees(self.resolution):.1f}°..." + ) + logger.log_info(f"Total sample points: {total_samples}") + logger.log_info(f"Cache mode: {cache_mode}") + logger.log_info(f"Save directory: {save_dir if save_dir else 'N/A'}") + logger.log_info(f"Sampling using: {self.device}") + + if cache_mode == "memory": + return self._sample_memory_mode(joint_values, total_samples, batch_size) + else: + return self._sample_disk_mode( + joint_values, + total_samples, + save_dir, + batch_size, + save_threshold, + use_cached, + ) + + def _sample_memory_mode( + self, joint_values: List[np.ndarray], total_samples: int, batch_size: int + ) -> List[np.ndarray]: + r"""Memory mode sampling""" + if not self.robot.pk_serial_chain: + all_xpos = [] + for qpos in tqdm( + product(*joint_values), + total=total_samples, + desc="Memory mode serial sampling", + ): + q = np.array(qpos, dtype=np.float32) + res, xpos = self._get_fk_result(qpos=q) + if res: + all_xpos.append(xpos) + if len(all_xpos) % 1000 == 0: + gc.collect() + return all_xpos + self.chain = self.robot.pk_serial_chain[self.control_part].to( + dtype=torch.float32, device=self.device + ) + sampled_xpos = [] + joint_combinations = product(*joint_values) + + T_tcp = torch.as_tensor(self.solver.get_tcp(), dtype=torch.float32).to( + self.device + ) + + with tqdm( + total=total_samples, + desc=f"Sampling {total_samples} points (batch={batch_size})", + ) as pbar: + for qpos_batch in batched(joint_combinations, batch_size): + # compute and collect + batch_mats = self._compute_batch_xpos(qpos_batch, T_tcp) + sampled_xpos.extend(batch_mats) + + # advance progress bar and cleanup + pbar.update(len(batch_mats)) + gc.collect() + + return sampled_xpos + + def _sample_disk_mode( + self, + joint_values: List[np.ndarray], + total_samples: int, + save_dir: str, + batch_size: int, + save_threshold: int, + use_cached: bool = True, + ) -> List[np.ndarray]: + r"""Disk mode sampling, with serial fallback if no pk_serial_chain.""" + # 1) If batches already exist, just merge & return + batches_dir = os.path.join(save_dir, "batches") + if os.path.exists(batches_dir) and use_cached: + npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] + if npy_files: + return self._merge_batch_files(save_dir, len(npy_files)) + + sampled_xpos = [] + current_batch = [] + total_processed = 0 + batch_count = 0 + + # 2) Choose serial vs. GPU path + if not self.robot.pk_serial_chain: + # serial, one qpos at a time + current_batch = [] + with tqdm(total=total_samples, desc="Disk mode serial sampling") as pbar: + for qpos in product(*joint_values): + q = np.array(qpos, dtype=np.float32) + res, xpos = self._get_fk_result(qpos=q) + if res: + current_batch.append(xpos) + # flush by batch_size + if len(current_batch) >= batch_size: + sampled_xpos.extend(current_batch) + total_processed += len(current_batch) + current_batch = [] + # flush to disk by save_threshold + if len(sampled_xpos) >= save_threshold: + self._save_batch_results( + sampled_xpos, save_dir, batch_count + ) + batch_count += 1 + sampled_xpos = [] + gc.collect() + pbar.update(1) + + else: + self.chain = self.robot.pk_serial_chain[self.control_part].to( + dtype=torch.float32, device=self.device + ) + # GPU‐batched path + T_tcp = torch.as_tensor( + self.robot.get_tcp(self.control_part), + dtype=torch.float32, + device=self.device, + ) + with tqdm( + total=total_samples, desc=f"Sampling in {batch_size}-sized batches" + ) as pbar: + for qpos_batch in batched(product(*joint_values), batch_size): + batch_mats = self._compute_batch_xpos(qpos_batch, T_tcp) + sampled_xpos.extend(batch_mats) + total_processed += len(batch_mats) + # flush to disk by save_threshold + if len(sampled_xpos) >= save_threshold: + self._save_batch_results(sampled_xpos, save_dir, batch_count) + batch_count += 1 + sampled_xpos = [] + gc.collect() + pbar.update(len(batch_mats)) + + # Process remaining samples + if sampled_xpos: + self._save_batch_results(sampled_xpos, save_dir, batch_count) + batch_count += 1 + + logger.log_info( + f"Sampling complete: {total_processed} samples, {batch_count} batches" + ) + + # If there are saved batches, read and merge them to process point cloud + if batch_count > 0: + all_xpos = self._merge_batch_files(save_dir, batch_count) + return all_xpos + + return None + + def sample_xpos_workspace( + self, + ref_xpos: np.ndarray, + xpos_resolution: float = 0.2, + qpos_resolution: float = np.radians(60), + cache_mode: str = "memory", + save_dir: str = None, + batch_size: int = 5000, + save_threshold: int = 10000000, + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1500, + num_samples: int = 5, + use_cached: bool = True, + ) -> List[np.ndarray]: + r"""Sample Cartesian space and calculate corresponding joints + + Args: + ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the + orientation for IK solutions. Translation components will + be overridden during sampling. + xpos_resolution (float, optional): Cartesian space sampling resolution in meters. + Smaller values provide finer sampling but increase + computation time. Defaults to 0.2 meters. + qpos_resolution (float, optional): Angular resolution for initial joint space + sampling in radians. Used to determine workspace + bounds. Defaults to 60 degrees. + cache_mode (str, optional): Caching strategy, either: + - "memory": Store samples in memory (faster but memory-intensive) + - "disk": Save samples to disk (slower but memory-efficient) + Defaults to "memory". + save_dir (str, optional): Directory path for saving results when using disk cache. + Must be provided if cache_mode is "disk". Defaults to None. + batch_size (int, optional): Number of samples to process in each batch. + Larger values may improve performance but increase + memory usage. Defaults to 5000. + save_threshold (int, optional): Number of samples to accumulate before saving + to disk in disk mode. Defaults to 10,000,000. + pos_eps (float, optional): Position tolerance for IK solutions in meters. + Defaults to 5e-4. + rot_eps (float, optional): Rotation tolerance for IK solutions in radians. + Defaults to 5e-4. + max_iterations (int, optional): Maximum iterations for IK solver. + Defaults to 1500. + num_samples (int, optional): Number of IK samples to generate for each position. + Defaults to 5. + use_cached (bool, optional): Whether to use cached results if available (only in disk mode) + + Returns: + List[np.ndarray]: List of valid end effector poses + """ + # logger.set_log_level(level="error") + + start_time = time.time() + try: + qpos_sampled_xpos = self.sample_qpos_workspace( + resolution=qpos_resolution, + cache_mode="memory", + batch_size=5000, + save_threshold=save_threshold, + ) + + qpos_all_positions = [xpos[:3, 3] for xpos in qpos_sampled_xpos] + qpos_pcd = self._process_point_cloud(positions=qpos_all_positions) + aabb = qpos_pcd.get_axis_aligned_bounding_box() + + sample_points = self._sample_in_aabb( + aabb.min_bound, aabb.max_bound, xpos_resolution + ) + + # Validate parameters + self._validate_params(cache_mode, save_dir) + + if cache_mode == "memory": + return self._sample_xpos_memory_mode( + positions=sample_points, + ref_xpos=ref_xpos, + batch_size=batch_size, + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + ) + else: + return self._sample_xpos_disk_mode( + positions=sample_points, + ref_xpos=ref_xpos, + save_dir=save_dir, + batch_size=batch_size, + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + save_threshold=save_threshold, + use_cached=use_cached, + ) + finally: + logger.set_log_level(level="info") + # Record the end time + end_time = time.time() + # Calculate the time cost + time_cost = end_time - start_time + logger.log_info(f"Time cost: {time_cost:.2f} seconds") + + def _compute_batch_xpos( + self, qpos_batch: Sequence[np.ndarray], T_tcp: torch.Tensor + ) -> List[np.ndarray]: + """Given a batch of q-poses, compute TCP-transformed FK matrices + and return them as numpy float16 arrays.""" + # 1) to NumPy (float32) → to torch.Tensor on correct device + np_qpos = np.array(qpos_batch, dtype=np.float32) + tensor_qpos = torch.as_tensor(np_qpos, dtype=torch.float32, device=self.device) + + # 2) batched forward kinematics → 4×4 matrices + ret_batch = self.chain.forward_kinematics( + tensor_qpos, end_only=True + ).get_matrix() + + # 3) apply TCP offset + T_final = torch.matmul(ret_batch, T_tcp) + + T_final = torch.bmm( + self.control_part_base_xpos.to(dtype=torch.float32).expand( + T_final.shape[0], -1, -1 + ), + T_final, + ) + + # 4) move to CPU, cast to float16 + T_cpu16 = T_final.cpu().to(dtype=torch.float16) + + # 5) return list of numpy arrays + return [mat.numpy() for mat in T_cpu16] + + def _sample_in_aabb( + self, min_bound: np.ndarray, max_bound: np.ndarray, resolution: float + ) -> np.ndarray: + r"""Uniformly sample within an axis-aligned bounding box (AABB) + + Args: + min_bound: AABB minimum bound [x_min, y_min, z_min] + max_bound: AABB maximum bound [x_max, y_max, z_max] + resolution: Sampling resolution (m) + + Returns: + np.ndarray: Array of sampled points with shape (N, 3) + """ + # Calculate number of samples per axis + num_samples = np.ceil((max_bound - min_bound) / resolution).astype(int) + + # Ensure at least 2 samples per dimension + num_samples = np.maximum(num_samples, 2) + + # Generate sample points for each axis + x = np.linspace(min_bound[0], max_bound[0], num_samples[0]) + y = np.linspace(min_bound[1], max_bound[1], num_samples[1]) + z = np.linspace(min_bound[2], max_bound[2], num_samples[2]) + + # Create a grid of points + X, Y, Z = np.meshgrid(x, y, z) + + # Convert grid to N×3 array + points = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T + + logger.log_info(f"Sampling space range:") + logger.log_info(f"- X: [{min_bound[0]:.3f}, {max_bound[0]:.3f}] m") + logger.log_info(f"- Y: [{min_bound[1]:.3f}, {max_bound[1]:.3f}] m") + logger.log_info(f"- Z: [{min_bound[2]:.3f}, {max_bound[2]:.3f}] m") + logger.log_info(f"Sampling resolution: {resolution:.3f} m") + logger.log_info(f"Number of samples: {len(points)}") + + return points + + def _sample_xpos_memory_mode( + self, + positions: List[np.ndarray], + ref_xpos: np.ndarray, + batch_size: int, + pos_eps: float, + rot_eps: float, + max_iterations: int, + num_samples: int, + ) -> List[np.ndarray]: + r"""Memory mode sampling with batch processing and progress bar + + Args: + positions: List of positions to validate. + ref_xpos: Reference end effector pose. + batch_size (int): Number of positions to process in each batch. + + Returns: + List[np.ndarray]: List of valid end effector poses. + """ + valid_xpos = [] + + # Get the degree of freedom (DOF) of the robot to create joint seed + dof_number = self.robot.get_dof(self.control_part) + + # Total number of positions to process + total_positions = len(positions) + + # TODO: Optimize efficiency by using batch IK if available. + # If self.robot implements get_batch_ik_solution, prefer batch processing for IK to significantly accelerate sampling. + # Otherwise, fall back to single-point IK calls (slower). + # This check ensures the most efficient computation path is used automatically. + # (Batch IK can greatly improve performance for large-scale workspace sampling.) + # Example: + # if hasattr(self.robot, "get_batch_ik_solution"): + if False: + # If the robot has get_batch_ik_solution, use it for batch processing + num_batches = (total_positions // batch_size) + ( + 1 if total_positions % batch_size != 0 else 0 + ) + + # Create progress bar with total samples and batch size + with tqdm( + total=total_positions, desc=f"Sampling in {batch_size}-sized batches" + ) as pbar: + # Iterate through positions in batches + for batch_idx in range(num_batches): + # Select the current batch of positions + batch_positions = positions[ + batch_idx * batch_size : (batch_idx + 1) * batch_size + ] + + # Create a batch of target poses (batch_size, 4, 4) + target_xpos_batch = [] + for point in batch_positions: + target_xpos = ref_xpos.copy() + target_xpos[:3, 3] = point + target_xpos_batch.append(target_xpos) + + # Convert to numpy array (batch_size, 4, 4) + target_xpos_batch = np.array(target_xpos_batch) + # Create joint seed batch of zeros (batch_size, dof) + joint_seed_batch = np.zeros((len(batch_positions), dof_number)) + # Use get_batch_ik_solution for batch processing + res, _ = self.robot.get_batch_ik_solution( + target_xpos_list=target_xpos_batch, # Batch of target poses + joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) + uid=self.control_part, + is_world_coordinates=False, # Set based on your use case + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + ) + + # Append valid target poses to valid_xpos + for j, is_valid in enumerate(res): + if is_valid: + valid_xpos.append(target_xpos_batch[j]) + + # Update the progress bar after processing the batch + pbar.update( + len(batch_positions) + ) # Update progress bar with batch size + + # Perform garbage collection after every batch + if len(valid_xpos) % 1000 == 0: + gc.collect() + + else: + # Fallback to the previous method if get_batch_ik_solution is not available + with tqdm( + total=total_positions, desc="Sampling in single IK calls" + ) as pbar: + for point in positions: + # Construct target pose + target_xpos = ref_xpos.copy() + target_xpos[:3, 3] = point + + # Calculate IK using the old method (get_ik) + res, _ = self.robot.get_ik(uid=self.control_part, xpos=target_xpos) + if res: + valid_xpos.append(target_xpos) + + # Update the progress bar after each point is processed + pbar.update(1) # Update progress bar with 1 point + + # Perform garbage collection after every 1000 valid points + if len(valid_xpos) % 1000 == 0: + gc.collect() + + return valid_xpos if valid_xpos else None + + def _sample_xpos_disk_mode( + self, + positions: List[np.ndarray], + ref_xpos: np.ndarray, + save_dir: str, + batch_size: int, + pos_eps: float, + rot_eps: float, + max_iterations: int, + num_samples: int, + save_threshold: int, + use_cached: bool = True, + ) -> List[np.ndarray]: + r"""Disk mode sampling with batch processing + + Args: + positions: List of positions to validate. + ref_xpos: Reference end effector pose. + save_dir: Directory to save results. + batch_size: Number of samples per batch. + save_threshold: Number of samples to accumulate before saving. + + Returns: + List[np.ndarray]: List of valid end effector poses. + """ + valid_positions = [] + current_batch = [] + total_processed = 0 + batch_count = 0 + # Record the start time + logger.log_info(f"Starting disk mode sampling...") + logger.log_info(f"Save directory: {save_dir}") + + # If there are saved batches, read and return without calculation + batches_dir = os.path.join(save_dir, "batches") + if os.path.exists(batches_dir) and use_cached: + npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] + batch_count = len(npy_files) + + if batch_count > 0: + all_xpos = self._merge_batch_files(save_dir, batch_count) + return all_xpos + + # Check if self.robot has the method get_batch_ik_solution + if hasattr(self.robot, "get_batch_ik_solution"): + # If get_batch_ik_solution is available, use batch processing + with tqdm(total=len(positions), desc="Disk mode sampling") as pbar: + for i in range(0, len(positions), batch_size): + # Select the current batch of positions + batch_positions = positions[i : i + batch_size] + + # Create a batch of target poses (batch_size, 4, 4) + target_xpos_batch = [] + for point in batch_positions: + target_xpos = ref_xpos.copy() + target_xpos[:3, 3] = point + target_xpos_batch.append(target_xpos) + + # Convert to numpy array (batch_size, 4, 4) + target_xpos_batch = np.array(target_xpos_batch) + + # Create the joint seed batch (batch_size, dof) + dof_number = self.robot.get_dof(self.control_part) + joint_seed_batch = np.zeros((len(batch_positions), dof_number)) + + # Use get_batch_ik_solution for batch processing + res, _ = self.robot.get_batch_ik_solution( + target_xpos_list=target_xpos_batch, # Batch of target poses + joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) + uid=self.control_part, + is_world_coordinates=False, # Set based on your use case + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + ) + + # Append valid target poses to valid_positions + for j, is_valid in enumerate(res): + if is_valid: + current_batch.append(target_xpos_batch[j]) + + # Process batch when it reaches batch_size + if len(current_batch) >= batch_size: + valid_positions.extend(current_batch) + total_processed += len(current_batch) + + current_batch = [] + + # Save when reaching the threshold + if len(valid_positions) >= save_threshold: + self._save_batch_results( + valid_positions, save_dir, batch_count + ) + batch_count += 1 + valid_positions = [] + gc.collect() + + # Update the progress bar + pbar.update(len(batch_positions)) # Update with batch size + + else: + # Fallback to the previous method if get_batch_ik_solution is not available + with tqdm(total=len(positions), desc="Disk mode sampling") as pbar: + for point in positions: + # Construct target pose + target_xpos = ref_xpos.copy() + target_xpos[:3, 3] = point + + # Calculate IK using the old method (get_ik) + res, _ = self.robot.compute_ik( + name=self.control_part, pose=target_xpos + ) + if res: + current_batch.append(target_xpos) + + # Process batch when it reaches batch_size + if len(current_batch) >= batch_size: + valid_positions.extend(current_batch) + total_processed += len(current_batch) + + current_batch = [] + + # Save when reaching the threshold + if len(valid_positions) >= save_threshold: + self._save_batch_results( + valid_positions, save_dir, batch_count + ) + batch_count += 1 + valid_positions = [] + gc.collect() + + # Update the progress bar + pbar.update(1) # Update with 1 point per iteration + + # Process remaining data + if current_batch: + valid_positions.extend(current_batch) + total_processed += len(current_batch) + + if valid_positions: + self._save_batch_results(valid_positions, save_dir, batch_count) + batch_count += 1 + + logger.log_info( + f"Sampling complete: {total_processed} samples, {batch_count} batches" + ) + + # If there are saved batches, read and merge them to process point cloud + if batch_count > 0: + all_xpos = self._merge_batch_files(save_dir, batch_count) + return all_xpos + + return None + + def sample_voxel_workspace( + self, + voxel_size: float = 0.04, + num_directions: int = 50, + num_yaws: int = 6, + pos_eps: float = 2e-4, + rot_eps: float = 2e-4, + max_iterations: int = 1500, + num_samples: int = 5, + cache_mode: str = "memory", + save_dir: str = None, + batch_size: int = 5000, + save_threshold: int = 10000000, + use_cached: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray]]: + r"""Sample Cartesian space using voxel‐based IK reachability. + + Divides the workspace into a grid of voxels around the arm base, then for + each voxel center sweeps through a set of directions and yaw rotations, + calling the IK solver to test reachability. + + Args: + voxel_size (float, optional): + Edge length of each cubic voxel in meters. + Smaller voxels give finer resolution but increase computation. + Defaults to 0.04. + num_directions (int, optional): + Number of unit‐vector directions to sample on the sphere for each + voxel. More directions improve angular coverage at the cost of + additional IK calls. Defaults to 50. + num_yaws (int, optional): + Number of discrete yaw rotations **around the local Z‐axis** to + attempt for each direction when solving IK. Higher values increase + rotational sampling but incur more IK calls. Defaults to 6. + pos_eps (float, optional): + Position tolerance for IK solutions in meters. + Defaults to 5e-4. + rot_eps (float, optional): + Rotation tolerance for IK solutions in radians. + Defaults to 5e-4. + max_iterations (int, optional): + Maximum iterations for IK solver. + Defaults to 1500. + num_samples (int, optional): + Number of IK samples to generate for each position. + Defaults to 5. + cache_mode (str, optional): + Caching strategy for IK results: + - `"memory"`: keep all samples in RAM (fast, memory‐intensive) + - `"disk"`: stream to disk in batches (slower, memory‐efficient) + Defaults to `"memory"`. + save_dir (str, optional): + Directory path for saving/loading cached batches when using + `cache_mode="disk"`. Required in disk mode. Defaults to None. + batch_size (int, optional): + Number of successful IK poses to accumulate before adding them to + the in‐memory pool. Larger values may improve throughput but + increase temporary memory usage. Defaults to 5000. + save_threshold (int, optional): + Number of poses in the in‐memory pool at which point they are + written out to disk as a batch file. Helps limit peak RAM use. + Defaults to 10,000,000. + use_cached: Whether to use cached results if available (only in disk mode) + + Returns: + Tuple[ + np.ndarray, # (M,3) array of voxel‐center coordinates + np.ndarray, # (M,) array of success counts per center + List[np.ndarray] # flat list of all valid 4×4 IK pose matrices + ] + """ + logger.set_log_level(level="error") + + try: + self._validate_params(cache_mode, save_dir) + + logger.log_info(f"Sampling robot workspace with voxel size {voxel_size}...") + logger.log_info(f"Cache mode: {cache_mode}") + logger.log_info(f"Sampling using: {self.device}") + + arm_base_pos = self.robot.get_base_xpos(name=self.control_part)[:3, 3] + arm_ee_pos = self.robot.get_current_xpos(name=self.control_part)[:3, 3] + arm_length = float(np.linalg.norm(arm_ee_pos - arm_base_pos)) + + if cache_mode == "memory": + return self._sample_voxels_memory_mode( + voxel_size, num_directions, num_yaws, arm_base_pos, arm_length + ) + else: + return self._sample_voxels_disk_mode( + voxel_size, + num_directions, + num_yaws, + arm_base_pos, + arm_length, + save_dir=save_dir, + save_threshold=save_threshold, + batch_size=batch_size, + use_cached=use_cached, + ) + finally: + logger.set_log_level(level="info") + + def _voxel_centers_in_sphere(self, arm_base, arm_length, voxel_size): + """ + Compute centers of all voxels of size `voxel_size` whose centers lie + within a sphere of radius `arm_length` around `arm_base`, using the + exact range definitions you provided for x, y, and z. + + Args: + arm_base (sequence of 3 floats): (x, y, z) origin. + arm_length (float): radius of the sphere. + voxel_size (float): edge length of each cubic voxel. + + Returns: + numpy.ndarray of shape (M, 3): each row is a valid (x, y, z) center. + """ + x, y, z = arm_base + r = float(arm_length) + half = voxel_size / 2.0 + + # follow your exact ranges + x_range = np.arange(x - half, x + r + half, voxel_size) + y_range = np.arange(y - half, y + r + half, voxel_size) + z_range = np.arange(z - r / 2 - half, z + r / 2 + half, voxel_size) + + # build full grid of candidate centers + xx, yy, zz = np.meshgrid(x_range, y_range, z_range, indexing="ij") + pts = np.stack((xx, yy, zz), axis=-1).reshape(-1, 3) + + # keep only those inside the sphere of radius r + d2 = np.sum((pts - np.array(arm_base)) ** 2, axis=1) + return pts[d2 <= r**2] + + def _generate_uniform_directions(self, num_directions: int = 50): + """ + Generate vectors in evenly distributed n directions + """ + phi = np.pi * (3.0 - np.sqrt(5.0)) + directions = [] + for i in range(num_directions): + z = 1 - 2 * i / float(num_directions - 1) + theta = phi * i + x = np.sqrt(1 - z * z) * np.cos(theta) + y = np.sqrt(1 - z * z) * np.sin(theta) + directions.append(np.array([x, y, z])) + + return directions + + # Helper function + def normalize(self, v: np.ndarray) -> np.ndarray: + """Normalize a vector to unit length.""" + norm = np.linalg.norm(v) + if norm == 0: + return v # Avoid division by zero + return v / norm + + def _compute_ik_solutions( + self, + centers: List[np.ndarray], + directions: List[np.ndarray], + voxel_size: float, + num_yaws: int, + pos_eps: float = 2e-4, + rot_eps: float = 2e-4, + max_iterations: int = 1500, + num_samples: int = 5, + ) -> List[np.ndarray]: + """ + Compute IK solutions for a set of centers and directions. + This function will process the centers and directions in batches if `get_batch_ik_solution` is available. + + Args: + centers: List of center positions to compute IK for. + directions: List of direction vectors to compute IK for. + voxel_size: Size of the voxel to offset the centers. + num_yaws: Number of yaw sweeps to attempt. + robot_base: Transformation matrix of the robot base. + yaw_rot: Rotation matrix for yaw rotation. + + Returns: + List[np.ndarray]: List of valid IK poses. + """ + valid_poses = [] + success_counts = [0] * len(centers) + + # Create progress bar + pbar = tqdm(total=len(centers), ncols=100, desc="Computing IK (per-center)") + + yaw_angle = 360.0 / num_yaws + yaw_rot = R.from_euler("z", yaw_angle, degrees=True).as_matrix() + robot_base = self.robot.get_base_xpos(name=self.control_part) + + # Check if self.robot has the method get_batch_ik_solution + if hasattr(self.robot, "get_batch_ik_solution"): + # If get_batch_ik_solution is available, we process in batches + for i, center in enumerate(centers): + batch_positions = [] + batch_xpos = [] + + for d in directions: + # Build local frame so that its Z-axis = -d + z_axis = -d + up = ( + np.array([0, 1, 0]) + if abs(z_axis[1]) < 0.9 + else np.array([1, 0, 0]) + ) + x_axis = self.normalize(np.cross(up, z_axis)) + y_axis = np.cross(z_axis, x_axis) + frame = np.stack([x_axis, y_axis, z_axis], axis=1) + + # Shift out to the surface of the voxel + pos = center + d * (voxel_size * 0.5) + + # Try yaw sweeps + for _ in range(num_yaws): + frame = frame @ yaw_rot + xpos = np.eye(4) + xpos[:3, :3] = frame + xpos[:3, 3] = pos + xpos_robot = np.linalg.inv(robot_base) @ xpos + + # Prepare batch for IK computation + batch_positions.append(pos) + batch_xpos.append(xpos_robot) + + # Convert lists to numpy arrays (batch_size, 4, 4) + batch_xpos_array = np.array(batch_xpos) + + # Create the joint seed batch (batch_size, dof) + dof_number = self.robot.get_dof(self.control_part) + joint_seed_batch = np.zeros((len(batch_xpos), dof_number)) + + # Use get_batch_ik_solution for batch processing + res, _ = self.robot.get_batch_ik_solution( + target_xpos_list=batch_xpos_array, # Batch of target poses + joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) + uid=self.control_part, + is_world_coordinates=False, # Set based on your use case + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + ) + + # Append valid target poses to valid_poses + for j, is_valid in enumerate(res): + if is_valid: + success_counts[i] += 1 + valid_poses.append(batch_xpos_array[j]) + + # Update the progress bar after processing the batch + pbar.update(1) + + else: + # Fallback to the previous method (get_ik) if get_batch_ik_solution is not available + for i, center in enumerate(centers): + for d in directions: + # Build local frame so that its Z-axis = -d + z_axis = -d + up = ( + np.array([0, 1, 0]) + if abs(z_axis[1]) < 0.9 + else np.array([1, 0, 0]) + ) + x_axis = self.normalize(np.cross(up, z_axis)) + y_axis = np.cross(z_axis, x_axis) + frame = np.stack([x_axis, y_axis, z_axis], axis=1) + + # Shift out to the surface of the voxel + pos = center + d * (voxel_size * 0.5) + + # Try yaw sweeps + for _ in range(num_yaws): + frame = frame @ yaw_rot + xpos = np.eye(4) + xpos[:3, :3] = frame + xpos[:3, 3] = pos + xpos_robot = np.linalg.inv(robot_base) @ xpos + + # Calculate IK using the old method (get_ik) + is_success, _ = self.robot.get_ik( + xpos=xpos_robot, uid=self.control_part + ) + if is_success: + success_counts[i] += 1 + valid_poses.append(xpos_robot.copy()) + break # stop yaw for this direction + + pbar.update(1) + + logger.log_info(f"Sampling complete: {sum(success_counts)} valid positions.") + + return success_counts, valid_poses + + def _sample_voxels_memory_mode( + self, + voxel_size: float, + num_directions: int, + num_yaws: int, + arm_base: np.ndarray, + arm_length: float, + pos_eps: float, + rot_eps: float, + max_iterations: int, + num_samples: int, + ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray]]: + + dirs = self._generate_uniform_directions(num_directions) + centers = self._voxel_centers_in_sphere(arm_base, arm_length, voxel_size) + + success_counts, ik_matrices = self._compute_ik_solutions( + centers, + dirs, + voxel_size, + num_yaws, + pos_eps, + rot_eps, + max_iterations, + num_samples, + ) + + return centers, success_counts, ik_matrices + + def _sample_voxels_disk_mode( + self, + voxel_size: float, + num_directions: int, + num_yaws: int, + arm_base: np.ndarray, + arm_length: float, + pos_eps: float, + rot_eps: float, + max_iterations: int, + num_samples: int, + save_dir: str, + batch_size: int, + save_threshold: int, + use_cached: bool = True, + ) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]: + """ + Returns: + centers: (M,3) np.ndarray of voxel centers + success_counts: (M,) np.ndarray of ints + valid_poses: list of 4x4 np.ndarrays + """ + counts_file = os.path.join(save_dir, "success_counts.npy") + batches_dir = os.path.join(save_dir, "batches") + + # 1) generate dirs & centers + dirs = self._generate_uniform_directions(num_directions) + centers = self._voxel_centers_in_sphere(arm_base, arm_length, voxel_size) + + # 2) if already computed, load & return + if os.path.isdir(batches_dir) and os.path.exists(counts_file) and use_cached: + npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] + if npy_files: + success_counts = np.load(counts_file) + valid_poses = self._merge_batch_files(save_dir, len(npy_files)) + return centers, success_counts, valid_poses + + os.makedirs(batches_dir, exist_ok=True) + + # 3) run IK sweep + success_counts, valid_poses = self._compute_ik_solutions( + centers, + dirs, + voxel_size, + num_yaws, + pos_eps, + rot_eps, + max_iterations, + num_samples, + ) + if success_counts.sum() == 0: + return centers, success_counts, [] + + # 4) save counts + np.save(counts_file, success_counts) + + # 5) batch & save using a local temp buffer + temp_valid = [] + valid_block = [] + batch_count = 0 + + for pose in valid_poses: + # collect into small blocks of batch_size + valid_block.append(pose) + if len(valid_block) >= batch_size: + # move into temp_valid + temp_valid.extend(valid_block) + valid_block = [] + + # once buffer reaches save_threshold, flush to disk + if len(temp_valid) >= save_threshold: + self._save_batch_results(temp_valid, save_dir, batch_count) + batch_count += 1 + temp_valid = [] + gc.collect() + + # move any remaining block into temp_valid + if valid_block: + temp_valid.extend(valid_block) + + # final flush of anything left in temp_valid + if temp_valid: + self._save_batch_results(temp_valid, save_dir, batch_count) + batch_count += 1 + + # 6) merge all batch files and return + all_poses = self._merge_batch_files(save_dir, batch_count) + return centers, success_counts, all_poses + + +def compute_xpos_reachability( + robot: Robot, + name: str, + ref_xpos: np.ndarray, + xpos_resolution: float = 0.2, + qpos_resolution: float = np.radians(60), + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1500, + num_samples: int = 5, + batch_size: int = 100000, + save_threshold: int = 10000000, + qpos_limits: np.ndarray = None, + cache_mode: str = "disk", + visualize: bool = True, + use_cached: bool = True, + **kwargs, +) -> Tuple[ + Optional[list[np.ndarray]], # First return: list of sampled 4x4 poses + Optional[ + dexsim.models.PointCloud + ], # Second return: point cloud handle if visualization is enabled +]: + """Compute the robot's reachable workspace by Cartesian space sampling. + + Samples points in Cartesian space and checks reachability using inverse kinematics. + If `visualize` is True, visualizes reachable positions as a colored point cloud; + Otherwise, only performs the sampling result as open3d PointCloud. + + + Args: + name (str): Identifier of the robot drive controller to analyze + ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the + orientation for IK solutions + xpos_resolution (float, optional): Cartesian space sampling resolution in meters. + Smaller values provide finer sampling but increase + computation time. Defaults to 0.2 meters. + qpos_resolution (float, optional): Angular resolution for initial joint space + sampling in radians. Used to determine workspace + bounds. Defaults to 60 degrees. + pos_eps (float, optional): Position tolerance for IK solutions in meters. + Defaults to 2e-4 meters. + rot_eps (float, optional): Rotation tolerance for IK solutions in radians. + Defaults to 2e-4 radians. + max_iterations (int, optional): Maximum number of IK iterations per sample. + Defaults to 2000. + num_samples (int, optional): Number of samples to generate in Cartesian space. + Defaults to 10. + qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). + If None, uses limits from drive controller or + articulation. Defaults to None + cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". + Defaults to "memory". + visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. + Defaults to True. + use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. + Ignored for "memory" mode. Defaults to True. + + Returns: + Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: + The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. + The second element is a point cloud handle if visualization is enabled and successful, otherwise None. + """ + from embodichain.lab.sim import REACHABLE_XPOS_DIR + from dexsim.utility.env_utils import create_point_cloud_from_o3d_pcd + from dexsim.utility import inv_transform + + if name not in robot.control_parts: + logger.log_warning(f"Drive controller '{name}' not found") + return None, None + + # try: + # Get robot configuration + # base_xpos = robot.get_control_part_base_pose(name=name, to_matrix=True).squeeze(0).cpu().numpy() + # ref_xpos_robot = inv_transform(base_xpos) @ ref_xpos + ref_xpos_robot = ref_xpos + + if qpos_limits is None: + joint_ranges = ( + robot.body_data.qpos_limits[0].cpu().numpy()[robot.get_joint_ids(name=name)] + ) + else: + joint_ranges = qpos_limits + + urdf_path = robot.cfg.fpath + robot_name = os.path.splitext(os.path.basename(urdf_path))[0] + + qpos_resolution_str = f"{qpos_resolution:.2f}".replace(".", "_") + xpos_resolution_str = f"{xpos_resolution:.2f}".replace(".", "_") + # Join into one directory name + save_dir = ( + REACHABLE_XPOS_DIR + / f"{robot_name}_{name}_{qpos_resolution_str}_{xpos_resolution_str}" + ) + + # Initialize workspace analyzer + analyzer = WorkspaceAnalyzer( + robot=robot, + name=name, + resolution=qpos_resolution, + joint_ranges=joint_ranges, + ) + # Sample workspace points + sampled_xpos = analyzer.sample_xpos_workspace( + ref_xpos=ref_xpos_robot, + xpos_resolution=xpos_resolution, + qpos_resolution=qpos_resolution, + cache_mode=cache_mode, + batch_size=batch_size, + save_dir=save_dir, + save_threshold=save_threshold, + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + use_cached=use_cached, + ) + + if visualize: + if sampled_xpos is None: + logger.log_warning("No reachable positions found.") + return None, None + all_positions = [xpos[:3, 3] for xpos in sampled_xpos] + pcd = analyzer._process_point_cloud( + positions=all_positions, is_voxel_down=False + ) + # Transfer to World Coordinate + # pcd.transform(base_xpos) + # Create and configure point cloud visualization + from embodichain.lab.sim.utility.sim_utils import get_dexsim_arenas + + pcd_handle = create_point_cloud_from_o3d_pcd( + pcd=pcd, env=get_dexsim_arenas()[0] + ) + else: + return sampled_xpos, None + + return sampled_xpos, pcd_handle From 8411a357a737f958dd51e40579acd5a41bac04a4 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 15:37:53 +0800 Subject: [PATCH 19/49] Migrate necessory files in toolkits --- embodichain/toolkits/code_generation.py | 80 +++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 embodichain/toolkits/code_generation.py diff --git a/embodichain/toolkits/code_generation.py b/embodichain/toolkits/code_generation.py new file mode 100644 index 0000000..c222b8a --- /dev/null +++ b/embodichain/toolkits/code_generation.py @@ -0,0 +1,80 @@ +from typing import List, Dict, Tuple, Any +from langchain_core.output_parsers import BaseOutputParser +from pygments import highlight +from pygments.lexers import PythonLexer +from pygments.formatters import TerminalFormatter + +import numpy as np + + +def merge_dicts(dicts: Dict): + return {k: v for d in dicts for k, v in d.items()} + + +def get_executable_code_str(input_string, language="python"): + start_marker = f"```{language}" + end_marker = f"```" + if input_string.find(start_marker) >= 0: + + start_index = input_string.find(start_marker) + len(start_marker) + end_index = input_string.rfind(end_marker) + + code_string = input_string[start_index:end_index].strip() + else: + code_string = input_string + + return code_string + + +class OutputFormatting: + @staticmethod + def flatten_dict(output: Dict[str, Dict]) -> Dict[str, np.ndarray]: + ret = {} + for _, val in output.items(): + ret.update(val) + return ret + + +class ExecutableOutputParser(BaseOutputParser): + # https://python.langchain.com/v0.1/docs/modules/model_io/output_parsers/custom/ + + _fixed_vars = {"np": np} + variable_vars = {} + + def update_vars(self, variable_vars: Dict): + self.variable_vars = variable_vars + + def parse(self, text: str) -> Tuple[str, Dict, Dict]: + code_str = get_executable_code_str(text) + # if self._cfg["include_context"] and context != "": + # to_exec = f"{context}\n{code_str}" + # to_log = f"{context}\n{use_query}\n{code_str}" + # else: + # to_exec = code_str + # to_log = f"{use_query}\n{to_exec}" + + # to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) + # print( + # f"\033[34m====================================================\nLMP {self._name} exec:\033[0m\n\n{to_log_pretty}\n\n\033[34m====================================================\n\033[0m" + # ) + + # generate new functions + # new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) + # self._variable_vars.update(new_fs) + + gvars = merge_dicts([self._fixed_vars, self.variable_vars]) + lvars = None + + # + # banned_phrases = ["import", "__"] + # for phrase in banned_phrases: + # assert phrase not in code_str + + if gvars is None: + gvars = {} + if lvars is None: + lvars = {} + empty_fn = lambda *args, **kwargs: None + custom_gvars = merge_dicts([gvars, {"exec": empty_fn, "eval": empty_fn}]) + + return code_str, custom_gvars, lvars From 64642ca3bb3c00a5db4f07947adb6a0a8ac5d4ae Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 16:29:00 +0800 Subject: [PATCH 20/49] Config LLM --- embodichain/agents/hierarchy/llm.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py index a38463d..06f3c6f 100644 --- a/embodichain/agents/hierarchy/llm.py +++ b/embodichain/agents/hierarchy/llm.py @@ -1,5 +1,5 @@ import os -from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_openai import AzureChatOpenAI # ------------------------------------------------------------------------------ # Environment configuration @@ -7,10 +7,11 @@ os.environ["ALL_PROXY"] = "" os.environ["all_proxy"] = "" -os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" -os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" +#os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" +#os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" os.environ["OPENAI_API_VERSION"] = "2024-10-21" os.environ["AZURE_OPENAI_ENDPOINT"] = "YOUR_ENDPOINT_HERE" +os.environ["AZURE_OPENAI_API_KEY"] = "YOUR_API_KEY_HERE" # ------------------------------------------------------------------------------ # LLM factory @@ -18,11 +19,12 @@ def create_llm(*, temperature=0.0, model="gpt-4o"): - return ChatOpenAI( + return AzureChatOpenAI( temperature=temperature, model=model, - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_BASE_URL"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("OPENAI_API_VERSION", "2024-10-21"), ) @@ -31,6 +33,6 @@ def create_llm(*, temperature=0.0, model="gpt-4o"): # ------------------------------------------------------------------------------ task_llm = create_llm(temperature=0.0, model="gpt-4o") -code_llm = create_llm(temperature=0.0, model="gemini-2.5-flash-lite") -validation_llm = create_llm(temperature=0.0, model="gemini-3-flash-preview") -view_selection_llm = create_llm(temperature=0.0, model="gemini-2.5-flash-lite") +code_llm = create_llm(temperature=0.0, model="gpt-4o") +validation_llm = create_llm(temperature=0.0, model="gpt-4o") +view_selection_llm = create_llm(temperature=0.0, model="gpt-4o") From c3898b6a9e21de8dbbc29e745bd91d5960f645b2 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 16:29:11 +0800 Subject: [PATCH 21/49] Extend enum --- embodichain/data/enum.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index cf75819..ad199f0 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -33,6 +33,16 @@ class SemanticMask(IntEnum): FOREGROUND = 1 ROBOT = 2 +class Modality(Enum): + STATES = "states" + STATE_INDICATOR = "state_indicator" + ACTIONS = "actions" + ACTION_INDICATOR = "action_indicator" + IMAGES = "images" + LANG = "lang" + LANG_INDICATOR = "lang_indicator" + GEOMAP = "geomap" # e.g., depth, point cloud, etc. + VISION_LANGUAGE = "vision_language" # e.g., image + lang class EndEffector(Enum): GRIPPER = "gripper" @@ -74,3 +84,17 @@ class EefType(Enum): class ActionMode(Enum): ABSOLUTE = "" RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. + +class PrivilegeType(Enum): + EXTEROCEPTION = "exteroception" + MASK = "mask" + STATE = "state" + PROGRESS = "progress" + +class ArmEnum(IntEnum): + LEFT_ARM_ONLY = 1 + RIGHT_ARM_ONLY = 2 + DUAL_ARM = 3 + +def is_dual_arms(dofs: int) -> bool: + return dofs > 10 \ No newline at end of file From 283954837902f94da4793d6faf69af2aa1907016 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 16:29:34 +0800 Subject: [PATCH 22/49] Add direction of database --- embodichain/data/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/embodichain/data/__init__.py b/embodichain/data/__init__.py index 9e152ab..1e14543 100644 --- a/embodichain/data/__init__.py +++ b/embodichain/data/__init__.py @@ -14,5 +14,12 @@ # limitations under the License. # ---------------------------------------------------------------------------- +import os + + +database_dir = os.path.dirname(os.path.abspath(__file__)).replace("data", "database") +database_2d_dir = os.path.join(database_dir, "2dasset") +database_agent_prompt_dir = os.path.join(database_dir, "agent_prompt") + from . import assets from .dataset import * From 18cfc0327613b3db0d71dbfe8a1262b5ee068dbf Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:05:47 +0800 Subject: [PATCH 23/49] Update enum --- embodichain/data/enum.py | 105 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index ad199f0..920bb3a 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -62,6 +62,14 @@ class ControlParts(Enum): HEAD = "head" WAIST = "waist" +class ControlPartsMappingW1(Enum): + ANKLE_IN_TORSO = 0 + KNEE_IN_TORSO = 1 + BUTTOCK_IN_TORSO = 2 + WAIST_IN_TORSO = 3 + + NECK1_IN_HEAD = 0 + NECK2_IN_HEAD = 1 class Hints(Enum): EEF = ( @@ -91,10 +99,105 @@ class PrivilegeType(Enum): STATE = "state" PROGRESS = "progress" +SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + EefType.POSE.value, + ControlParts.RIGHT_ARM.value + EefType.POSE.value, + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.HEAD.value + JointType.QPOS.value, + ControlParts.WAIST.value + JointType.QPOS.value, +] +SUPPORTED_ACTION_TYPES = SUPPORTED_PROPRIO_TYPES + [ + ControlParts.LEFT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, +] +SUPPORTED_EXTRA_VISION_TYPES = [ + Modality.GEOMAP.value, + PrivilegeType.EXTEROCEPTION.value, + PrivilegeType.MASK.value, +] + class ArmEnum(IntEnum): LEFT_ARM_ONLY = 1 RIGHT_ARM_ONLY = 2 DUAL_ARM = 3 +class ArmName(Enum): + LEFT_ARM_ONLY = "left_arm" + RIGHT_ARM_ONLY = "right_arm" + def is_dual_arms(dofs: int) -> bool: - return dofs > 10 \ No newline at end of file + return dofs > 10 + +class HandQposNormalizer: + """ + A class for normalizing and denormalizing dexterous hand qpos data. + """ + + def __init__(self): + pass + + @staticmethod + def normalize_hand_qpos( + qpos_data: np.ndarray, + key: str, + agent=None, + robot=None, + ) -> np.ndarray: + """ + Clip and normalize dexterous hand qpos data. + + Args: + qpos_data: Raw qpos data + key: Control part key + agent: LearnableRobot instance (for V2 API) + robot: Robot instance (for V3 API) + + Returns: + Normalized qpos data in range [0, 1] + """ + if isinstance(qpos_data, torch.Tensor): + qpos_data = qpos_data.cpu().numpy() + + if agent is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + return qpos_data + indices = agent.get_data_index(key, warning=False) + full_limits = agent.get_joint_limits(agent.uid) + limits = full_limits[indices] # shape: [num_joints, 2] + elif robot is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + if key in [ControlParts.LEFT_EEF.value, ControlParts.RIGHT_EEF.value]: + # Note: In V3, robot does not distinguish between GRIPPER EEF and HAND EEF in uid, + # _data_key_to_control_part maps both to EEF. Under current conditions, normalization + # will not be performed. Please confirm if this is intended. + pass + return qpos_data + indices = robot.get_joint_ids(key, remove_mimic=True) + limits = robot.body_data.qpos_limits[0][indices] # shape: [num_joints, 2] + else: + raise ValueError("Either agent or robot must be provided") + + if isinstance(limits, torch.Tensor): + limits = limits.cpu().numpy() + + qpos_min = limits[:, 0] # Lower limits + qpos_max = limits[:, 1] # Upper limits + + # Step 1: Clip to valid range + qpos_clipped = np.clip(qpos_data, qpos_min, qpos_max) + + # Step 2: Normalize to [0, 1] + qpos_normalized = (qpos_clipped - qpos_min) / (qpos_max - qpos_min + 1e-8) + + return qpos_normalized \ No newline at end of file From 2ede9a078346234ec45937746683d4f664a51954 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:18:22 +0800 Subject: [PATCH 24/49] Migrate end effector --- embodichain/lab/sim/end_effector/__init__.py | 9 + .../lab/sim/end_effector/end_effector.py | 552 ++++++++++++++++++ embodichain/lab/sim/end_effector/utility.py | 148 +++++ 3 files changed, 709 insertions(+) create mode 100644 embodichain/lab/sim/end_effector/__init__.py create mode 100644 embodichain/lab/sim/end_effector/end_effector.py create mode 100644 embodichain/lab/sim/end_effector/utility.py diff --git a/embodichain/lab/sim/end_effector/__init__.py b/embodichain/lab/sim/end_effector/__init__.py new file mode 100644 index 0000000..69ee8c8 --- /dev/null +++ b/embodichain/lab/sim/end_effector/__init__.py @@ -0,0 +1,9 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +from .end_effector import EndEffector +from .utility import * + +del end_effector, utility diff --git a/embodichain/lab/sim/end_effector/end_effector.py b/embodichain/lab/sim/end_effector/end_effector.py new file mode 100644 index 0000000..708e6ce --- /dev/null +++ b/embodichain/lab/sim/end_effector/end_effector.py @@ -0,0 +1,552 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +import typing +import dexsim.engine +import numpy as np +import dexsim.environment +from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag, ActorType + +from embodichain.lab.sim.end_effector.utility import ( + load_model_from_file, + inv_transform, +) +from abc import ABC, abstractmethod +from embodichain.lab.sim.articulation_entity import ArticulationEntity +from embodichain.utils import logger +import dexsim +import time + + +class EndEffector(ArticulationEntity, ABC): + r""" + Abstract class for end effector in simulation. + """ + + def __init__( + self, + env: dexsim.environment.Arena, + file: str, + drive_type: DriveType = DriveType.FORCE, + **kwargs, + ) -> None: + """init end effector + + Args: + env (dexsim.environment.Arena): dexsim environment. + file (str): input file (urdf or mesh file) + drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. Defaults to DriveType.FORCE. + kwargs(optional): Accepts additional keyword arguments. + """ + urdf_path = load_model_from_file(file_path=file) + + super().__init__( + urdf_path=urdf_path, + init_qpos=None, + init_base_xpos=np.eye(4), + speed_ratio=0.5, + time_step=0.02, + drive_type=drive_type, + env=env, + **kwargs, + ) + + self._init_end_effector(**kwargs) + + self.articulation.set_physical_attr(self.default_physical_attrs) + self.articulation.set_drive( + drive_type=self.drive_type, **self.default_drive_param + ) + + @abstractmethod + def _init_end_effector(self, **kwargs) -> None: + r"""Initializes the robot using the URDF path with necessary parameters.""" + pass + + def _set_ee_control_data(self, **kwargs): + self._dof = self.articulation.get_dof() + self._actived_joint_names = self.articulation.get_actived_joint_names() + self._root_link_name = self.articulation.get_root_link_name() + self._attached_nodes = dict() # {node_name: [dexsim.engine.Node, ActorType]} + self._leaf_link_names = self.articulation.get_leaf_link_names() + + if self._dof > 0: + # ignore mimic information for 0-dof articulation + self._joint_ids[self.uid] = np.arange(self._dof) + self._joint_limit = self.articulation.get_joint_limits() + self._set_mimic() + + self.attach_robot_uid = None # if end-effector is attach to robot. + + # KWARGS. If true, set object to be dynamic when release object, otherwise do nothing. + self._is_release_dynamic = kwargs.get("is_release_dynamic", True) + + # open state sample num + self._open_state_sample_num = kwargs.get("open_state_sample_num", 30) + + # open state and close state + self.open_state = np.array( + [ + 1.0, + ] + ) + self.close_state = np.array( + [ + 0.0, + ] + ) + + @property + def actived_joint_names(self) -> typing.List[str]: + return self._actived_joint_names + + def _set_to_init_qpos(self): + self._init_qpos = np.array([]) + if self._dof > 0: + self._init_qpos = self._joint_limit[:, 0] + self.articulation.set_current_qpos( + self._init_qpos, self._joint_ids[self.uid] + ) + + def get_init_qpos(self) -> np.ndarray: + return self._init_qpos + + @property + def release_dynamic(self) -> bool: + """get is release dynamic + + Returns: + bool: If true, set object to be dynamic when release object, otherwise do nothing. + """ + return self._is_release_dynamic + + @release_dynamic.setter + def release_dynamic(self, is_release_dynamic: bool): + """set is release dynamic + + Args: + is_release_dynamic (bool): If true, set object to be dynamic when release object, otherwise do nothing. + """ + self._is_release_dynamic = is_release_dynamic + + def _set_mimic(self) -> None: + r"""Sets up the mimic configuration for the articulation. + + Attributes Updated: + - self._mimic_joint_ids: Array of joint IDs that are mimicked. + - self._mimic_master_ids: Array of master joint IDs that control the mimicked joints. + - self._mimic_multipliers: Array of multipliers for the mimicked joints. + - self._mimic_offsets: Array of offsets for the mimicked joints. + - self._control_joint_ids: Array of joint IDs that are not mimicked and can be controlled. + - self._control_limit: Joint limits for the controllable joints. + - self._control_num: Number of controllable joints. + """ + mimic_info = self.articulation.get_mimic_info() + + self._mimic_joint_ids = mimic_info.mimic_id + self._mimic_master_ids = mimic_info.mimic_parent + self._mimic_multipliers = mimic_info.mimic_multiplier + self._mimic_offsets = mimic_info.mimic_offset + + # Using set for faster membership testing + mimic_joint_set = set(self._mimic_joint_ids) + + # List comprehension for better readability and performance + self._control_joint_ids = np.array( + [i for i in range(self._dof) if i not in mimic_joint_set] + ) + + self._control_limit = self._joint_limit[self._control_joint_ids] + self._control_num = self._control_joint_ids.shape[0] + + def _qpos_to_control_state(self, qpos: np.ndarray) -> np.ndarray: + """full joint state to control joint state + + Args: + qpos (np.ndarray): [dof] of float. Full joint state. + + Returns: + np.ndarray: [control_joint_num] of float. control joint state + """ + return qpos[self._control_joint_ids] + + def _control_state_to_qpos(self, control_state: np.ndarray) -> np.ndarray: + """control joint state to full joint state + + Args: + control_state (np.ndarray): [control_joint_num] of float. control joint state + + Returns: + np.ndarray: [dof] of float. Full joint state. + """ + qpos = np.empty(shape=(self._dof,), dtype=float) + qpos[self._control_joint_ids] = control_state + qpos[self._mimic_joint_ids] = ( + qpos[self._mimic_master_ids] * self._mimic_multipliers + self._mimic_offsets + ) + return qpos + + def _qpos_to_control_state_path(self, qpos_path: np.ndarray): + return qpos_path[:, self._control_joint_ids] + + def _control_state_to_qpos_path(self, control_state_path: np.ndarray): + waypoint_num = control_state_path.shape[0] + qpos_path = np.empty(shape=(waypoint_num, self._dof), dtype=float) + qpos_path[:, self._control_joint_ids] = control_state_path + qpos_path[:, self._mimic_joint_ids] = ( + qpos_path[:, self._mimic_master_ids] * self._mimic_multipliers + + self._mimic_offsets + ) + return qpos_path + + def _to_arena_pose(self, pose: np.ndarray) -> np.ndarray: + return inv_transform(self._env.get_root_node().get_world_pose()) @ pose + + def get_xpos(self) -> np.ndarray: + """get gripper root link pose + + Returns: + np.ndarray: [4, 4] of float. root link 6d pose + """ + return self._to_arena_pose( + self.articulation.get_link_pose(self._root_link_name) + ) + + def set_xpos(self, pose: np.ndarray) -> None: + """directly set gripper world pose + + Args: + pose (np.ndarray): [4, 4] of float. root link 6d pose + """ + # TODO: When gripper attach to robot base, this function result can be wild. + assert pose.shape == (4, 4) + self.set_world_pose(self._to_arena_pose(pose)) + + def set_world_pose(self, pose: np.ndarray) -> None: + """Set the world pose of the end effector.""" + assert pose.shape == (4, 4), "Pose must be a 4x4 transformation matrix." + self.articulation.set_world_pose(pose) + + def get_qpos(self) -> np.ndarray: + """get robot joint state array + + Returns: + np.ndarray: (joint_num, ) of float. joint state array + """ + return np.array(self.articulation.get_current_qpos(self._joint_ids[self.uid])) + + def set_qpos(self, qpos: np.ndarray) -> None: + """set gripper joint state array + + Args: + qpos (np.ndarray): (joint_num, ) of float. joint state array + """ + assert qpos.shape == (self._dof,) + self.articulation.set_current_qpos(qpos, self._joint_ids[self.uid]) + + def get_control_qpos(self) -> np.ndarray: + """get control joint state + + Returns: + np.ndarray: (control_joint_num, ) of float. + """ + return self._qpos_to_control_state(self.get_qpos()) + + def set_control_qpos(self, control_state: np.ndarray) -> None: + """set control joint state + + Args: + control_state (np.ndarray): (control_joint_num, ) of float + """ + assert control_state.shape == self._control_joint_ids.shape + qpos = self._control_state_to_qpos(control_state) + self.articulation.set_current_qpos(qpos, self._joint_ids[self.uid]) + + def move_qpos(self, qpos_path: np.ndarray, is_wait=True, move_time: float = 1): + assert qpos_path.shape[1] == self._dof + self.move_joints( + qpos_path, + is_wait=is_wait, + joint_ids=self._joint_ids[self.uid], + move_time=move_time, + ) + + def get_leaf_link_pose(self) -> dict: + """get leaf link pose. + + Returns: + dict: {"link_name", np.ndarray [4, 4]} pose of each leaf link + """ + leaf_link_poses = dict() + for leaf_link_name in self._leaf_link_names: + leaf_link_pose = self.articulation.get_link_pose(leaf_link_name) + leaf_link_poses[leaf_link_name] = leaf_link_pose + return leaf_link_poses + + def get_leaf_contact(self, is_flatten: bool = False) -> dict: + """Get leaf link contacts. + Leaf link: 1. has physical body; 2. no child link; 3. parent link is not fixed. + + Args: + is_flatten (bool): get flatten + + Returns: + is_flatten == False: + dict: { + "link_name": { + "nodes": [dexsim.engine.Node, ...], + "contact_positions": [link_contact_num, 3] of float. np.ndarray, + "contact_normals": [link_contact_num, 3] of float. np.ndarray, + "contact_distances": [link_contact_num] of float. np.ndarray, + }, + ... + } + + is_flatten == True: + ContactInfo + + ContactInfo.nodes(List[dexsim.engine.Node]): List of Contact object node ptr + ContactInfo.link_name(List[str]): List of contact link name + ContactInfo.contact_positions(np.ndarray): [contact_num, 3] of float, matrix of contact_positions. + ContactInfo.contact_normals(np.ndarray): [contact_num, 3] of float, matrix of contact normal. + ContactInfo.contact_distances(np.ndarray): [contact_num] of float. Contact distance. Negetive for peneration and postive for surface distance. + """ + contact_info = self.articulation.get_leaf_contacts() + if is_flatten: + return contact_info + link_contact_all_id = np.arange(len(contact_info.nodes)) + + contact_info_dict = dict() + # Tricky implementation. save str ing np.ndarray, and select link name by mask + contact_link_names = np.array(contact_info.link_name) + contact_link_name_unique = np.unique(contact_link_names) + # unpack contact info + for link_name in contact_link_name_unique: + contact_info_dict[link_name] = dict() + link_contact_mask = contact_link_names == link_name + link_contact_ids = link_contact_all_id[link_contact_mask] + contact_info_dict[link_name]["nodes"] = [] + for link_contact_idx in link_contact_ids: + contact_info_dict[link_name]["nodes"].append( + contact_info.nodes[link_contact_idx] + ) + contact_info_dict[link_name][ + "contact_positions" + ] = contact_info.contact_positions[link_contact_ids] + contact_info_dict[link_name][ + "contact_normals" + ] = contact_info.contact_normals[link_contact_ids] + contact_info_dict[link_name][ + "contact_distances" + ] = contact_info.contact_distances[link_contact_ids] + return contact_info_dict + + def get_cpp_articulation(self): + return self.articulation + + def attach(self, node: dexsim.engine.Node) -> str: + """attach certain actor to current end-effector + (will attach to root link) + + Args: + node (dexsim.engine.Node): dexsim actor + + Returns: + str: Name of the attached actor, return none str if will attach wrong actor. + """ + node_name = node.get_name() + original_actor_type = node.get_actor_type() + + if original_actor_type == ActorType.STATIC: + logger.log_info( + "Skipping attachment to static object, its name: {}.".format(node_name) + ) + return "" + if original_actor_type == ActorType.DYNAMIC: + # TODO: tricky implemetation. Fix dynamic actor to kinematic + node.set_actor_type(ActorType.KINEMATIC) + # node.enable_collision(False) + + node_pose = node.get_local_pose() + self_pose = self.get_xpos() + relative_pose = inv_transform(self_pose) @ node_pose + + self.articulation.attach_node( + obj=node, link_name=self._root_link_name, relative_pose=relative_pose + ) + + self._attached_nodes[node_name] = [node, original_actor_type] + return node_name + + def detach(self, node_name: str) -> bool: + """detach certain actor to current suctor + + Args: + actor (dexsim.models.Entity): dexsim actor + + Returns: + bool: is_success + """ + if node_name in self._attached_nodes: + node = self._attached_nodes[node_name][0] + original_actor_type = self._attached_nodes[node_name][1] + arena_root_node = self._env.get_root_node() + node.attach_node(arena_root_node) + if original_actor_type != ActorType.STATIC and self._is_release_dynamic: + node.set_actor_type(ActorType.DYNAMIC) + # node.enable_collision(True) + self._attached_nodes.pop(node_name) + return True + else: + logger.log_warning(f"Actor {node_name} to be detach is not attached yet.") + return False + + @abstractmethod + def get_control_state(self, **kwargs) -> np.ndarray: + """get control state of end-effector + + Returns: + np.ndarray: [state_dof] of float. Control state array + """ + + @abstractmethod + def get_open_state(self, **kwargs) -> np.ndarray: + """get control state of end-effector + + Returns: + np.ndarray: [state_dof] of float. Open state array + """ + + @abstractmethod + def set_open_state(self, open_state: np.ndarray, **kwargs): + """set control state of end-effector + + Args: + open_state (np.ndarray): [state_dof] of float. Open state + """ + + def to_target_open_state_path( + self, + target_open_state: np.ndarray, + start_open_state: np.ndarray = None, + step_num: int = None, + step_size: float = None, + **kwargs, + ) -> np.ndarray: + """Generate a path from the start open state to the target open state for a gripper or a robotic hand. + + An "open state" refers to the configuration of the gripper or robotic hand at a given moment, + which can include the positions of fingers, joints, and any gripping mechanisms. + The "target state" is the desired configuration that the gripper or hand should achieve after + the motion, typically used for grasping or releasing an object. + + Args: + target_open_state (np.ndarray): Target open state, shape [state_dof]. + start_open_state (np.ndarray, optional): Starting open state, shape [state_dof]. Default is None, which uses the current open state. + step_num (int, optional): Number of interpolation points. Default is None. + step_size (float, optional): Step size for interpolation. Default is None. + + Returns: + np.ndarray: Path as an array of shape [waypoint_num, state_dof]. + """ + + if start_open_state is None: + start_open_state = self.get_open_state() + + if step_num is not None and step_size is not None: + logger.log_warning( + "Please provide either 'step_num' or 'step_size', not both." + ) + return [] + + if step_num is not None: + step_num = max(step_num, 1) + elif step_size is not None: + distance = np.linalg.norm(target_open_state - start_open_state) + step_num = int(np.ceil(distance / step_size)) + else: + state_range = np.abs(start_open_state - target_open_state).max() + step_num = int(np.round(self._open_state_sample_num * state_range)) + + open_state_path = np.linspace(start_open_state, target_open_state, step_num) + + return open_state_path + + def open(self, **kwargs): + """open end-effector. only for demo""" + if self._world is not None: + if self._world.is_physics_manually_update(): + logger.log_warning("Cannot call open in physics manually update mode.") + return + open_state_path = self.to_target_open_state_path(self.open_state) + for i in range(open_state_path.shape[0]): + self.set_open_state(open_state_path[i]) + time.sleep(0.02) + + def close(self, **kwargs): + """close end-effector. only for demo""" + if self._world is not None: + if self._world.is_physics_manually_update(): + logger.log_warning("Cannot call close in physics manually update mode.") + return + open_state_path = self.to_target_open_state_path(self.close_state) + for i in range(open_state_path.shape[0]): + self.set_open_state(open_state_path[i]) + time.sleep(0.02) + + @property + def default_physical_attrs(self) -> PhysicalAttr: + physical_attr = PhysicalAttr() + if self.drive_type == DriveType.FORCE: + physical_attr.mass = 0.01 # TODO: mass setting is not activated currently + physical_attr.static_friction = 2.0 + physical_attr.dynamic_friction = 1.5 + physical_attr.linear_damping = 0.7 + physical_attr.angular_damping = 0.7 + physical_attr.contact_offset = 0.005 + physical_attr.rest_offset = 0.001 + physical_attr.restitution = 0.05 + physical_attr.has_gravity = True + physical_attr.max_linear_velocity = 4000 + physical_attr.max_angular_velocity = 25 + physical_attr.max_depenetration_velocity = 1e1 + else: # DriveType.FORCE and so on + physical_attr.mass = 0.01 # TODO: mass setting is not activated currently + physical_attr.static_friction = 2.0 + physical_attr.dynamic_friction = 1.5 + physical_attr.linear_damping = 0.7 + physical_attr.angular_damping = 0.7 + physical_attr.contact_offset = 0.005 + physical_attr.rest_offset = 0.001 + physical_attr.restitution = 0.05 + physical_attr.has_gravity = False + physical_attr.max_linear_velocity = 1e6 + physical_attr.max_angular_velocity = 1e6 + physical_attr.max_depenetration_velocity = 1e1 + return physical_attr + + @property + def default_drive_param(self) -> typing.Dict: + # Stiffness: + # Recommended range: 2000 N/m to 10000 N/m + # Note: Higher stiffness is suitable for tasks that require precise position control, + # such as gripping and assembly. You can start with 5000 N/m and fine-tune based on feedback from the actual application. + # Damping: + # Recommended range: 200 Ns/m to 1000 Ns/m + # Note: Damping values ​​should be high enough to dampen oscillations, + # but not too high to excessively hinder motion. You can start with 500 Ns/m and adjust based on dynamic performance. + # Max force: + # Recommended range: 10000 N to 100000 N + # Note: The maximum force should be set according to the load capacity of the robot arm + # to ensure that it does not exceed its load capacity when working. You can start with 50000 N, depending on the specific task load. + if self.drive_type == DriveType.FORCE: + if hasattr(self, "max_force"): + max_force = self.max_force + else: + max_force = 1e3 + param = {"stiffness": 1e2, "damping": 1e1, "max_force": max_force} + elif self.drive_type == DriveType.FORCE: + param = {"stiffness": 1e8, "damping": 1e6, "max_force": 1e10} + return param diff --git a/embodichain/lab/sim/end_effector/utility.py b/embodichain/lab/sim/end_effector/utility.py new file mode 100644 index 0000000..d1a61e6 --- /dev/null +++ b/embodichain/lab/sim/end_effector/utility.py @@ -0,0 +1,148 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import typing +import pathlib +import hashlib +import numpy as np +import open3d as o3d +from dexsim.kit.meshproc import convex_decomposition_coacd +from dexsim.kit.meshproc.utility import mesh_list_to_file +from embodichain.utils import logger + + +def load_model_from_file(**kwargs) -> typing.Optional[str]: + """Loads a model from the specified file path. + + This function checks the provided file path to determine if it is a URDF file + or a mesh file (STL, OBJ, PLY). If it is a URDF file, it is loaded directly. + If it is a mesh file, a URDF file is generated from the mesh. + + Args: + file_path (str): The path to the input file (URDF or mesh file). + + Returns: + Optional[str]: The path to the loaded URDF file, or None if the file path is not provided or unsupported. + """ + file_path = kwargs.get("file_path", None) + + if file_path is None: + logger.log_warning("No file path provided for the model.") + return None + + file_suffix = pathlib.Path(file_path).suffix + mesh_suffix_list = [".stl", ".obj", ".ply"] + + if file_suffix == ".urdf": + # Load the URDF file directly + urdf_path = file_path + elif file_suffix in mesh_suffix_list: + # Generate URDF from the mesh file + urdf_path = generate_gripper_urdf_from_meshpath(file_path) + else: + logger.log_warning( + f"Unsupported file extension {file_suffix} for model file {file_path}." + ) + return None # Return None for unsupported file types + + return urdf_path + + +def generate_gripper_urdf_from_meshpath( + mesh_file: str, cache_dir: str = None, max_convex_hull_num: int = 16 +) -> str: + r"""Generate URDF for a gripper given a mesh file path. + + Args: + mesh_file (str): The path of mesh file. + cache_dir (str, optional): Cache directory. Defaults to None. + max_convex_hull_num (int, optional): The maximum convex hull number. Defaults to 16. + + Returns: + str: Urdf file path. + """ + mesh_md5_key = hashlib.md5(open(mesh_file, "rb").read()).hexdigest() + + # Set cache directory + save_dir = ( + pathlib.Path(cache_dir) + if cache_dir + else pathlib.Path.home() / "urdf_generate_cache" + ) + # Create the directory if it doesn't exist + save_dir.mkdir(parents=True, exist_ok=True) + + # Define cache file names + acd_file = f"{mesh_md5_key}_acd_{max_convex_hull_num}.obj" + visual_file = f"{mesh_md5_key}_visual.obj" + acd_cache_path = save_dir / acd_file + visual_cache_path = save_dir / visual_file + + # Generate convex decomposition cache if not exists + if not acd_cache_path.is_file() or not visual_cache_path.is_file(): + try: + in_mesh = o3d.t.io.read_triangle_mesh(mesh_file) + _, out_mesh_list = convex_decomposition_coacd( + in_mesh, max_convex_hull_num=max_convex_hull_num + ) + + # Write approximate convex decomposition result + mesh_list_to_file(str(acd_cache_path), out_mesh_list) + # Write visual mesh + o3d.t.io.write_triangle_mesh(str(visual_cache_path), in_mesh) + except Exception as e: + raise RuntimeError(f"Error during mesh processing: {e}") + + # Create URDF string + urdf_str = f""" + + + + + + + + + + + + + + + +""" + + urdf_cache_path = save_dir / f"{mesh_md5_key}.urdf" + + try: + with open(urdf_cache_path, "w") as writer: + writer.write(urdf_str) + except IOError as e: + raise RuntimeError(f"Failed to write URDF file: {e}") + + return str(urdf_cache_path) + + +def inv_transform(transform: np.ndarray) -> np.ndarray: + r"""Compute the inverse transformation. + + Args: + transform (np.ndarray): A [4 x 4] transformation matrix. + + Returns: + np.ndarray: The inverse transformation matrix. + """ + r = transform[:3, :3] + t = transform[:3, 3].T + inv_r = r.T + inv_t = -inv_r @ t + + inv_pose = np.eye(4, dtype=np.float32) + inv_pose[:3, :3] = inv_r + inv_pose[:3, 3] = inv_t + + return inv_pose From 983fd9f3029bf56e1e3fec55a735c01a235757d9 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:20:00 +0800 Subject: [PATCH 25/49] Migrate robot --- embodichain/lab/sim/robots/__init__.py | 1 + embodichain/lab/sim/robots/robot.py | 1177 ++++++++++++++++++++++++ 2 files changed, 1178 insertions(+) create mode 100644 embodichain/lab/sim/robots/robot.py diff --git a/embodichain/lab/sim/robots/__init__.py b/embodichain/lab/sim/robots/__init__.py index de4c08a..95db284 100644 --- a/embodichain/lab/sim/robots/__init__.py +++ b/embodichain/lab/sim/robots/__init__.py @@ -15,4 +15,5 @@ # ---------------------------------------------------------------------------- from .dexforce_w1 import * +from .robot import Robot from .cobotmagic import CobotMagicCfg diff --git a/embodichain/lab/sim/robots/robot.py b/embodichain/lab/sim/robots/robot.py new file mode 100644 index 0000000..5d0bc7e --- /dev/null +++ b/embodichain/lab/sim/robots/robot.py @@ -0,0 +1,1177 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +from typing import List, Tuple, Union, Dict, Any, Optional +from abc import ABC, abstractmethod +from copy import deepcopy +import open3d as o3d +import os +from pathlib import Path +import pytorch_kinematics as pk +from matplotlib import colormaps + +import dexsim +from dexsim.models import Entity, MeshObject + +# from dexsim.engine import Articulation +from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag, PrimitiveType + + +from embodichain.utils import logger +from dexsim.utility.env_utils import create_point_cloud_from_o3d_pcd + +# Try to import DriveController, but make it optional +try: + from rlia.kit.drive_controllers import DriveController +except ImportError: + # If rlia is not available, create a dummy type for type checking + DriveController = None + +from embodichain.lab.sim.end_effector import EndEffector + +# from dexsim.utility import inv_transform +from dexsim.sensor import Sensor, MonocularCam, BinocularCam +from embodichain.lab.sim.articulation_entity import ArticulationEntity + +__all__ = ["Robot"] + + +class Robot(ArticulationEntity, ABC): + r""" + Abstract class for robot in simulation. + """ + + def __init__( + self, + urdf_path: Union[str, List[str]] = dict(), + init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = dict(), + init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None, + speed_ratio: float = 0.5, + time_step: float = 0.02, + drive_type: DriveType = DriveType.FORCE, + env: dexsim.environment.Arena = None, + **kwargs, + ): + r"""Initialize the robot. + + Args: + urdf_path (str): urdf file path of robot + init_qpos (np.ndarray, optional): [dof] of double. Init robot joint state(home joint state). + init_base_xpos (np.ndarray, optional): [4, 4] of double. Robot base pose in arena coordinate system. + speed_ratio (float, optional): 0 ~ 1. Robot speed ratio. + time_step (float, optional): wait time between two update. Defaults to 1/50. + drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. + env (Arena, optional): dexsim.environment.Arena. Load the first world(None defaults). + kwargs(optional): Accepts additional keyword arguments. + """ + # unique name of the robot. + self.uid = kwargs.get("uid", "Robot") + + super().__init__( + urdf_path=urdf_path, + init_qpos=init_qpos, + init_base_xpos=init_base_xpos, + speed_ratio=speed_ratio, + time_step=time_step, + drive_type=drive_type, + env=env, + **kwargs, + ) + + # Initialize the robot + self._init_robot(**kwargs) + + # Disable self-collision avoidance for the articulation + self.set_enable_self_collision_flag(self.uid, False) + + # Additional parameters + self.attach_end_effectors = {} + + # Build pk_serial_chain + self.pk_serial_chain = self.build_pk_serial_chain() + + def set_enable_self_collision_flag(self, name: str = None, is_enable: bool = False): + r"""Set the self-collision flag for the specified articulation + or all articulations. + + Args: + name (str, optional): Name of the articulation. + If None, apply to all articulations. Defaults to None. + is_enable (bool, optional): Flag to enable + or disable self-collision. Defaults to False. + """ + if name is None or name == self.uid: + self.articulation.set_articulation_flag( + ArticulationFlag.DISABLE_SELF_COLLISION, not is_enable + ) + else: + if name in self.child_articulations: + self.child_articulations[name].set_articulation_flag( + ArticulationFlag.DISABLE_SELF_COLLISION, not is_enable + ) + else: + logger.log_warning(f"Articulation '{name}' not found.") + + @abstractmethod + def _init_robot(self, **kwargs) -> None: + r"""Initializes the robot using the URDF path with necessary parameters.""" + pass + + def get_end_effector(self, uid: str = None): + r"""Get the end effector by its unique identifier. + + Args: + uid (str): Unique identifier for the end effector to be attached. + If None, returns a dictionary of all end effectors. + + Returns: + EndEffector: The end effector associated with the given uid, or None if not found. + """ + if uid is None: + return self.attach_end_effectors + + end_effector = self.attach_end_effectors.get(uid) + return end_effector + + def attach_end_effector( + self, + uid: str, + end_effector: EndEffector, + robot_uid: str = None, + attach_xpos: np.ndarray = np.eye(4), + ee_link_name: str = "ee_link", + **kwargs, + ): + r"""Attach an end effector to the robotic system. + + Args: + uid (str): Unique identifier for the end effector to be attached. + end_effector (EndEffector): An instance of the EndEffector class representing the end effector to be attached. + robot_uid (str, optional): Unique identifier for the robot to which the end effector is to be attached. Defaults to None. + attach_xpos (np.ndarray, optional): 4x4 transformation matrix (homogeneous transformation matrix) representing the pose + at which the end effector should be attached. Defaults to identity matrix. + ee_link_name (str, optional): The link string that represents the end effector link in the robot. Defaults to "ee_link". + **kwargs: Additional keyword arguments for extended functionality (if applicable). + Returns: + tuple: A tuple containing a boolean and a value: + - (bool) False if the end effector is already attached, True otherwise. + - (None) Always returns None as the second element. + """ + # If robot_uid is not provided, use the current object's uid + robot_uid = robot_uid or self.uid + + # Check if the end effector is already attached to the robot + if robot_uid == self.uid or robot_uid in self.child_articulations: + target_articulation = ( + self.articulation + if robot_uid == self.uid + else self.child_articulations[robot_uid] + ) + + # Get degrees of freedom for the target articulation and the end effector + arm_dof = target_articulation.get_dof() + ef_dof = end_effector.get_dof() + + # Get the root link name of the end effector + ef_root_link_name = end_effector.articulation.get_root_link_name() + ef_link_names = end_effector.articulation.get_link_names() + end_effector.drive_type = self.drive_type + + end_effector_joint_names = ( + end_effector.articulation.get_actived_joint_names() + ) + + # Load the end effector's URDF into the target articulation at the specified position + target_articulation.load_urdf( + end_effector.get_urdf_path(), ee_link_name, attach_xpos + ) + + # Remove the previous articulation of the end effector + ef_articulation = end_effector.get_articulation(end_effector.uid) + self._env.remove_articulation(ef_articulation) + + # Assign the target articulation to the end effector + end_effector.articulation = target_articulation + + target_articulation_joint_names = ( + target_articulation.get_actived_joint_names() + ) + + # Update joint indices for the end effector + ef_joint_ids = arm_dof + np.arange(ef_dof) + end_effector.set_joint_ids(ef_joint_ids) + + # Combine initial positions + ef_init_qpos = end_effector._init_qpos + joint_name_to_idx = { + name: idx for idx, name in enumerate(target_articulation_joint_names) + } + ef_ids = np.array( + [joint_name_to_idx[name] for name in end_effector_joint_names] + ) + + robot_ids = np.arange(arm_dof) + target_articulation.set_current_qpos( + self.get_init_qpos(robot_uid), joint_ids=robot_ids + ) + target_articulation.set_current_qpos(ef_init_qpos, joint_ids=ef_ids) + + # Set physical attributes for the target articulation + target_articulation.set_physical_attr(self.default_physical_attrs) + target_articulation.set_drive( + drive_type=self.drive_type, **self.default_drive_param + ) + + # Store end effector details in the class attributes + self.child_articulations[uid] = end_effector.articulation + self._dof[uid] = ef_dof + self._joint_ids[uid] = ef_ids + self.init_qpos[uid] = ef_init_qpos + self.root_link_names[uid] = ef_root_link_name + end_effector.attach_robot_uid = robot_uid + + end_effector._joint_ids[end_effector.uid] = ef_ids + + # TODO: update robot, etc. + # Update the joint ids for other end effector + for ee_name, ee in self.attach_end_effectors.items(): + ee_idx_list = np.array( + [joint_name_to_idx[name] for name in ee.actived_joint_names] + ) + + self._joint_ids[ee_name] = ee_idx_list + ee._joint_ids[ee.uid] = ee_idx_list + + # ee_init_qpos = self.init_qpos[ee_name] + # Update the initial positions in the class attributes + # self.init_qpos[ee.uid] = ee_init_qpos[ee_idx_list] + + # Keep a reference of the attached end effector + self.attach_end_effectors[uid] = end_effector + + # set end-effector physical param and drive param + for link_name in ef_link_names: + target_articulation.set_physical_attr( + attrib=end_effector.default_physical_attrs, + link_name=link_name, + is_replace_inertial=True, + ) + target_articulation.set_drive( + drive_type=self.drive_type, + joint_ids=ef_joint_ids, + **end_effector.default_drive_param, + ) + # end_effector.set_drive(end_effector.drive_type) + return True, end_effector + else: + logger.log_warning(f"Articulation '{uid}' not found.") + return False, None + + def attach_sensor( + self, + sensor: Sensor, + robot_uid: str = None, + attach_xpos: np.ndarray = np.eye(4), + link_name: str = "ee_link", + ): + r"""Attach a sensor to a robot. + + Note: + Currently, this function is only available for Monocular and Binocular sensors. + + Args: + sensor (Sensor): The sensor object to be attached. It can be a MonocularCam or BinocularCam. + robot_uid (str, optional): Unique identifier for the robot to which the sensor will be attached. Defaults to None, which refers to the current robot. + attach_xpos (np.ndarray, optional): 4x4 transformation matrix (homogeneous transformation matrix) representing the pose + at which the sensor should be attached. Defaults to the identity matrix. + link_name (str, optional): The link string that represents the attachment point on the robot. Defaults to "ee_link". + + Returns: + None: This function does not return a value but logs warnings for unsupported sensor types or invalid robot identifiers. + """ + robot_uid = robot_uid or self.uid + + # Check if the robot_uid matches the current robot or a child articulation + if robot_uid == self.uid or robot_uid in self.child_articulations: + target_articulation = ( + self.articulation + if robot_uid == self.uid + else self.child_articulations[robot_uid] + ) + + # Attach the sensor based on its type + if isinstance(sensor, MonocularCam): + target_articulation.attach_node( + obj=sensor.get_node(), + link_name=link_name, + relative_pose=attach_xpos, + ) + elif isinstance(sensor, BinocularCam): + # Attach the left camera node + if sensor._coordinate_system == "center": + relative_pose = sensor._relativate_T_l + else: + relative_pose = sensor.get_relative_transform() + relative_pose[:3, 3] = relative_pose[:3, 3] * -0.5 + target_articulation.attach_node( + obj=sensor.get_node(is_left=True), + link_name=link_name, + relative_pose=attach_xpos @ relative_pose, + ) + # Attach the right camera node + target_articulation.attach_node( + obj=sensor.get_node(is_left=False), + link_name=link_name, + relative_pose=attach_xpos @ np.linalg.inv(relative_pose), + ) + else: + logger.log_warning("Unsupported sensor type: %s", type(sensor).__name__) + else: + logger.log_warning(f"Articulation '{robot_uid}' not found.") + + # @deprecated(reason="Currently unable to detach this component.") + def detach_end_effector( + self, + uid: str, + robot_uid: str = None, + ): + r"""Detach an end effector from the robotic system. + + Args: + uid (str): Unique identifier for the end effector to be detached. + robot_uid (str, optional): Unique identifier for the robot from which the end effector is to be detached. + + Returns: + bool: True if the end effector was successfully detached, False otherwise. + """ + if uid not in self.child_articulations: + logger.log_warning(f"End effector {uid} already detached.") + return False + + robot_uid = robot_uid or self.uid + if robot_uid is not self.uid: + logger.log_warning(f"Articulation with UID '{robot_uid}' not found.") + return False + + if uid in self.init_qpos: + del self.init_qpos[uid] + if uid in self.init_base_xpos: + del self.init_base_xpos[uid] + self.child_articulations[uid].detach_parent() + self.child_articulations.pop(uid) + return True + + def close(self, uid: str = None, target: float = 1.0) -> bool: + r"""Closes the attached end effector, if this manipulator has one. If no UID is provided, + it will close all end effectors associated with the manipulator. + + Args: + uid (str, optional): + A unique identifier for the specific end effector to be closed. + If None, the method will attempt to close all end effectors. + Defaults to None. + target (float, optional): + The target position for the close operation, typically representing + the closure position of the end effector. + Defaults to 1.0 (fully closed). + + Returns: + bool: + Returns True if the end effector(s) were closed successfully, + and False otherwise. If no end effector is found with the given UID, + a warning is logged. + """ + is_success = False + if uid is None or uid == self.uid: + for key, value in self.attach_end_effectors.items(): + if isinstance(value, EndEffector): + value.close(target=target) + is_success = True # Mark success if any end effector is closed + else: + if uid in self.attach_end_effectors: + self.attach_end_effectors[uid].close(target=target) + is_success = True + else: + logger.log_warning(f"End effector with UID '{uid}' not found.") + + return is_success + + def open(self, uid: str = None, target: float = 0.0) -> bool: + r""" + Opens the attached end effector, if this manipulator has one. If no UID is provided, + it will open all end effectors associated with the manipulator. + + Args: + uid (str, optional): + A unique identifier for the specific end effector to be opened. + If None, the method will attempt to open all end effectors. + Defaults to None. + target (float, optional): + The target position for the open operation, typically representing + the opening position of the end effector. + Defaults to 0.0 (fully opened). + + Returns: + bool: + Returns True if the end effector(s) were opened successfully, + and False otherwise. If no end effector is found with the given UID, + a warning is logged. + """ + is_success = False + if uid is None or uid == self.uid: + for key, value in self.attach_end_effectors.items(): + if isinstance(value, EndEffector): + value.open(target=target) + is_success = True # Mark success if any end effector is opened + else: + if uid in self.attach_end_effectors: + self.attach_end_effectors[uid].open(target=target) + is_success = True + else: + logger.log_warning(f"End effector with UID '{uid}' not found.") + + return is_success + + def set_controller(self, controller=None, uid: str = None, **kwargs): + r"""Set a drive or task controller to the robot. + + Args: + controller (DriveController, optional): + The controller instance to be added to the robot. Can be either: + - DriveController: For low-level joint control + uid (str, optional): + Unique identifier for the articulation to be controlled. + If None, uses the robot's main articulation ID. + + Returns: + bool: True if controller was successfully set, False otherwise. + """ + uid = uid or self.uid + + # Check if the robot_uid matches the current robot or a child articulation + if uid == self.uid or uid in self.child_articulations: + target_articulation = ( + self.articulation if uid == self.uid else self.child_articulations[uid] + ) + + if DriveController is not None and isinstance(controller, DriveController) and any( + isinstance(controller, ctl_type) + for ctl_type in self.supported_drive_controller_types.values() + ): + if hasattr(controller, "set_init_qpos"): + controller.set_init_qpos(self.init_qpos[uid]) + controller.set_articulation(target_articulation) + controller.set_control_q_ids(self._joint_ids[uid]) + self.drive_controllers[uid] = controller + else: + logger.log_warning(f"Controller type '{type(controller)}' not support.") + return False + else: + logger.log_warning(f"Articulation '{uid}' not found.") + return False + + return True + + def set_speed_ratio(self, speed_ratio: float, uid: str = None): + r"""Set speed ratio of the robot. + + Args: + speed_ratio (float): 0.0~1.0. robot speed ratio. + uid (str): Uid of the articulation. + """ + uid = uid or self.uid + + if uid == self.uid or uid in self.child_articulations: + self.speed_ratio = speed_ratio + return True + else: + logger.log_warning( + f"Drive controller with UID '{uid}' not found. Please add the drive controller before set speed ratio." + ) + return False + + def get_speed_ratio(self, uid: str = None): + r"""Get speed ratio of the robot. + + Args: + uid (str): Uid of the articulation. + """ + uid = uid or self.uid + + if uid == self.uid or uid in self.child_articulations: + return self.speed_ratio + else: + logger.log_warning( + f"Drive controller with UID '{uid}' not found. Please add the drive controller before set speed ratio." + ) + return None + + @abstractmethod + def get_fk(self, qpos: np.ndarray, uid: str = None) -> np.ndarray: + r"""Get forward kinematic of given joints + + Args: + qpos (np.ndarray): [dof] of float. + uid (str, optional): uid of the articulation. Defaults to None. + + Returns: + np.ndarray: Pose of the end-effector. + """ + pass + + @abstractmethod + def get_ik(self, xpos: np.ndarray, uid: str = None, **kwargs) -> np.ndarray: + r"""Get inverse kinematic of given end-effector pose. + + Args: + xpos (np.ndarray): [4, 4] of matrix. + uid (str, optional): uid of the articulation. Defaults to None. + **kwargs: Other parameters. which can be used to specify the IK method. + + Returns: + np.ndarray: [dof] of float. + """ + pass + + @abstractmethod + def move( + self, + path: Union[np.ndarray, List[np.ndarray]], + is_joint: bool = False, + is_wait: bool = True, + **kwargs, + ) -> bool: + r"""Move the robot to the given path. + + Args: + path (np.ndarray): [4, 4] | [waypoint_num, 4, 4] | [dof] of float or + [waypoint_num, dof] of float. Path in cartesian space or joint space. + is_joint (bool, optional): Whether the path is in joint space. Defaults to False. + is_wait (bool, optional): Whether to synchronize the robot movement. Defaults to True. + **kwargs: Other parameters. + + Returns: + bool: is_move_success + """ + pass + + def get_dof(self, name: str = None) -> Union[int, Dict[str, int]]: + r"""Get degree of freedom (DoF) of the robot. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + Union[int, Dict[str, int]]: + - If `name` is None, returns the total DoF of the robot as an integer. + - If `name` is provided and found, returns the DoF of the specified articulation as an integer. + - If `name` is provided but not found, logs a warning and returns 0. + """ + # TODO: Need to clarify behavior. + if name is None: + if isinstance(self._dof, dict): + return sum(self._dof.values()) + else: + return ( + self._dof + ) # Assuming _dof is an integer representing the total DoF + elif name in self._dof: + return self._dof[ + name + ] # Assuming _dof[name] is an integer representing the DoF of the specified articulation + + logger.log_warning(f"Articulation '{name}' not found.") + return 0 + + def get_proprioception(self, remove_index: bool = True) -> Dict[str, Any]: + r"""Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. + + The default proprioception information includes: + - xpos: End-effector pose in the robot base coordinate system. + - qpos: Joint positions. + - qvel: Joint velocities. + - qf (effort): Joint forces. + + Args: + remove_index (bool, optional): + If True, the suffix index of the UID will be removed. + Defaults to True. + + Returns: + Dict[str, Any]: + A dictionary containing the robot's proprioception information, + where keys are the UID or modified UID and values are dictionaries + containing the proprioception data. + """ + obs = {} + + # Helper function to populate proprioception data for a given name + def populate_proprioception(name: str): + return { + "xpos": self.get_current_xpos(name=name, is_world_coordinates=False), + "qpos": self.get_current_qpos(name=name), + "qvel": self.get_current_qvel(name=name), + "qf": self.get_current_qf(name=name), + } + + # Process the main UID + base_name = self.uid.split("_")[0] if remove_index else self.uid + obs[base_name] = populate_proprioception(self.uid) + + # Process child articulations + for child_name in self.child_articulations: + if remove_index: + import re + + modified_name = re.sub(r"(_\d+)$", "", child_name) + else: + modified_name = child_name + + if modified_name in obs: + if isinstance(obs[modified_name], list): + obs[modified_name].append(populate_proprioception(child_name)) + else: + obs[modified_name] = [ + obs[modified_name], + populate_proprioception(child_name), + ] + else: + obs[modified_name] = populate_proprioception(child_name) + + return obs + + def attach_actor( + self, actor: Entity, relative_xpos: np.ndarray, uid: str = None, **kwargs + ) -> Entity: + r"""Attach an actor to the robot. + + Args: + actor (Entity): + The actor to be attached to the robot. + relative_xpos (np.ndarray): + A [4, 4] matrix representing the relative pose of the actor to the robot. + uid (str, optional): + Unique identifier of the articulation. If None, defaults to the robot's UID. + **kwargs: + Additional parameters for future extension. + + Returns: + Entity: + The attached actor, or None if the attachment failed. + """ + uid = uid or self.uid + + # Define a function to attach the actor to the specified articulation + def attach_to_articulation(articulation): + actor_name = actor.get_name() + self.attached_actors[actor_name] = actor + articulation.attach_node(actor.node, "ee_link", relative_xpos) + return actor + + # Check if UID matches the robot's UID + if uid == self.uid: + return attach_to_articulation(self.articulation) + + # Check if UID matches any child articulation + elif uid in self.child_articulations: + return attach_to_articulation(self.child_articulations[uid]) + + # Log a warning if the articulation is not found + logger.log_warning(f"Articulation with UID '{uid}' not found.") + return None + + def remove_actor(self, actor_name: str, delete: bool = False) -> None: + r"""Remove the attached actor from the robot. + + Args: + actor_name (str): Name of the actor to be removed. + delete (bool, optional): Whether to delete the actor from the simulation. Defaults to False. + """ + if actor_name in self.attached_actors: + for key, value in self.child_articulations.items(): + if isinstance(value, EndEffector): + value.detach(actor_name) + self.attached_actors.pop(actor_name) + if delete: + self._env.remove_actor(actor_name) + + def get_attached_actor_names(self) -> List[str]: + r"""Get names of all attached actors. + + Returns: + List[str]: Names of all attached actors. + """ + return list(self.attached_actors.keys()) + + def compute_qpos_reachability( + self, + name: str, + resolution: float = np.radians(50), + qpos_limits: np.ndarray = None, + cache_mode: str = "memory", + visualize: bool = False, + batch_size: int = 100000, + use_cached: bool = True, + **kwargs, + ) -> Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: + """Compute the robot's reachable workspace by joint space sampling. + + Samples points in joint space and optionally visualizes the resulting end-effector positions + as a colored point cloud. If `visualize` is True, points closer to the robot base are colored green, + transitioning to red for points further away. If `visualize` is False, only the sampling is performed + without any visualization. + + + Args: + name (str): Identifier of the robot drive controller to analyze + resolution (float, optional): Angular resolution for joint space sampling in radians. + Lower values provide finer sampling but increase computation time. + Defaults to 50 degrees (≈0.873 radians) + qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). + If None, uses limits from drive controller or articulation. + Defaults to None + cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". + Defaults to "memory". + visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. + Defaults to False. + batch_size (int, optional): Number of samples to process in each batch. + Defaults to 100000. + use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. + Ignored for "memory" mode. Defaults to True. + + + Returns: + Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: + The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. + The second element is a point cloud handle if visualization is enabled and successful, otherwise None. + """ + from embodichain.lab.sim.utility.workspace_analyzer import ( + WorkspaceAnalyzer, + ) + from embodichain.lab.sim import REACHABLE_XPOS_DIR + + if name not in self.drive_controllers: + logger.log_warning(f"Drive controller '{name}' not found") + return None, None + + # try: + # Get robot configuration + base_xpos = self.get_base_xpos(name=name) + drive_controller = self.drive_controllers[name] + + if qpos_limits is None: + if hasattr(drive_controller, "get_joint_limits"): + res, upper_limits, lower_limits = self.drive_controllers[ + name + ].get_joint_limits() + if not res: + logger.log_warning("Failed to get joint limits") + return None, None + joint_ranges = np.column_stack((lower_limits, upper_limits)) + else: + joint_limits = self.articulation.get_joint_limits() + joint_ranges = joint_limits[self._joint_ids[name]] + else: + joint_ranges = qpos_limits + paths = self.get_urdf_path() + urdf_path = paths if isinstance(paths, str) else paths[self.uid] + robot_name = os.path.splitext(os.path.basename(urdf_path))[0] + # Initialize workspace analyzer + analyzer = WorkspaceAnalyzer( + robot=self, name=name, resolution=resolution, joint_ranges=joint_ranges + ) + # Format resolution to avoid issues with decimal points in paths + resolution_str = f"{resolution:.2f}".replace(".", "_") + # Join into one directory name + save_dir = REACHABLE_XPOS_DIR / f"{robot_name}_{name}_{resolution_str}" + # Sample workspace points + sampled_xpos = analyzer.sample_qpos_workspace( + cache_mode=cache_mode, + save_dir=save_dir, + batch_size=batch_size, + use_cached=use_cached, + ) + if visualize == True: + # Create and configure point cloud visualization + # all_positions = [xpos[:3, 3] for xpos in sampled_xpos] + N = len(sampled_xpos) + all_pos = np.empty((N, 3), dtype=np.float16) + for i, mat in enumerate(sampled_xpos): + all_pos[i] = mat[:3, 3].astype(np.float16) + pcd = analyzer._process_point_cloud(positions=all_pos) + # Transfer to World Coordinate + pcd.transform(base_xpos) + pcd_handle = create_point_cloud_from_o3d_pcd(pcd=pcd, env=self._env) + else: + return sampled_xpos, None + + return sampled_xpos, pcd_handle + + # except Exception as e: + # logger.log_warning(f"Failed to visualize qpos workspace: {str(e)}") + # return None, None + + def compute_xpos_reachability( + self, + name: str, + ref_xpos: np.ndarray, + xpos_resolution: float = 0.2, + qpos_resolution: float = np.radians(60), + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1500, + num_samples: int = 5, + batch_size: int = 100000, + save_threshold: int = 10000000, + qpos_limits: np.ndarray = None, + cache_mode: str = "memory", + visualize: bool = True, + use_cached: bool = True, + **kwargs, + ) -> Tuple[ + Optional[list[np.ndarray]], # First return: list of sampled 4x4 poses + Optional[ + dexsim.models.PointCloud + ], # Second return: point cloud handle if visualization is enabled + ]: + """Compute the robot's reachable workspace by Cartesian space sampling. + + Samples points in Cartesian space and checks reachability using inverse kinematics. + If `visualize` is True, visualizes reachable positions as a colored point cloud; + Otherwise, only performs the sampling result as open3d PointCloud. + + + Args: + name (str): Identifier of the robot drive controller to analyze + ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the + orientation for IK solutions + xpos_resolution (float, optional): Cartesian space sampling resolution in meters. + Smaller values provide finer sampling but increase + computation time. Defaults to 0.2 meters. + qpos_resolution (float, optional): Angular resolution for initial joint space + sampling in radians. Used to determine workspace + bounds. Defaults to 60 degrees. + pos_eps (float, optional): Position tolerance for IK solutions in meters. + Defaults to 2e-4 meters. + rot_eps (float, optional): Rotation tolerance for IK solutions in radians. + Defaults to 2e-4 radians. + max_iterations (int, optional): Maximum number of IK iterations per sample. + Defaults to 2000. + num_samples (int, optional): Number of samples to generate in Cartesian space. + Defaults to 10. + qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). + If None, uses limits from drive controller or + articulation. Defaults to None + cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". + Defaults to "memory". + visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. + Defaults to True. + use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. + Ignored for "memory" mode. Defaults to True. + + Returns: + Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: + The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. + The second element is a point cloud handle if visualization is enabled and successful, otherwise None. + """ + from embodichain.lab.sim.utility.workspace_analyzer import ( + WorkspaceAnalyzer, + ) + from embodichain.lab.sim import REACHABLE_XPOS_DIR + + if name not in self.drive_controllers: + logger.log_warning(f"Drive controller '{name}' not found") + return None, None + + # try: + # Get robot configuration + base_xpos = self.get_base_xpos(name=name) + ref_xpos_robot = dexsim.utility.inv_transform(base_xpos) @ ref_xpos + drive_controller = self.drive_controllers[name] + + if qpos_limits is None: + if hasattr(drive_controller, "get_joint_limits"): + res, upper_limits, lower_limits = self.drive_controllers[ + name + ].get_joint_limits() + if not res: + logger.log_warning("Failed to get joint limits") + return None, None + joint_ranges = np.column_stack((lower_limits, upper_limits)) + else: + joint_limits = self.articulation.get_joint_limits() + joint_ranges = joint_limits[self._joint_ids[name]] + else: + joint_ranges = qpos_limits + + paths = self.get_urdf_path() + urdf_path = paths if isinstance(paths, str) else paths[self.uid] + robot_name = os.path.splitext(os.path.basename(urdf_path))[0] + + qpos_resolution_str = f"{qpos_resolution:.2f}".replace(".", "_") + xpos_resolution_str = f"{xpos_resolution:.2f}".replace(".", "_") + # Join into one directory name + save_dir = ( + REACHABLE_XPOS_DIR + / f"{robot_name}_{name}_{qpos_resolution_str}_{xpos_resolution_str}" + ) + + # Initialize workspace analyzer + analyzer = WorkspaceAnalyzer( + robot=self, + name=name, + resolution=qpos_resolution, + joint_ranges=joint_ranges, + ) + # Sample workspace points + sampled_xpos = analyzer.sample_xpos_workspace( + ref_xpos=ref_xpos_robot, + xpos_resolution=xpos_resolution, + qpos_resolution=qpos_resolution, + cache_mode=cache_mode, + batch_size=batch_size, + save_dir=save_dir, + save_threshold=save_threshold, + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + use_cached=use_cached, + ) + + if visualize == visualize: + if sampled_xpos is None: + logger.log_warning("No reachable positions found.") + return None, None + all_positions = [xpos[:3, 3] for xpos in sampled_xpos] + pcd = analyzer._process_point_cloud( + positions=all_positions, is_voxel_down=False + ) + # Transfer to World Coordinate + pcd.transform(base_xpos) + # Create and configure point cloud visualization + pcd_handle = create_point_cloud_from_o3d_pcd(pcd=pcd, env=self._env) + else: + return sampled_xpos, None + + return sampled_xpos, pcd_handle + + def compute_voxel_reachability( + self, + name: str, + voxel_size: float = 0.04, + num_directions: int = 50, + num_yaws=6, + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1500, + num_samples: int = 5, + qpos_limits: np.ndarray = None, + cache_mode: str = "memory", + visualize: bool = False, + use_cached: bool = True, + **kwargs, + ) -> Tuple[Optional[List[np.ndarray]], Optional[List[MeshObject]]]: + """ + Compute the robot's reachable workspace by voxel-based sampling. + + Samples voxel centers within a sphere around the robot’s end-effector base + and checks reachability via inverse kinematics. + If `visualize` is True, spawns a colored sphere actor at each voxel center + to indicate success rate; otherwise returns only the sampled poses. + + Args: + name (str): Identifier of the drive controller to analyze. + voxel_size (float, optional): Edge length of each cubic voxel (m). + Smaller values give finer resolution but increase computation time. + Defaults to 0.04. + num_directions (int, optional): Number of sample directions per voxel. + Defaults to 50. + num_yaws (int, optional): Number of discrete yaw rotations **around the local Z-axis** + to try for each sample direction when solving IK. A higher value can + increase rotational coverage but incurs more IK calls. Defaults to 6. + qpos_limits (np.ndarray, optional): Custom joint limits array of shape + (n_joints, 2). If None, retrieves limits from the controller or + articulation. Defaults to None. + cache_mode (str, optional): “memory” or “disk” mode for caching IK + results. Defaults to "memory". + visualize (bool, optional): If True, returns a list of DexSim actor + handles for visualization; otherwise returns None for actors. + Defaults to False. + use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. + Ignored for "memory" mode. Defaults to True. + + Returns: + Tuple[Optional[List[np.ndarray]], Optional[List[MeshObject]]]: + - List of sampled end-effector poses (4×4 matrices), or None on failure. + - List of sphere actor handles if visualize=True, else None. + """ + from embodichain.lab.sim.utility.workspace_analyzer import ( + WorkspaceAnalyzer, + ) + from embodichain.lab.sim import REACHABLE_XPOS_DIR + + # 1) Validate drive controller + if name not in self.drive_controllers: + logger.log_warning(f"Drive controller '{name}' not found") + return None, None + + try: + drive_controller = self.drive_controllers[name] + + # 2) Determine joint limits + if qpos_limits is None: + if hasattr(drive_controller, "get_joint_limits"): + res, upper, lower = drive_controller.get_joint_limits() + if not res: + logger.log_warning("Failed to get joint limits") + return None, None + joint_ranges = np.column_stack((lower, upper)) + else: + all_limits = self.articulation.get_joint_limits() + joint_ranges = all_limits[self._joint_ids[name]] + else: + joint_ranges = qpos_limits + + # 3) Prepare save directory + urdf_paths = self.get_urdf_path() + urdf_path = ( + urdf_paths if isinstance(urdf_paths, str) else urdf_paths[self.uid] + ) + robot_name = os.path.splitext(os.path.basename(urdf_path))[0] + + vs_str = f"{voxel_size:.2f}".replace(".", "_") + nd_str = str(num_directions) + save_dir = ( + REACHABLE_XPOS_DIR / f"Voxel_{robot_name}_{name}_{vs_str}_{nd_str}" + ) + + # 4) Set up workspace analyzer + analyzer = WorkspaceAnalyzer( + robot=self, name=name, joint_ranges=joint_ranges + ) + + # 5) Sample voxels and IK + ( + voxel_centers, + voxel_success_counts, + sampled_xpos, + ) = analyzer.sample_voxel_workspace( + voxel_size=voxel_size, + num_directions=num_directions, + num_yaws=num_yaws, + pos_eps=pos_eps, + rot_eps=rot_eps, + max_iterations=max_iterations, + num_samples=num_samples, + cache_mode=cache_mode, + batch_size=5000, + save_dir=save_dir, + save_threshold=10_000_000, + use_cached=use_cached, + ) + + # 6) Visualization (optional) + if visualize: + colormap = colormaps.get_cmap("jet") + actor_handles: List[MeshObject] = [] + + for idx, (center, count) in enumerate( + zip(voxel_centers, voxel_success_counts), start=1 + ): + # map success rate to color + frac = count / num_directions + color = colormap(1.0 - frac)[:3] + + # build and color sphere mesh + sphere = o3d.geometry.TriangleMesh.create_sphere(voxel_size / 2) + sphere.paint_uniform_color(color) + + verts = np.asarray(sphere.vertices) + inds = np.asarray(sphere.triangles) + cols = np.asarray(sphere.vertex_colors) + cols4 = np.ones((cols.shape[0], 4), dtype=float) + cols4[:, :3] = cols + + # create uniquely named actor e.g. "sphere1", "sphere2", … + actor_name = f"sphere{idx}" + actor = self._env.create_actor(actor_name, True, True) + actor.set_mesh( + vertices=verts, + indices=inds, + shape=PrimitiveType.TRIANGLES, + smooth_angle=-1, + colors=cols4, + ) + actor.set_location(*center) + + actor_handles.append(actor) + + return sampled_xpos, actor_handles + + # 7) Return only sampled poses + return sampled_xpos, None + + except Exception as e: + print(f"Failed to visualize voxel workspace: {e}") + return None, None + + def destroy(self) -> None: + r"""Release the resources of the robot.""" + # Safely handle drive_controllers + if hasattr(self, "drive_controllers") and isinstance( + self.drive_controllers, dict + ): + for key in self.drive_controllers.keys(): + self.drive_controllers[key] = None + + # Safely handle task_controllers + if hasattr(self, "task_controllers") and isinstance( + self.task_controllers, dict + ): + for key in self.task_controllers.keys(): + self.task_controllers[key] = None + + # Safely handle articulation + if hasattr(self, "articulation"): + self.articulation = None + + # Safely handle child_articulations + if hasattr(self, "child_articulations") and isinstance( + self.child_articulations, dict + ): + for key in self.child_articulations.keys(): + if self.child_articulations[key] is not None: + if hasattr(self.child_articulations[key], "get_articulation"): + self._env.remove_articulation( + self.child_articulations[key].get_articulation() + ) + else: + self._env.remove_articulation(self.child_articulations[key]) + + self.child_articulations[key] = None + + @staticmethod + def build_pk_serial_chain(**kwargs) -> Dict[str, pk.SerialChain]: + """Build the serial chain from the URDF file. + + Args: + **kwargs: Additional arguments for building the serial chain. + + Returns: + Dict[str, pk.SerialChain]: The serial chain of the robot. + """ + # paths = self.get_urdf_path() + # urdf_path = paths if isinstance(paths, str) else paths[self.uid] + # chain = pk.build_chain_from_urdf(open(urdf_path, mode="rb").read()) + + # articulation = robot.get_articulation(self.uid) + # link_names = articulation.get_link_names() + # serial_chain = pk.SerialChain(chain, link_names[-1], link_names[0]) + + # return {self.uid: serial_chain} + return {} From 974ee5f68375eab60366fc017b0d4de80eaa0fab Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:20:31 +0800 Subject: [PATCH 26/49] Update direction of database --- embodichain/data/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/embodichain/data/__init__.py b/embodichain/data/__init__.py index 1e14543..fccbbc2 100644 --- a/embodichain/data/__init__.py +++ b/embodichain/data/__init__.py @@ -20,6 +20,7 @@ database_dir = os.path.dirname(os.path.abspath(__file__)).replace("data", "database") database_2d_dir = os.path.join(database_dir, "2dasset") database_agent_prompt_dir = os.path.join(database_dir, "agent_prompt") +database_demo_dir = os.path.join(database_dir, "demostration") from . import assets from .dataset import * From 40957f1f780cfb2b5c676455516bbc5e41e44f69 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:21:25 +0800 Subject: [PATCH 27/49] Migrate robot interface --- embodichain/lab/gym/robots/interface.py | 243 ++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 embodichain/lab/gym/robots/interface.py diff --git a/embodichain/lab/gym/robots/interface.py b/embodichain/lab/gym/robots/interface.py new file mode 100644 index 0000000..5f0a289 --- /dev/null +++ b/embodichain/lab/gym/robots/interface.py @@ -0,0 +1,243 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import pytorch_kinematics as pk + +from abc import abstractmethod +from typing import List, Dict, Tuple, Union + +from gymnasium import spaces + +from embodichain.data.enum import ControlParts, EndEffector, JointType +from embodichain.lab.sim.robots import Robot +from embodichain.utils import logger +from embodichain.data.enum import JointType, EefType, ActionMode + + +class LearnableRobot(Robot): + """The interface class for the learnable robot agent. + + There are three types of actions should be explained: + - Real robot actions: The actions that the real robot can execute. + - Control actions: The actions that the robot interacts with the policy. + - Environment actions: The actions that the robot executes in the simulation environment. + """ + + def get_single_action_space(self) -> spaces.Dict: + limits = self.get_joint_limits(self.uid) + low, high = limits[:, 0], limits[:, 1] + single_action_space = spaces.Dict( + { + JointType.QPOS.value: spaces.Box(low=low, high=high, dtype=np.float32), + } + ) + + return single_action_space + + def step_env_action(self, action: Dict): + qpos = action[JointType.QPOS.value] + + self.set_current_qpos(self.uid, qpos) + return action + + def get_debug_xpos_dict( + self, + ) -> Dict[str, np.ndarray]: + """Get the debug xpos list.""" + return {} + + def get_data_index(self, name: str, warning: bool = True) -> List[int]: + """ + Get the data index for the control part. Subclasses must implement the index_map attribute. + + Args: + name (str): The name of the control part. + warning (bool, optional): Whether to log a warning if the control part is not supported. Defaults to True. + + Returns: + List[int]: The list of indices for the control part. Returns an empty list if not found. + + Raises: + NotImplementedError: If the subclass does not define the index_map attribute. + """ + if not hasattr(self, "index_map"): + raise NotImplementedError("Subclasses must define the index_map attribute.") + if name in self.index_map: + return self.index_map[name] + else: + if warning: + logger.log_warning(f"Control part {name} is not supported.") + return [] + + def map_ee_state_to_env_actions( + self, ee_state: np.ndarray, env_actions: np.ndarray + ) -> np.ndarray: + """Map the end-effector state to the environment actions of robot agent. + + Args: + ee_state (np.ndarray): The end-effector state. + env_actions (np.ndarray): The environment actions of the robot agent. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + return env_actions + + def map_real_actions_to_control_actions(self, actions: np.ndarray) -> np.ndarray: + """Map the real robot actions to the control actions of robot agent, which + should has the same dimension. + + The control actions should be the actions that match the articulation joint limits. + + Note: + Real robot may have gap in the action compared to the simulation robot agent. The + method provides a place the process the gap. + + Args: + actions (np.ndarray): The real robot actions collected from the robot. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + return actions + + def map_control_actions_to_env_actions( + self, + actions: np.ndarray, + env_action_dim: int, + action_type: str = JointType.QPOS.value, + ) -> np.ndarray: + """Map the control actions to the environment actions of robot agent. + + Args: + actions (np.ndarray): The control actions. + env_action_dim (int): The dimension of the environment action space. + action_type (str, optional): The type of action. Defaults to JointType.QPOS.value. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + control_index = self.get_data_index(self.uid) + if action_type != EefType.POSE.value and actions.shape[1] != len(control_index): + logger.log_error( + f"The policy action dimension {actions.shape[1]} does not match the control index dimension {len(control_index)}." + ) + + length = len(actions) + env_actions = np.zeros((length, env_action_dim)) + if action_type == JointType.QPOS.value: + env_actions[:, control_index] = actions + elif action_type == EefType.POSE.value: + # TODO: the eef state is also mapped in this function, which should be separated. + env_actions = self.map_eef_pose_to_env_qpos(actions, env_actions) + else: + logger.log_error(f"Invalid action type: {action_type}") + + return env_actions + + def map_env_qpos_to_eef_pose( + self, env_qpos: np.ndarray, to_dict: bool = False, ret_mat: bool = False + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Map environment joint positions to end-effector pose representation. + + Args: + env_qpos (np.ndarray): Joint positions from the environment, shape [batch, total_dof]. + + Returns: + np.ndarray: End-effector pose, shape [batch, 18]. + [left_pos(3), left_x(3), left_y(3), right_pos(3), right_x(3), right_y(3)] + """ + num_pose = env_qpos.shape[0] + left_indices = self.get_joint_ids( + ControlParts.LEFT_ARM.value + JointType.QPOS.value + ) + right_indices = self.get_joint_ids( + ControlParts.RIGHT_ARM.value + JointType.QPOS.value + ) + + left_env_qpos = torch.as_tensor(env_qpos[:, left_indices], dtype=torch.float32) + right_env_qpos = torch.as_tensor( + env_qpos[:, right_indices], dtype=torch.float32 + ) + + left_ret = ( + self.pk_serial_chain[ControlParts.LEFT_ARM.value + JointType.QPOS.value] + .forward_kinematics(left_env_qpos, end_only=True) + .get_matrix() + ) + right_ret = ( + self.pk_serial_chain[ControlParts.RIGHT_ARM.value + JointType.QPOS.value] + .forward_kinematics(right_env_qpos, end_only=True) + .get_matrix() + ) + + eef_pose = np.zeros((num_pose, 18)) + eef_pose[..., :3] = left_ret[..., :3, 3] + eef_pose[..., 3:6] = left_ret[..., :3, 0] + eef_pose[..., 6:9] = left_ret[..., :3, 1] + eef_pose[..., 9:12] = right_ret[..., :3, 3] + eef_pose[..., 12:15] = right_ret[..., :3, 0] + eef_pose[..., 15:18] = right_ret[..., :3, 1] + + if to_dict: + from embodichain.data.enum import EefType + + if not ret_mat: + return { + ControlParts.LEFT_ARM.value + EefType.POSE.value: eef_pose[..., :9], + ControlParts.RIGHT_ARM.value + + EefType.POSE.value: eef_pose[..., 9:], + } + else: + return { + ControlParts.LEFT_ARM.value + EefType.POSE.value: left_ret, + ControlParts.RIGHT_ARM.value + EefType.POSE.value: right_ret, + } + else: + return eef_pose + + def map_eef_pose_to_env_qpos( + self, eef_pose: np.ndarray, env_qpos: np.ndarray + ) -> np.ndarray: + """Map the end-effector pose to the environment actions. + + Args: + eef_pose (np.ndarray): The end-effector pose. + env_qpos (np.ndarray): The env qpos to be mapped. + + Returns: + np.ndarray: The environment actions. + """ + return env_qpos + + def clip_env_qpos(self, env_qpos: np.ndarray) -> np.ndarray: + """Clip the environment qpos based on the robot joint limits. + + Args: + env_qpos (np.ndarray): The environment qpos to be clipped. + + Returns: + np.ndarray: The clipped environment qpos. + """ + limits = self.get_joint_limits(self.uid) + low, high = limits[:, 0], limits[:, 1] + env_qpos = np.clip(env_qpos, low, high) + return env_qpos + + @staticmethod + def build_pk_serial_chain(**kwargs) -> Dict[str, pk.SerialChain]: + """Build the serial chain from the URDF file. + + Args: + **kwargs: Additional arguments for building the serial chain. + + Returns: + Dict[str, pk.SerialChain]: The serial chain of the robot. + """ + return {} From d4a3ae5cdf9272312599ee94c8f15f307b8b26d1 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:22:45 +0800 Subject: [PATCH 28/49] Migrate two get control information function in action_bank utils --- embodichain/lab/gym/envs/action_bank/utils.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/embodichain/lab/gym/envs/action_bank/utils.py b/embodichain/lab/gym/envs/action_bank/utils.py index 255e5b8..8347dd0 100644 --- a/embodichain/lab/gym/envs/action_bank/utils.py +++ b/embodichain/lab/gym/envs/action_bank/utils.py @@ -28,6 +28,39 @@ def get_init_affordance(scope: str, tag: str = "init") -> str: return "{}_{}_qpos".format(scope, tag) +def get_control_part(env, agent_uid): + + from embodichain.lab.gym.utils.misc import _data_key_to_control_part + + control_parts = env.metadata["dataset"]["robot_meta"].get("control_parts", []) + + if agent_uid in control_parts: + return agent_uid + else: + return _data_key_to_control_part( + robot=env.robot, + control_parts=control_parts, + data_key=agent_uid, + ) + +def get_control_part_joint_ids(env, key: str) -> List[int]: + from embodichain.data.enum import ( + ControlParts, + ControlPartsMappingW1, + ) + + control_part = get_control_part(env, key) + if control_part == ControlParts.WAIST.value: + waist_joint_id = env.robot.get_joint_ids(name=ControlParts.TORSO.value)[ + ControlPartsMappingW1.WAIST_IN_TORSO.value + ] + if not isinstance(waist_joint_id, (list)): + return [waist_joint_id] + return waist_joint_id + + else: + return env.robot.get_joint_ids(name=control_part, remove_mimic=True) + def generate_affordance_from_src( env, From c4eb89fa1ac5dcf4d92fffc13d77ec5a05ac4c0b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 18:30:06 +0800 Subject: [PATCH 29/49] Migrate misc --- embodichain/lab/gym/utils/misc.py | 46 ++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py index b75b70a..8972b6b 100644 --- a/embodichain/lab/gym/utils/misc.py +++ b/embodichain/lab/gym/utils/misc.py @@ -28,7 +28,7 @@ from collections import OrderedDict from importlib import import_module from scipy.spatial.transform import Rotation as R -from typing import Any, Dict, List, Tuple, Union, Sequence, Callable, Mapping +from typing import Any, Dict, List, Tuple, Union, Sequence, Callable, Mapping, Optional import numpy as np @@ -756,6 +756,11 @@ def resolve_env_attr(obj: Any, env: Any) -> Any: _EXPR = re.compile(r"\$\{([^}]+)\}") # For searching ${...} marker +def is_binocularcam(sensor): + from dexsim.sensor import BinocularCam + from embodichain.lab.sim.sensors import StereoCamera + + return isinstance(sensor, BinocularCam) or isinstance(sensor, StereoCamera) def resolve_formatted_string(obj, local_vars=None, global_vars=None): """Given a dict carrys "${...}"-like strings , `eval` the "${...}$" values while keep the dict structure. @@ -1382,3 +1387,42 @@ def is_stereocam(sensor) -> bool: from embodichain.lab.sim.sensors import StereoCamera return isinstance(sensor, StereoCamera) + +def _data_key_to_control_part(robot, control_parts, data_key: str) -> Optional[str]: + # TODO: Temporary workaround, should be removed after refactoring data dict extractor. + # @lru_cache(max_size=None) # NOTE: no way to pass a hashable parameter + def is_eef_hand_func(robot, control_parts) -> bool: + # TODO: This is a temporary workaround, should be used a more general method to check + # whether the end-effector is a hand. + for part in control_parts: + if "eef" in part: + joint_ids = robot.get_joint_ids(part, remove_mimic=True) + return len(joint_ids) >= 2 + return False + + from embodichain.data.enum import ( + ControlParts, + EndEffector, + ActionMode, + JointType, + EefType, + ) + + is_eef_hand = is_eef_hand_func(robot, control_parts) + + for control_part in ControlParts: + if EndEffector.DEXTROUSHAND.value in data_key and is_eef_hand: + return data_key.replace(EndEffector.DEXTROUSHAND.value, "") + elif EndEffector.DEXTROUSHAND.value in data_key and not is_eef_hand: + continue + elif EndEffector.GRIPPER.value in data_key and not is_eef_hand: + return data_key.replace(EndEffector.GRIPPER.value, "") + elif EndEffector.GRIPPER.value in data_key and is_eef_hand: + continue + elif ActionMode.RELATIVE.value + JointType.QPOS.value in data_key: + continue + elif EefType.POSE.value in data_key: + continue + elif control_part.value in data_key: + return control_part.value + return None From 189965d438022bda722ee57fc1486c48643e78f4 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 19:16:59 +0800 Subject: [PATCH 30/49] Fix: change data configs to functor format --- .../pour_water_agent_v3/fast_gym_config.json | 30 +++++++++++-------- .../fast_gym_config.json | 30 +++++++++++-------- .../lab/gym/envs/managers/dataset_manager.py | 17 +++++++---- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json index b06bbb7..10a3d9c 100644 --- a/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json +++ b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json @@ -205,18 +205,24 @@ } }, "dataset": { - "instruction": { - "lang": "Pour water from the bottle into the mug." - }, - "robot_meta": { - "arm_dofs": 12, - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], - "observation": { - "vision": {}, - "states": ["qpos"] - }, - "min_len_steps": 5 + "lerobot": { + "func": "embodichain.lab.gym.envs.managers.datasets:LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from the bottle into the mug." + } + } } }, "success_params": { diff --git a/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json index 117c4e5..5cc6daa 100644 --- a/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json +++ b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json @@ -190,18 +190,24 @@ } }, "dataset": { - "instruction": { - "lang": "Place the spoon and fork neatly into the plate on the table." - }, - "robot_meta": { - "arm_dofs": 12, - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], - "observation": { - "vision": {}, - "states": ["qpos"] - }, - "min_len_steps": 125 + "lerobot": { + "func": "embodichain.lab.gym.envs.managers.datasets:LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 125 + }, + "instruction": { + "lang": "Place the spoon and fork neatly into the plate on the table." + } + } } }, "success_params": { diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py index a0ca168..e950297 100644 --- a/embodichain/lab/gym/envs/managers/dataset_manager.py +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -87,20 +87,25 @@ def __init__(self, cfg: object, env: EmbodiedEnv): super().__init__(cfg, env) ## TODO: fix configurable_action.py to avoid getting env.metadata['dataset'] - # Extract robot_meta from first functor and add to env.metadata for backward compatibility + # Extract robot_meta and instruction from first functor and add to env.metadata for backward compatibility # This allows legacy code (like action_bank) to access robot_meta via env.metadata["dataset"]["robot_meta"] for mode_cfgs in self._mode_functor_cfgs.values(): for functor_cfg in mode_cfgs: - if "robot_meta" in functor_cfg.params: + if "robot_meta" in functor_cfg.params or "instruction" in functor_cfg.params: if not hasattr(env, "metadata"): env.metadata = {} if "dataset" not in env.metadata: env.metadata["dataset"] = {} - env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ - "robot_meta" - ] + if "robot_meta" in functor_cfg.params: + env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ + "robot_meta" + ] + if "instruction" in functor_cfg.params: + env.metadata["dataset"]["instruction"] = functor_cfg.params[ + "instruction" + ] logger.log_info( - "Added robot_meta to env.metadata for backward compatibility" + "Added robot_meta and instruction to env.metadata for backward compatibility" ) break else: From 4a2d0a9c86d01c61de8b21ca5d88fa8ac38f480c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 19:21:00 +0800 Subject: [PATCH 31/49] Fix: use get_wrapper_attr for wrapped env methods in run_agent.py --- embodichain/lab/scripts/run_agent.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py index a10f27c..a32117f 100644 --- a/embodichain/lab/scripts/run_agent.py +++ b/embodichain/lab/scripts/run_agent.py @@ -70,12 +70,13 @@ def wait_for_threads(threads): ret = [] trajectory_idx = 0 - env.create_demo_action_list(regenerate=regenerate) + # Access the wrapped environment's method + env.get_wrapper_attr("create_demo_action_list")(regenerate=regenerate) # --------------------------------------------------------- # SUCCESS CASE # --------------------------------------------------------- - if not debug_mode and env.is_task_success().item(): + if not debug_mode and env.get_wrapper_attr("is_task_success")().item(): dataset_id = f"time_{time_id}_trajectory_{trajectory_idx}" @@ -85,13 +86,16 @@ def wait_for_threads(threads): num_samples = kwargs.get("num_samples", 0) is_save_dataset = time_id < num_samples - data_dict = env.to_dataset(id=dataset_id if is_save_dataset else None) + data_dict = env.get_wrapper_attr("to_dataset")(id=dataset_id if is_save_dataset else None) ret.append(data_dict) else: - data_dict = env.to_dataset(id=dataset_id) + data_dict = env.get_wrapper_attr("to_dataset")(id=dataset_id) # episode id - episode = getattr(env, "get_current_episode", lambda: time_id)() + try: + episode = env.get_wrapper_attr("get_current_episode")() + except AttributeError: + episode = time_id # video saving if save_video: From c2d0e00dee2ee1b934b29b867cfd9b7795f1972e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:02:34 +0800 Subject: [PATCH 32/49] Fix: missing task plan --- .../envs/tasks/tableware/base_agent_env.py | 273 ++++++------------ 1 file changed, 81 insertions(+), 192 deletions(-) diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index a0e106c..1aa3d62 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -3,6 +3,9 @@ import traceback from embodichain.data import database_agent_prompt_dir from pathlib import Path +import tempfile +import numpy as np +import random import os from embodichain.toolkits.interfaces import extract_drive_calls, draw_axis from embodichain.agents.hierarchy.code_agent import format_execution_history @@ -179,16 +182,44 @@ def generate_code_for_actions(self, regenerate=False, **kwargs): color="green", ) - # # Task planning (not used currently) - # print(f"\033[92m\nStart task planning.\n\033[0m") - # task_agent_input = self.task_agent.get_composed_observations(env=self) - # query = self.task_agent.generate(**task_agent_input, regenerate=regenerate, **kwargs) + # Task planning + print(f"\033[92m\nStart task planning.\n\033[0m") + + # Handle one_stage_prompt_for_correction which needs obs_image_path + if self.task_agent.prompt_name == 'one_stage_prompt_for_correction': + kwargs.setdefault("last_task_plan", "None.") + kwargs.setdefault("last_executed_failure", "None.") + kwargs.setdefault("last_executed_history", "None.") + + temp_img_dir = Path(tempfile.mkdtemp()) / "obs_images" + temp_img_dir.mkdir(parents=True, exist_ok=True) + + # Convert torch tensor to numpy array if needed + obs_image = self.get_obs_for_agent()["valid_rgb_1"] + if isinstance(obs_image, torch.Tensor): + obs_image = obs_image.cpu().numpy() + if obs_image.dtype in [np.float32, np.float64]: + obs_image = (obs_image * 255).astype(np.uint8) + + obs_image_path = save_obs_image( + obs_image=obs_image, + save_dir=temp_img_dir, + step_id=0 + ) + kwargs['obs_image_path'] = str(obs_image_path) + + task_agent_input = self.task_agent.get_composed_observations( + env=self, regenerate=regenerate, **kwargs + ) + task_plan = self.task_agent.generate(**task_agent_input) # Code generation print(f"\033[94m\nStart code generation.\n\033[0m") code_agent_input = self.code_agent.get_composed_observations( env=self, regenerate=regenerate, **kwargs ) + code_agent_input['task_plan'] = task_plan + code_file_path, kwargs, code = self.code_agent.generate(**code_agent_input) return code_file_path, kwargs, code @@ -199,197 +230,55 @@ def create_demo_action_list(self, regenerate=False): ) action_list = self.code_agent.act(code_file_path, **kwargs) return action_list - - def create_demo_action_list_with_self_correction(self, **kwargs): - logger.log_info( - f"Generate code for creating action list for {self.code_agent.task_name} with self correction.", - color="green", + + def to_dataset( + self, + id: str = None, + obs_list: list = None, + action_list: list = None, + ): + from embodichain.data.data_engine.data_dict_extractor import ( + fetch_imitation_dataset, ) - # Create log file name with timestamp - import datetime + from embodichain.lab.gym.robots.interface import LearnableRobot + + # Initialize curr_episode if not exists + if not hasattr(self, "curr_episode"): + self.curr_episode = 0 + + # Get episode data from env if not provided + if obs_list is None: + obs_list = getattr(self, "_episode_obs_list", []) + if action_list is None: + action_list = getattr(self, "_episode_action_list", []) + + if len(obs_list) == 0 or len(action_list) == 0: + logger.log_warning("No episode data found. Returning empty dataset.") + return { + "data_path": None, + "id": id, + "current_episode": self.curr_episode, + "data": None, + "save_path": None, + } - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - log_dir = ( - Path(database_agent_prompt_dir) - / self.code_agent.task_name - / "self_correction_logs" - / timestamp - ) - os.makedirs(log_dir, exist_ok=True) - img_dir = log_dir / "observation_images" - - kwargs.setdefault("env", self) - kwargs.setdefault("log_dir", log_dir) - kwargs.setdefault("file_path", log_dir / "agent_generated_code.py") - kwargs.setdefault("md_path", log_dir / "agent_llm_responses.md") - kwargs.setdefault("last_task_plan", "None.") - kwargs.setdefault("last_executed_failure", "None.") - kwargs.setdefault("last_executed_history", "None.") - - # TODO: rethink which part should be divided to task / code agents. Important! - # TODO: use the task agent to select which needs the validation (mainly interaction with the objects), not all steps. - # TODO: add logs - # TODO: maybe use a sequence of images for task planning - - step_id = 0 - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_1"], - save_dir=img_dir / "cam_1", - step_id=step_id, - ) - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_2"], - save_dir=img_dir / "cam_2", - step_id=step_id, - ) - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_3"], - save_dir=img_dir / "cam_3", - step_id=step_id, - ) + dataset_path = self.metadata["dataset"].get("save_path", None) + if dataset_path is None: + from embodichain.data import database_demo_dir - task_agent_input = self.task_agent.get_composed_observations(**kwargs) - code_agent_input = self.code_agent.get_composed_observations(**kwargs) - while True: - exec_code = [] - print(f"\033[94m\nStart task planning.\n\033[0m") - task_plan, plan_list, validation_list = ( - self.task_agent.generate_for_correction( - img_dir=img_dir / "cam_1", **task_agent_input - ) - ) + dataset_path = database_demo_dir - # TODO: maybe here I need to insert an error-occurred agent, calling some error-occurred apis, maybe with correction action too. - # TODO:maybe the validation agent can provide correction action, and no need to generate the subsequent full task by the task agent. + # TODO: create imitation dataset folder with name "{task_name}_{robot_type}_{num_episodes}" + from embodichain.lab.gym.utils.misc import camel_to_snake + + if not hasattr(self, "folder_name") or self.curr_episode == 0: + robot_class_name = self.robot.__class__.__name__ if hasattr(self, "robot") and self.robot is not None else "Robot" + self.folder_name = f"{camel_to_snake(self.__class__.__name__)}_{camel_to_snake(robot_class_name)}" + if os.path.exists(os.path.join(dataset_path, self.folder_name)): + self.folder_name = f"{self.folder_name}_{random.randint(0, 1000)}" + + return fetch_imitation_dataset( + self, obs_list, action_list, id, self.folder_name + ) - print(f"\033[92m\nStart code generation.\n\033[0m") - code_agent_input, code = self.code_agent.generate_according_to_task_plan( - task_plan=task_plan, **code_agent_input - ) - drive_list = extract_drive_calls(code) - for action_id, single_action in enumerate(drive_list): - try: - # ---------- execute ---------- - self.code_agent.act_single_action(single_action, **code_agent_input) - exec_success = True - exec_trace = None - - # # # # TODO: manually adjust the bottle pose for testing - # if step_id == 2: - # - # # pose = torch.tensor( - # # [[[0.99989, -0.00457, -0.01415, 0.72850], - # # [0.00457, 0.99999, -0.00041, -0.20441], - # # [0.01415, 0.00034, 0.99990, 0.92571], - # # [0.00000, 0.00000, 0.00000, 1.00000]]], - # # dtype=torch.float32 - # # ) - # # self.sim.get_rigid_object('bottle').set_local_pose(pose) - # - # pose = torch.tensor( - # [[[0.99989, -0.00457, -0.01415, 0.722850], - # [0.00457, 0.99999, -0.00041, 0.20441], - # [0.01415, 0.00034, 0.99990, 0.92571], - # [0.00000, 0.00000, 0.00000, 1.00000]]], - # dtype=torch.float32 - # ) - # self.sim.get_rigid_object('cup').set_local_pose(pose) - # - # # pose = self.sim.get_rigid_object('spoon').get_local_pose(to_matrix=True).squeeze(0) - # # pose[0, 3] = 0.6 - # # pose[1, 3] = -0.35 - # # pose[2, 3] = 0.8 - # # self.sim.get_rigid_object('spoon').set_local_pose(pose.unsqueeze(0)) - # - # for i in range(5): - # _ = self.step(action=self.robot.get_qpos()) - - except Exception: - exec_success = False - exec_trace = traceback.format_exc() - print(f"Execution failed:\n{exec_trace}") - - # ---------- step transition ---------- - step_id += 1 - - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_1"], - save_dir=img_dir / "cam_1", - step_id=step_id, - ) - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_2"], - save_dir=img_dir / "cam_2", - step_id=step_id, - ) - save_obs_image( - obs_image=self.get_obs_for_agent()["valid_rgb_3"], - save_dir=img_dir / "cam_3", - step_id=step_id, - ) - - # ---------- post-execution handling ---------- - if exec_success: - if code_agent_input.get("validation_agent"): - print( - f"\033[33mStarting validation with condition '{validation_list[action_id]}'!\033[0m" - ) - validation_info = self.validation_agent.validate_single_action( - single_action, - plan_list[action_id], - validation_list[action_id], - img_dir, - get_obj_position_info(self), - ) - - if "SUCCESS" in validation_info: - print(f"\033[33mValid info:\n{validation_info}\033[0m") - is_success = True - exec_code.append(plan_list[action_id]) - continue - else: - print(f"\033[31mValid info:\n{validation_info}\033[0m") - info = ( - "Validation Result: FAILED\n\n" - "Failed Step (currently executing step):\n" - f"{plan_list[action_id]}\n\n" - "Failure Analysis (why this step failed):\n" - f"{validation_info}" - ) - history = ( - "Executed History (previous steps):\n" - f"{format_execution_history(exec_code)}\n\n" - ) - is_success = False - else: - is_success = True - exec_code.append(plan_list[action_id]) - continue - else: - info = ( - "Action Execution: FAILED\n\n" - "Failed Step (currently executing step):\n" - f"{plan_list[action_id]}\n\n" - "Execution Error Trace:\n" - f"{exec_trace}\n\n" - "Note: You may try `force_valid=True` for the current action to find the nearest valid pose." - ) - history = ( - "Executed History (previous steps):\n" - f"{format_execution_history(exec_code)}\n\n" - ) - - is_success = False - - task_agent_input["last_task_plan"] = task_plan - task_agent_input["last_executed_failure"] = info - task_agent_input["last_executed_history"] = history - break - - if single_action == drive_list[-1] and is_success: - # ---------- termination ---------- - print( - "\033[91mExecuted all the plans. The task is considered complete.\033[0m" - ) - break From 7bbcda33d18d620beecdf7918ecba04305d6cbc6 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:08:44 +0800 Subject: [PATCH 33/49] Fix: module-level drive() calls --- embodichain/agents/hierarchy/code_agent.py | 103 +++++++++++++++++---- 1 file changed, 87 insertions(+), 16 deletions(-) diff --git a/embodichain/agents/hierarchy/code_agent.py b/embodichain/agents/hierarchy/code_agent.py index a3eac99..1aa5a1c 100644 --- a/embodichain/agents/hierarchy/code_agent.py +++ b/embodichain/agents/hierarchy/code_agent.py @@ -174,24 +174,95 @@ def generate(self, **kwargs): return file_path, kwargs, code_to_save def act(self, code_file_path, **kwargs): - # Dynamically import the generated function from the .py file - spec = importlib.util.spec_from_file_location( - "generated_function", code_file_path + """Execute generated code with proper execution environment. + + Supports two modes: + 1. If code defines 'create_agent_action_list' function, call it + 2. If code contains module-level drive() calls, execute them directly + """ + import ast + + # Read the generated code file + with open(code_file_path, "r") as f: + code_content = f.read() + + # Build execution namespace with necessary imports + ns = { + "__builtins__": __builtins__, + "__name__": "__main__", + "__file__": str(code_file_path), + "kwargs": kwargs, # Make kwargs available for injection + } + + # Import toolkit functions into namespace + try: + exec( + "from embodichain.toolkits.interfaces import *", + ns, + ns, + ) + except Exception as e: + raise RuntimeError( + "Failed to import embodichain.toolkits.interfaces" + ) from e + + # Parse code to check if it defines a function or contains module-level calls + tree = ast.parse(code_content) + + # Check if code defines create_agent_action_list function + has_function = any( + isinstance(node, ast.FunctionDef) and node.name == "create_agent_action_list" + for node in tree.body ) - generated_function_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(generated_function_module) - - # Ensure that the function exists and call it with kwargs - if hasattr(generated_function_module, "create_agent_action_list"): - result = generated_function_module.create_agent_action_list( - **kwargs - ) # Call the function with kwargs - print("Function executed successfully.") - return result + + if has_function: + # Execute code (function will be defined in namespace) + exec(code_content, ns, ns) + + # Call the function if it exists + if "create_agent_action_list" in ns: + result = ns["create_agent_action_list"](**kwargs) + print("Function executed successfully.") + return result + else: + raise AttributeError( + "The function 'create_agent_action_list' was not found after execution." + ) else: - raise AttributeError( - "The function 'create_agent_action_list' was not found in the generated code." - ) + # Code contains module-level drive() calls + # AST transformer to inject **kwargs into function calls + class InjectKwargs(ast.NodeTransformer): + def visit_Call(self, node): + self.generic_visit(node) + # Inject **kwargs if not present + has_kwargs = any( + kw.arg is None and isinstance(kw.value, ast.Name) and kw.value.id == "kwargs" + for kw in node.keywords + ) + if not has_kwargs: + node.keywords.append( + ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load())) + ) + return node + + # Transform AST to inject kwargs + tree = InjectKwargs().visit(tree) + ast.fix_missing_locations(tree) + + # Compile and execute transformed code + compiled_code = compile(tree, filename=str(code_file_path), mode="exec") + exec(compiled_code, ns, ns) + + # Collect actions from drive() calls if they were executed + # drive() function stores actions in env._episode_action_list + if "env" in kwargs: + env = kwargs["env"] + if hasattr(env, "_episode_action_list") and env._episode_action_list: + print(f"Collected {len(env._episode_action_list)} actions from module-level drive() calls.") + return env._episode_action_list + + print("Code executed successfully, but no actions were collected.") + return [] def build_feedback_message( self, last_code: str, last_error: str, last_observation: str = None From d95188d0f014234cc037c859f859c87e7d98d8b8 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:27:19 +0800 Subject: [PATCH 34/49] Fix: import direction of data_dict_extractor --- embodichain/data/data_engine/data_dict_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embodichain/data/data_engine/data_dict_extractor.py b/embodichain/data/data_engine/data_dict_extractor.py index e82f964..5c5f3c4 100644 --- a/embodichain/data/data_engine/data_dict_extractor.py +++ b/embodichain/data/data_engine/data_dict_extractor.py @@ -52,7 +52,7 @@ SUPPORTED_EXTRA_VISION_TYPES, ) from copy import deepcopy -from embodichain.lab.gym.envs.action_bank.configurable_action import ( +from embodichain.lab.gym.envs.action_bank.utils import ( get_control_part_joint_ids, ) From a2392bf982367904e63128108798863f5c5ff95e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:27:52 +0800 Subject: [PATCH 35/49] Fix: to_dataset function of base_agent_env --- embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index 1aa3d62..7aa1887 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -243,10 +243,6 @@ def to_dataset( from embodichain.lab.gym.robots.interface import LearnableRobot - # Initialize curr_episode if not exists - if not hasattr(self, "curr_episode"): - self.curr_episode = 0 - # Get episode data from env if not provided if obs_list is None: obs_list = getattr(self, "_episode_obs_list", []) @@ -258,7 +254,7 @@ def to_dataset( return { "data_path": None, "id": id, - "current_episode": self.curr_episode, + "current_episode": getattr(self, "curr_episode", 0), "data": None, "save_path": None, } From 038e35b0f22c67c99ad7b9d401111bb1ae72bf5f Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:28:23 +0800 Subject: [PATCH 36/49] Fix: argument of mul_linear_expand should be 2 --- embodichain/toolkits/interfaces.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/embodichain/toolkits/interfaces.py b/embodichain/toolkits/interfaces.py index 565d7aa..cba8e55 100644 --- a/embodichain/toolkits/interfaces.py +++ b/embodichain/toolkits/interfaces.py @@ -154,8 +154,8 @@ def plan_gripper_trajectory(env, is_left, sample_num, execute_open, ee_state_expand_select = np.array([open_state, close_state]) env.set_current_gripper_state_agent(close_state, is_left=is_left) - ee_state_expand_select = mul_linear_expand(ee_state_expand_select, [sample_num], include_endpoint=True) - + ee_state_expand_select = mul_linear_expand(ee_state_expand_select, [sample_num]) + select_qpos_traj.extend([select_arm_current_qpos] * sample_num) ee_state_list_select.extend(ee_state_expand_select) From 91ec67d8826149300e1b7188eeeef8f5796b4440 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:28:40 +0800 Subject: [PATCH 37/49] Migrate articulation entity --- embodichain/lab/sim/articulation_entity.py | 1391 ++++++++++++++++++++ 1 file changed, 1391 insertions(+) create mode 100644 embodichain/lab/sim/articulation_entity.py diff --git a/embodichain/lab/sim/articulation_entity.py b/embodichain/lab/sim/articulation_entity.py new file mode 100644 index 0000000..9ce21f5 --- /dev/null +++ b/embodichain/lab/sim/articulation_entity.py @@ -0,0 +1,1391 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +import typing +from typing import List, Tuple, Union, Dict, Any +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field + +import dexsim +from dexsim.models import Entity +from dexsim.engine import Articulation +from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag +from embodichain.utils import logger + +# Try to import DriveController, but make it optional +try: + from rlia.kit.drive_controllers import DriveController +except ImportError: + # If rlia is not available, use Any as a fallback type + DriveController = Any + +from dexsim.utility import inv_transform +from dexsim.utility.env_utils import load_first_environment + +__all__ = ["ArticulationEntity"] + + +@dataclass +class ArticulationPosition: + r"""Represents the position of an articulation in a robotic system. + + Attributes: + init_qpos (Union[np.ndarray, Dict[str, np.ndarray]]): + The initial joint positions of the articulation, which can be a + NumPy array or a dictionary mapping joint names to their initial + positions. + + init_base_xpos (Union[np.ndarray, Dict[str, np.ndarray]], optional): + The initial base position of the articulation, which can also be a + NumPy array or a dictionary mapping base names to their initial + positions. Defaults to None. + """ + + init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = field(default_factory=dict) + init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None + + +@dataclass +class ArticulationControl: + r"""Controls the behavior of an articulation in a robotic system. + + Attributes: + speed_ratio (float): + The ratio of speed for the articulation control. Default is 0.5. + + time_step (float): + The time step for control updates in seconds. Default is 0.02. + + drive_type (DriveType): + The type of drive used for the articulation control. Default is 'TARGET'. + """ + + speed_ratio: float = 0.5 + time_step: float = 0.02 + drive_type: "DriveType" = "TARGET" + + +@dataclass +class ArticulationJointConfiguration: + link_names: List[str] = field(default_factory=list) + joint_names: List[str] = field(default_factory=list) + + root_link_name: str = field(default_factory=dict) + end_link_name: str = field(default_factory=dict) + + +class ArticulationEntity(metaclass=ABCMeta): + r""" + Abstract class for articulation entity in simulation. + """ + + def __init__( + self, + urdf_path: Union[str, List[str]] = dict(), + init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = dict(), + init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None, + speed_ratio: float = 0.5, + time_step: float = 0.02, + drive_type: DriveType = DriveType.FORCE, + env: dexsim.environment.Arena = None, + **kwargs, + ): + r"""Initialize the articulation entity. + + Args: + urdf_path (str): urdf file path of robot + init_qpos (np.ndarray, optional): [dof] of double. Init robot joint state(home joint state). + init_base_xpos (np.ndarray, optional): [4, 4] of double. Robot base pose in arena coordinate system. + speed_ratio (float, optional): 0 ~ 1. Robot speed ratio. + time_step (float, optional): wait time between two update. Defaults to 1/50. + drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. Defaults to DriveType.FORCE. + env (Arena, optional): dexsim.environment.Arena. Load the first world(None defaults). + kwargs(optional): Accepts additional keyword arguments. + """ + # placeholder for articulations to be created to the robot. + # a robot can have multiple articulations, for example, + # 1. a arm with a gripper (manipulator) + # 2. two arms + # 3. mobile manipulator + self.articulation = None + + ## Additional variable for DualManipulator, Humanoids and DexterousHands: + # Dictionary to map child to its parent articulation "self.articulation" + self.child_articulations: Dict[str, Articulation] = dict() + + # URDF file path(s) for the robot + self.urdf_path = urdf_path + + # initial joint positions of the robot. + self.init_qpos = init_qpos + + # initial base pose of the robot in arena coordinate system. + self.init_base_xpos = init_base_xpos + + # Dictionary to store degrees of freedom for each articulation + self._dof: Dict[str, int] = dict() + + # Dictionary for actual control joint indices of articulations + self._joint_ids: Dict[str, np.ndarray] = dict() + + # self._actived_joint_names = dict() + + # TODO: Maybe turn to dict stored joint pos, vel, acc limits. + # List to store the limits for each joint's motion. + self._joint_limit = [] + + # placeholder for actors to attach to the robot. + self.attached_actors: Dict[str, Entity] = dict() + + # Dictionary to map control group names to their corresponding root link names, + # used for accessing the base position of each control group. + self.root_link_names: Dict[str] = kwargs.get("root_link_names", {}) + + # Dictionary to map control group names to their corresponding end link names, + # used for accessing the terminal position of each control group. + self.end_link_names: Dict[str] = kwargs.get("end_link_names", {}) + + # Speed ratio for the robot's movement + self.speed_ratio = speed_ratio + + # Time step for control updates + self.time_step = time_step + + # Validate and set the drive type + if drive_type not in [DriveType.FORCE, DriveType.FORCE]: + logger.log_error(f"Invalid drive type: {drive_type}.") + self.drive_type = drive_type + + # Dictionary to map child to its parent init_base_xpos "self.init_base_xpos" + self.child_init_base_xpos = dict() + + # Dictionaries for drive and task controllers + self.drive_controllers: Dict[str, DriveController] = dict() + + # Load the first environment if not provided + self._env, self._world = load_first_environment(env) + + def get_articulation(self, uid: str = None) -> dexsim.engine.Articulation: + r"""Get articulation based on its unique identifier (uid). + + This method returns the articulation associated with the provided uid. + If uid is not specified (None), it returns all articulations. If the + uid is invalid, a warning is logged, and None is returned. + + Args: + uid (str, optional): The unique identifier for the articulation. If None, all articulations will be returned. + + Returns: + dexsim.engine.Articulation or Dict: The articulation corresponding to the provided uid, or a dictionary of all articulations if uid is None. Returns None if the uid is invalid. + """ + + if uid is None or uid == self.uid: + return self.articulation + + if uid in self.child_articulations: + return self.child_articulations[uid] + else: + logger.log_warning( + f"Current uid {self.uid} cannot find the corresponding Articulation." + ) + return None + + def _setup_child_articulations(self, uid: str, control_parts: Dict): + r"""Initialize child articulations and establish a mapping between parent and child articulations. + + This method sets up child articulations associated with a parent articulation identified by its UID. + It verifies the existence of the parent articulation before proceeding to initialize the child articulations. + + Args: + uid (str): The unique identifier (UID) of the parent articulation. + control_parts (Dict): A dictionary of control parts to initialize as child articulations. + + Returns: + bool: True if the child articulations were successfully set up; False otherwise. + """ + # Use a list comprehension to filter valid control parts and log warnings for the invalid ones + control_parts_dict = {} + + # Check if the articulation is valid and if the provided UID matches the instance's UID + if self.articulation is None or uid != self.uid: + logger.log_warning(f"Articulation with UID '{uid}' not found.") + return False + + # Iterate over control parts to set up child articulations + for control_part in control_parts: + # Add to child articulations + control_parts_dict[control_part] = self.articulation + + # Establish the relationship between the child articulations and their parent + self.child_articulations = control_parts_dict + + return True + + @property + def default_physical_attrs(self) -> PhysicalAttr: + physical_attr = PhysicalAttr() + if self.drive_type == DriveType.FORCE: + physical_attr.static_friction = 1.0 + physical_attr.dynamic_friction = 0.9 + physical_attr.linear_damping = 0.7 + physical_attr.angular_damping = 0.7 + physical_attr.contact_offset = 0.005 + physical_attr.rest_offset = 0.001 + physical_attr.restitution = 0.05 + physical_attr.has_gravity = True + physical_attr.max_linear_velocity = 4000 + physical_attr.max_angular_velocity = 25 + physical_attr.max_depenetration_velocity = 1e1 + else: # DriveType.FORCE and so on + physical_attr.static_friction = 1.0 + physical_attr.dynamic_friction = 0.9 + physical_attr.linear_damping = 0.7 + physical_attr.angular_damping = 0.7 + physical_attr.contact_offset = 0.005 + physical_attr.rest_offset = 0.001 + physical_attr.restitution = 0.05 + physical_attr.has_gravity = False + physical_attr.max_linear_velocity = 1e6 + physical_attr.max_angular_velocity = 1e6 + physical_attr.max_depenetration_velocity = 1e1 + return physical_attr + + @property + def default_drive_param(self) -> Dict: + # Stiffness: + # Recommended range: 2000 N/m to 10000 N/m + # Note: Higher stiffness is suitable for tasks that require precise position control, + # such as gripping and assembly. You can start with 5000 N/m and fine-tune based on feedback from the actual application. + # Damping: + # Recommended range: 200 Ns/m to 1000 Ns/m + # Note: Damping values ​​should be high enough to dampen oscillations, + # but not too high to excessively hinder motion. You can start with 500 Ns/m and adjust based on dynamic performance. + # Max force: + # Recommended range: 10000 N to 100000 N + # Note: The maximum force should be set according to the load capacity of the robot arm + # to ensure that it does not exceed its load capacity when working. You can start with 50000 N, depending on the specific task load. + if self.drive_type == DriveType.FORCE: + param = {"stiffness": 2e3, "damping": 2e2, "max_force": 2e4} + elif self.drive_type == DriveType.FORCE: + param = {"stiffness": 1e8, "damping": 1e6, "max_force": 1e10} + return param + + def set_uid(self, uid: str) -> None: + r"""Set unique id of the robot. + + Args: + uid (str): Unique id of the robot. + """ + if uid == self.uid: + logger.log_warning( + f"The uid: {uid} is the same as the current: {self.uid}." + ) + else: + self.uid = uid + + def get_urdf_path(self) -> str: + r"""Provides the file path to the Unified Robot Description Format (URDF) file. + + Returns: + str: A string representing the file path to the robot's URDF file. + """ + return self.urdf_path + + def get_dof(self, name: str = None) -> Union[int, Dict[str, int]]: + r"""Get degree of freedom (DoF) of the robot. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + Union[int, Dict[str, int]]: + - If `name` is None, returns the total DoF of the robot as an integer. + - If `name` is provided and found, returns the DoF of the specified articulation as an integer. + - If `name` is provided but not found, logs a warning and returns 0. + """ + # TODO: Need to clarify behavior. + if name is None: + if isinstance(self._dof, dict): + return sum(self._dof.values()) + else: + return ( + self._dof + ) # Assuming _dof is an integer representing the total DoF + elif name in self._dof: + return self._dof[ + name + ] # Assuming _dof[name] is an integer representing the DoF of the specified articulation + + logger.log_warning(f"Articulation '{name}' not found.") + return 0 + + def _convert_pose(self, pose: np.ndarray, is_to_arena: bool) -> np.ndarray: + r"""Convert a given pose to the specified coordinate system. + + Args: + pose (np.ndarray): A [4, 4] transformation matrix representing the pose to be converted. + is_to_arena (bool): If True, convert to arena coordinate system; otherwise, convert to world coordinate system. + + Returns: + np.ndarray: A [4, 4] transformation matrix representing the pose in the specified coordinate system. + """ + if pose is None: + return np.eye(4) + + pose_array = np.array(pose) + + if pose_array.shape == (4, 4): + poses_to_convert = [pose_array] + elif pose_array.ndim == 3 and pose_array.shape[1:] == (4, 4): + poses_to_convert = pose_array + else: + logger.log_warning(f"Invalid shape for pose: {pose.shape}") + return np.eye(4) + + # Retrieve the world pose of the arena's root node + arena_root_pose = self._env.get_root_node().get_world_pose() + + # Determine the transformation logic based on the value of is_to_arena + if is_to_arena: + # Apply the inverse transformation to convert to the arena coordinate system + inv_arena_root_pose = np.linalg.inv(arena_root_pose) + converted_poses = [inv_arena_root_pose @ p for p in poses_to_convert] + else: + # Directly apply the transformation to convert to the world coordinate system + converted_poses = [arena_root_pose @ p for p in poses_to_convert] + + # Return the result in the same format as the input + if pose_array.shape == (4, 4): + return converted_poses[0] # Return single pose + else: + return np.array(converted_poses) # Return list/array of poses + + def set_joint_ids(self, joint_ids: np.ndarray, uid: str = None): + r"""Set joint IDs for the given UID. + + Args: + joint_ids (np.ndarray): Joint IDs to set. + uid (str, optional): The unique identifier for the joint. Defaults to None. + """ + uid = uid or self.uid + self._joint_ids[uid] = joint_ids + + def get_joint_ids(self, name: str = None) -> List: + r"""Gets joint IDs from the internal storage. + + Args: + name (str, optional): The name of the joint to look up. + If None, all joint IDs are returned. + + Returns: + List: A list of joint IDs associated with the specified name, + or a dictionary of all joint IDs if no name is given. + Returns an empty list if the name is not found. + """ + if name is None: + return {key: value for key, value in self._joint_ids.items()} + if name in self._joint_ids: + return self._joint_ids[name] + else: + logger.log_warning( + f"Joint ids with name '{name}' not found in self._joint_ids." + ) + return [] + + def get_joint_limits( + self, name: str = None + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get joint limits for the specified articulation. + + Args: + name (str): Name of the articulation. Defaults to None. + + Returns: + np.ndarray: [dof, 2] of float. Lower and upper joint limits. + Dict[str, np.ndarray]: [dof, 2] of float. Lower and upper joint limits for all articulations. + """ + limits = self.articulation.get_joint_limits() + + if name is None: + return limits + else: + if self.uid == name: + return limits[self._joint_ids[name]] + + if name not in self.child_articulations: + logger.log_warning(f"Articulation '{name}' not found.") + return None + return limits[self._joint_ids[name]] + + def get_link_names( + self, name: str = None + ) -> Union[List[str], Dict[str, List[str]]]: + r"""Gets the list of link names for a given articulation. + + Args: + name (str, optional): The name of the articulation. If None, returns link names for all articulations. + + Returns: + List[str]: A list of link names for the specified articulation if `name` is provided. + Dict[str, List[str]]: A dictionary mapping articulation names to their respective link name lists if `name` is None. + None: Returns None if the specified articulation name is not found. + """ + # todo: Articulation needs to distinguish between some parents and children. + link_names = self.articulation.get_link_names() + + if name is None or name == self.uid: + # Return a dictionary of link names for all articulations + return link_names + else: + if name in self.child_articulations: + return link_names[self._joint_ids[name]] + + def _get_link_velocity( + self, name: str = None, is_linear: bool = True, is_root: bool = False + ) -> Union[np.ndarray, None]: + r"""Get the link velocity of the specified articulation. + + Args: + name (str, optional): Name of the articulation. If None, retrieves velocities for all articulations. + is_linear (bool, optional): If True, retrieves linear velocity; otherwise, retrieves angular velocity. + is_root (bool, optional): If True, returns the root link velocity as a flattened array. + + Returns: + Union[np.ndarray, None]: Returns the velocity of the specified joint as a numpy array, or None if not found. + """ + + def _get_link_velocity_helper( + name: str, is_linear: bool = True, is_root: bool = False + ) -> typing.Optional[np.ndarray]: + """Helper function to get the link velocity for a specific articulation.""" + if name == self.uid: + link_general_vel = self.articulation.get_link_general_velocities() + link_velocity = ( + link_general_vel[:, :3] if is_linear else link_general_vel[:, 3:] + ) + return link_velocity[0].reshape(-1) if is_root else link_velocity + elif name in self.child_articulations: + link_general_vel = self.child_articulations[ + name + ].get_link_general_velocities() + link_velocity = ( + link_general_vel[:, :3] if is_linear else link_general_vel[:, 3:] + ) + return link_velocity[0].reshape(-1) if is_root else link_velocity + else: + return None + + if name is None: + link_velocity = _get_link_velocity_helper( + name=self.uid, is_linear=is_linear, is_root=is_root + ) + else: + link_velocity = _get_link_velocity_helper( + name=name, is_linear=is_linear, is_root=is_root + ) + + return link_velocity + + def get_body_link_linear_velocity( + self, + name: str = None, + ) -> Union[np.ndarray, None]: + r"""Get body link linear velocity in coordinate frame. + + Args: + name (str, optional): The name of the articulation. + If None, retrieves the velocity of all articulations. + + Returns: + Union[np.ndarray, None]: + If a name is provided, returns an array of shape [link_num, 3] + representing the linear velocity of the specified articulation. + If name is None, returns a dictionary mapping articulation names + to their corresponding linear velocities. + """ + return self._get_link_velocity(name=name, is_linear=True, is_root=False) + + def get_body_link_angular_velocity( + self, + name: str = None, + ) -> Union[np.ndarray, None]: + r"""Get body link angular velocity in coordinate frame. + + Args: + name (str, optional): The name of the articulation. + If None, retrieves the velocity of all articulations. + + Returns: + Union[np.ndarray, None]: + If a name is provided, returns an array of shape [link_num, 3] + representing the angular velocity of the specified articulation. + If name is None, returns a dictionary mapping articulation names + to their corresponding angular velocities. + """ + return self._get_link_velocity(name=name, is_linear=False, is_root=False) + + def get_root_link_linear_velocity( + self, + name: str = None, + ) -> Union[np.ndarray, None]: + r"""Get root link linear velocity in coordinate frame. + + Args: + name (str, optional): The name of the articulation. + If None, retrieves the velocity of all articulations. + + Returns: + Union[np.ndarray, None]: + If a name is provided, returns an array of shape [3] + representing the linear velocity of the root link. + If name is None, returns a dictionary mapping articulation names + to their corresponding linear velocities. + """ + return self._get_link_velocity(name=name, is_linear=True, is_root=True) + + def get_root_link_angular_velocity( + self, + name: str = None, + ) -> Union[np.ndarray, None]: + r"""Get root link angular velocity in coordinate frame. + + Args: + name (str, optional): The name of the articulation. + If None, retrieves the velocity of all articulations. + + Returns: + Union[np.ndarray, None]: + If a name is provided, returns an array of shape [3] + representing the angular velocity of the root link. + If name is None, returns a dictionary mapping articulation names + to their corresponding angular velocities. + """ + return self._get_link_velocity(name=name, is_linear=False, is_root=True) + + def _set_articulation_property( + self, + name: str, + property_name: str, + value: Union[np.ndarray, Dict[str, np.ndarray]], + use_params: bool = True, + **params, + ) -> bool: + r"""Helper function to set a property for a specific articulation. + + This function attempts to set a specified property (e.g., position, velocity) + for the articulation identified by 'name'. It first checks if the articulation + is a child articulation and then checks the main articulations. If the + articulation is found and the property exists, the function sets the property + with the provided value. + + Args: + name (str): The name of the articulation to set the property for. + property_name (str): The name of the property to set. + value (Union[np.ndarray, Dict[str, np.ndarray]]): The value to set the property to. + use_params (bool): Whether to use params when calling the property method. + + Returns: + bool: True if the property was successfully set, False otherwise. + """ + # Use self._joint_ids[name] if params is empty + if use_params and not params: + params = {"joint_ids": self._joint_ids[name]} + + # Check in child articulations first + if name in self.child_articulations: + child_articulation = self.child_articulations[name] + if hasattr(child_articulation, property_name): + # Call the property method with or without params + if use_params: + getattr(child_articulation, property_name)(value, **params) + else: + getattr(child_articulation, property_name)(value) + return True + + # Check the main articulation + if name == self.uid: + if hasattr(self.articulation, property_name): + # Call the property method with or without params + if use_params: + getattr(self.articulation, property_name)(value, **params) + else: + getattr(self.articulation, property_name)(value) + return True + + logger.log_warning(f"Articulation '{name}' not found.") + return False + + def get_current_xpos( + self, name: str = None, is_world_coordinates: bool = True + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get the current pose of the articulations. + + This method retrieves the current pose of specified articulation(s) in either world + or base coordinates. It handles both single articulations and hierarchical structures + with parent-child relationships. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + is_world_coordinates (bool, optional): + Whether to use the arena(world) coordinate system(WCS) or the Base + coordinate system(BCS). Defaults to True. + + Returns: + Union[np.ndarray, Dict[str, np.ndarray]]: + Returns the xpos for the specified articulation if `name` is provided and found. + If `name` is None, returns xpos for all articulations. + Returns None if `name` is provided but not found. + """ + + # Function to calculate the current position based on qpos + def calculate_xpos( + key: str, qpos: np.ndarray, parent_key: str = None + ) -> np.ndarray: + if key == self.uid: + articulation = self.articulation + else: + if key in self.child_articulations: + articulation = self.child_articulations.get(key, None) + if articulation is None: + return None # Articulation not found + + # Case 1: Use parent's drive controller for forward kinematics + if ( + parent_key + and (parent_key in self.drive_controllers) + and hasattr(self.drive_controllers[parent_key], "get_fk") + and (self.drive_controllers.get(key, None) is None) + ): + end_link_name = self.end_link_names.get(key, None) + if end_link_name is None: + end_link_index = -1 + else: + end_link_index = self.drive_controllers[parent_key].get_link_orders( + end_link_name + ) + + _, xpos = self.drive_controllers[parent_key].get_fk( + qpos, index=end_link_index + ) + # Case 2: Use articulation's own drive controller + elif (key in self.drive_controllers) and hasattr( + self.drive_controllers[key], "get_fk" + ): + if len(qpos) != self.drive_controllers[key]: + qpos = qpos[self._joint_ids[key]] + end_link_name = self.end_link_names.get(key) + if end_link_name is None: + end_link_index = -1 + else: + end_link_index = self.drive_controllers[key].get_link_orders( + end_link_name + ) + + _, xpos = self.drive_controllers[key].get_fk(qpos, index=end_link_index) + # Case 3: Fallback to direct world pose + else: + xpos = self._convert_pose( + articulation.get_world_pose(), is_to_arena=True + ) + return xpos + + # Get the base xpos for the articulation + # If parent_key exists, use it; otherwise use the current key + base_xpos = self.get_base_xpos(parent_key if parent_key else key) + + # Get initial transformation matrix, default to identity if not found + initial_xpos = self.init_base_xpos.get(key, np.eye(4)) + + if is_world_coordinates: + # Special handling for root links which require different transformation logic + if self.root_link_names.get(key, None) is not None: + if key not in self.drive_controllers: + # For articulations without drive controllers, + # transform using base transformation matrix + return base_xpos @ xpos + else: + # For articulations with drive controllers, + # get an up-to-date base transformation and apply it + root_base_xpos = self.get_base_xpos(key) + return root_base_xpos @ xpos + else: + # Handle non-root links + # TODO: judge by num of drive_controllers + return ( + (initial_xpos @ xpos) + if parent_key is not None + else (base_xpos @ xpos) + ) + + return xpos + + # If name is None, calculate for all articulations + if name is None: + current_xpos = {} + qpos = self.get_current_qpos(self.uid) # Get qpos once for all + + # Calculate for all main articulations + xpos = calculate_xpos(self.uid, qpos) + if xpos is not None: + current_xpos[self.uid] = xpos + + # Calculate for child articulations using parent drive controller + for child_key in self.child_articulations: + xpos = calculate_xpos(child_key, qpos, self.uid) + if xpos is not None: + current_xpos[child_key] = xpos + + return current_xpos + + # Check for articulation in child articulations + if name in self.child_articulations: + if self.uid in self._actived_joint_names: + xpos = calculate_xpos(name, self.get_current_qpos()[self.uid], self.uid) + else: + xpos = calculate_xpos(name, self.get_current_qpos(self.uid), self.uid) + if xpos is not None: + return xpos + + # Check for articulation in main articulation + xpos = calculate_xpos(name, self.get_current_qpos(name)) + if xpos is not None: + return xpos + + logger.log_warning(f"Articulation '{name}' not found.") + return None + + def get_base_xpos( + self, name: str = None, is_init: bool = False + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get current robot base pose in arena coordinate system. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + is_init (bool, optional): Init base xpos or current base xpos. Current base xpos defaults. + + Returns: + np.ndarray: Joint positions for the specified articulation if `name` is provided and found. + Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + + if is_init: + # Return initial base positions + return self.init_base_xpos.get(name) if name else self.init_base_xpos + + # Initialize a dictionary for current base positions + current_base_xpos_dict = {} + + # Get the current base xpos for the main articulation + base_xpos = self.articulation.get_link_pose(self.root_link_names[self.uid]) + current_base_xpos_dict[self.uid] = self._convert_pose( + base_xpos, is_to_arena=True + ) + + # Populate the dictionary with joint positions for all child articulations + for key in self.child_articulations: + if ( + self.root_link_names.get(key, None) + in self.child_articulations[key].get_link_names() + ): + child_base_xpos = self._get_articulation_property( + key, "get_link_pose", link_name=self.root_link_names[key] + ) + current_base_xpos_dict[key] = self._convert_pose( + child_base_xpos, is_to_arena=True + ) + + if name is None: + return current_base_xpos_dict + + # If a specific articulation name is provided + if name == self.uid: + return self._convert_pose(base_xpos, is_to_arena=True) + + # Get the base xpos for the specified articulation + current_base_xpos = self._get_articulation_property( + name, "get_link_pose", link_name=self.root_link_names[name] + ) + return self._convert_pose(current_base_xpos, is_to_arena=True) + + def set_base_xpos( + self, name: str = None, base_xpos: np.ndarray = np.eye(4) + ) -> None: + r"""Set the robot's base pose. + + Args: + name (str, optional): + Name of the articulation. If specified, the function will + apply the base pose to the articulation with this name. + Defaults to None, which means the base pose will be set for + the entire robot. + + base_xpos (np.ndarray, optional): + A [4, 4] matrix representing the transformation matrix that + defines the base pose of the robot. The matrix should + contain rotation and translation information. Defaults to + the identity matrix (np.eye(4)), indicating no change in pose. + """ + if base_xpos is None: + logger.log_warning("base_xpos is None, no action taken.") + return False + + if name is None or name == self.uid: + if isinstance(base_xpos, dict): + failed_cases = [] + for articulation_name, pos in base_xpos.items(): + if not self._set_articulation_property( + articulation_name, + "set_world_pose", + self._convert_pose(pos, is_to_arena=False), + False, + ): + failed_cases.append(articulation_name) + if failed_cases: + logger.log_warning( + f"Failed to set base xpos for articulations: {failed_cases}" + ) + return False + return True + elif isinstance(base_xpos, (list, np.ndarray)): + self._set_articulation_property( + name, + "set_world_pose", + self._convert_pose(base_xpos, is_to_arena=False), + False, + ) + return True + else: + logger.log_warning( + f"Expected base xpos to be dict for articulations, got {type(base_xpos)}." + ) + return False + else: + if isinstance(base_xpos, (list, np.ndarray)): + return self._set_articulation_property( + name, + "set_world_pose", + self._convert_pose(base_xpos, is_to_arena=False), + False, + ) + else: + logger.log_warning( + f"Expected base xpos to be np.ndarray for articulation '{name}', got {type(base_xpos)}." + ) + return False + + def get_current_joint_poses( + self, name: str = None + ) -> Union[List[np.ndarray], Dict[str, List[np.ndarray]]]: + r"""Get current robot joint poses. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + List[np.ndarray]: List of [4, 4]. Joint poses for the specified articulation if `name` is provided and found. + Dict[str, List[np.ndarray]]: Joint poses for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + name = name or self.uid + + if name == self.uid: + current_joint_poses = dict() + if hasattr(self.articulation, "get_joint_poses"): + current_joint_poses = self._convert_pose( + self.articulation.get_joint_poses(self._joint_ids[self.uid]), + is_to_arena=True, + ) + + return current_joint_poses + else: + if name in self.child_articulations: + logger.log_warning(f"Articulation {name} not found.") + return None + + return self._convert_pose( + self.child_articulations[name].get_joint_poses(self._joint_ids[name]), + is_to_arena=True, + ) + + def get_init_qpos(self, name: str = None) -> None: + r"""Get robot initial joint positions. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + np.ndarray: initial joint positions for the specified articulation if `name` is provided and found. + Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + if name is None: + return self.init_qpos + + if name in self.child_articulations or name == self.uid: + return self.init_qpos[name] + + logger.log_warning(f"Articulation {name} not found.") + return None + + def set_init_qpos( + self, name: str = None, qpos: Union[np.ndarray, Dict[str, np.ndarray]] = [] + ) -> None: + r"""Set initial joint positions. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + qpos (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot initial joint positions. + """ + if qpos is None: + logger.log_warning("qpos is None, no action taken.") + return + + if name is None or name == self.uid: + if isinstance(qpos, dict): + for articulation_name, pos in qpos.items(): + if articulation_name in self.init_qpos: + self.init_qpos[articulation_name] = pos + else: + logger.log_warning( + f"Articulation '{articulation_name}' not found in init_qpos." + ) + elif isinstance(qpos, (list, np.ndarray)): + self.init_qpos[self.uid] = qpos + else: + logger.log_warning( + f"Unsupported qpos type: {type(qpos)}, expected np.ndarray or dict." + ) + else: + if not isinstance(qpos, (list, np.ndarray)): + logger.log_warning( + f"Expected qpos to be np.ndarray for articulation '{name}', got {type(qpos)}." + ) + return + + if name in self.init_qpos: + self.init_qpos[name] = qpos + else: + logger.log_warning(f"Articulation '{name}' not found in init_qpos.") + + def _get_articulation_property( + self, name: str, property_name: str, **params + ) -> Union[np.ndarray, None]: + r"""Helper function to get a property for a specific articulation. + + This function retrieves the value of a specified property (e.g., position, + velocity) for the articulation identified by 'name'. It first checks if the + articulation is a main articulation and then checks child articulations. If + the articulation is found and the property exists, the function returns the + property's value. + + Args: + name (str): The name of the articulation to get the property from. + property_name (str): The name of the property to retrieve. + + Returns: + Union[np.ndarray, None]: The value of the property if found, None otherwise. + """ + # Use self._joint_ids[name] if params is empty + if not params: + if name in self._joint_ids: + params = {"joint_ids": self._joint_ids[name]} + else: + logger.log_warning(f"Joint_id '{name}' not found.") + has_similar_name = False + for key, val in self._joint_ids.items(): + if name in key: + params = {"joint_ids": val} + logger.log_warning(f"Joint_id '{key}' is used for {name}.") + name = key + has_similar_name = True + break + if not has_similar_name: + return None + + if name == self.uid: + return getattr(self.articulation, property_name)(**params) + + if len(self._joint_ids[name]): + if name in self.child_articulations: + child_articulation = self.child_articulations[name] + return getattr(child_articulation, property_name)(**params) + else: + return None + + logger.log_warning(f"Articulation '{name}' not found.") + return None + + def set_current_qpos( + self, name: str = None, qpos: Union[np.ndarray, Dict[str, np.ndarray]] = None + ): + r"""Set current robot joint positions. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + qpos (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot current joint positions. + + Returns: + bool: True if the positions were successfully set, False otherwise. + """ + if qpos is None: + logger.log_warning("qpos is None, no action taken.") + return False + + if name is None or name == self.uid: + if isinstance(qpos, dict): + failed_cases = [] + for articulation_name, pos in qpos.items(): + if not self._set_articulation_property( + articulation_name, "set_current_qpos", pos + ): + failed_cases.append(articulation_name) + if failed_cases: + logger.log_warning( + f"Failed to set qpos for articulations: {failed_cases}" + ) + return False + return True + elif isinstance(qpos, (list, np.ndarray)): + return self._set_articulation_property(name, "set_current_qpos", qpos) + else: + logger.log_warning( + f"Expected qpos to be dict for articulations, got {type(qpos)}." + ) + return False + else: + if isinstance(qpos, (list, np.ndarray)): + return self._set_articulation_property(name, "set_current_qpos", qpos) + else: + logger.log_warning( + f"Expected qpos to be np.ndarray for articulation '{name}', got {type(qpos)}." + ) + return False + + def get_current_qpos( + self, name: str = None + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get current robot joint positions. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + np.ndarray: Joint positions for the specified articulation if `name` is provided and found. + Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + # Validate the name parameter + if name is not None and not isinstance(name, str): + logger.log_warning( + f"The 'name' parameter must be a string or None, got {type(name)}." + ) + return None + + if name is None: + # Initialize a dictionary to hold joint positions for all articulations + current_qpos_dict = {} + + # Get the current joint positions for the main articulation + qpos = self.articulation.get_current_qpos() + current_qpos_dict[self.uid] = qpos + + # Populate the dictionary with joint positions for all child articulations + for key in self.child_articulations: + current_qpos_dict[key] = self._get_articulation_property( + key, "get_current_qpos" + ) + + return current_qpos_dict + else: + + return self._get_articulation_property(name, "get_current_qpos") + + def set_current_qvel( + self, name: str = None, qvel: Union[np.ndarray, Dict[str, np.ndarray]] = None + ): + r"""Set the current joint velocities of the robot. + + Args: + name (str, optional): + Name of the articulation. If None, the velocities will be set + for all articulations. + + qvel (Union[np.ndarray, Dict[str, np.ndarray]], optional): + Joint velocities. This can be a NumPy array for a single + articulation or a dictionary mapping articulation names to + their respective velocities. + + Returns: + bool: Returns True if the joint velocities were successfully set, + otherwise returns False if no action was taken or if there + were errors in the input. + """ + if qvel is None: + logger.log_warning("qvel is None, no action taken.") + return False + + if name is None or name == self.uid: + if isinstance(qvel, dict): + failed_cases = [] + for articulation_name, vel in qvel.items(): + if not self._set_articulation_property( + articulation_name, "set_current_qvel", vel + ): + failed_cases.append(articulation_name) + if failed_cases: + logger.log_warning( + f"Failed to set qvel for articulations: {failed_cases}" + ) + return False + return True + else: + logger.log_warning( + f"Expected qvel to be dict for articulations, got {type(qvel)}." + ) + return False + else: + if isinstance(qvel, (list, np.ndarray)): + return self._set_articulation_property(name, "set_current_qvel", qvel) + else: + logger.log_warning( + f"Expected qvel to be np.ndarray for articulation '{name}', got {type(qvel)}." + ) + return False + + def get_current_qvel( + self, name: str = None + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get current robot joint velocities. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + np.ndarray: Joint velocities for the specified articulation if `name` is provided and found. + Dict[str, np.ndarray]: Joint velocities for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + if name is None: + # Initialize a dictionary to hold joint velocities for all articulations + current_qvel_dict = {} + + # Get the current joint velocities for the main articulation + qvel = self.articulation.get_current_qvel() + # Store the velocity of the main articulation in the dictionary using its unique ID + current_qvel_dict[self.uid] = qvel + + # Iterate over child articulations to get their velocities + for key in self.child_articulations: + # Retrieve and store the joint velocity for the child articulation in the dictionary + current_qvel_dict[key] = self._get_articulation_property( + key, "get_current_qvel" + ) + + # Return the dictionary containing velocities for all articulations + return current_qvel_dict + else: + return self._get_articulation_property(name, "get_current_qvel") + + def set_current_qf( + self, + name: str = None, + qf: Union[np.ndarray, Dict[str, np.ndarray]] = None, + ): + r"""Set current robot joint force. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + qf (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot current joint force. + + """ + if qf is None: + logger.log_warning("joint_force is None, no action taken.") + return False + + if name is None: + if isinstance(qf, dict): + failed_cases = [] + for articulation_name, force in qf.items(): + if not self._set_articulation_property( + articulation_name, "set_current_qf", force + ): + failed_cases.append(articulation_name) + if failed_cases: + logger.log_warning( + f"Failed to set joint force for articulations: {failed_cases}" + ) + return False + return True + else: + logger.log_warning( + f"Expected joint_force to be dict for articulations, got {type(qf)}." + ) + return False + else: + if isinstance(qf, (list, np.ndarray)): + return self._set_articulation_property(name, "set_current_qf", qf) + else: + logger.log_warning( + f"Expected joint_force to be np.ndarray for articulation '{name}', got {type(qf)}." + ) + return False + + def get_current_qf( + self, name: str = None + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + r"""Get current robot joint force. + + Args: + name (str, optional): Name of the articulation. Defaults to None. + + Returns: + np.ndarray: Joint force for the specified articulation if `name` is provided and found. + Dict[str, np.ndarray]: Joint force for all articulations if `name` is None. + None: If `name` is provided but not found. + """ + if name is None: + # Initialize a dictionary to hold joint forces for all articulations + current_qf_dict = {} + + # Get the current joint forces for the main articulation + qvel = self.articulation.get_current_qvel() + # Store the velocity of the main articulation in the dictionary using its unique ID + current_qf_dict[self.uid] = qvel + + # Iterate over child articulations to get their forces + for key in self.child_articulations: + # Retrieve and store the joint velocity for the child articulation in the dictionary + current_qf_dict[key] = self._get_articulation_property( + key, "get_current_qvel" + ) + + # Return the dictionary containing forces for all articulations + return current_qf_dict + else: + return self._get_articulation_property(name, "get_current_qf") + + @staticmethod + def is_approx(qpos1: np.ndarray, qpos2: np.ndarray, eps: float = 1e-5): + r"""Evaluate whether qpos1 and qpos2 are 'close'. + + Args: + qpos1 (np.ndarray): a object of joint + qpos2 (np.ndarray): a object of other joint + + Returns: + bool: is close + """ + qpos1 = np.array(qpos1) + qpos2 = np.array(qpos2) + if qpos1.shape != qpos2.shape: + logger.log_warning( + "qpos1 shape {} does not match qpos2 shape {}, qpos1: {}, qpos2: {}.".format( + qpos1.shape, qpos2.shape, qpos1, qpos2 + ) + ) + return False + + dis = np.linalg.norm(qpos1 - qpos2, ord=1) + return dis < eps + + def create_physical_visible_node( + self, name: str, rgba: np.array = None, link_name: str = None + ) -> bool: + r"""Create a physical visible node for the articulation. + + Args: + name (str): + The name/identifier of the articulation to create the visible node for. + Must match either the main articulation's UID or a child articulation's name. + + rgba (np.ndarray, optional): + An array of 4 float values representing the RGBA color values: + - Red component (0.0 to 1.0) + - Green component (0.0 to 1.0) + - Blue component (0.0 to 1.0) + - Alpha/transparency (0.0 to 1.0) + Defaults to [0.0, 1.0, 0.0, 0.6] (semi-transparent green). + + link_name (str, optional): + The specific link name of the articulation to create the visible node for. + If None, visible nodes will be created for all links of the articulation. + Defaults to None. + + Returns: + bool: + True if the visible node was successfully created. + False if: + - The articulation name was not found + - The link name was invalid + - The creation process failed + """ + if rgba is None: + rgba = np.array([0.0, 1.0, 0.0, 0.6]) + else: + rgba = np.array(rgba) + + assert rgba.shape == (4,), "RGBA array must have 4 elements." + + # Prepare parameters for the node creation + params = {"rgba": rgba} + + # Add link_name to parameters if provided + if link_name is not None: + params["link_name"] = link_name + + # Check if the name matches the uid and create the node + if name == self.uid: + return self.articulation.create_physical_visible_node(**params) + elif name in self.child_articulations: + # Otherwise, create the node for the specified child articulation + return self.child_articulations[name].create_physical_visible_node(**params) + + logger.log_warning(f"Articulation '{name}' not found.") + return False + + def set_physical_visible( + self, + name: str, + is_physic_visible: bool, + is_render_body_visible: bool = True, + link_name: str = None, + ) -> bool: + r"""Set whether the current physical collision is visible. + + Args: + name (str): The name of the articulation. + is_physic_visible (bool): Whether the current physical node is visible. + is_render_body_visible (bool, optional): Whether the render body is visible. Defaults to True. + link_name (str, optional): The link name of the articulation. If None, set all articulation visible. Defaults to None. + + Returns: + bool: Returns True if the setting is successful, False otherwise. + """ + # Prepare parameters for setting visibility + params = { + "is_physic_visible": is_physic_visible, + "is_render_body_visible": is_render_body_visible, + } + + # Add link_name to parameters if provided + if link_name is not None: + params["link_name"] = link_name + + # Check if the name matches the uid and set visibility + if name == self.uid: + self.articulation.set_physical_visible(**params) + return True + + # Check if the name is in child articulations and set visibility for it + elif name in self.child_articulations: + self.child_articulations[name].set_physical_visible(**params) + return True + + # Log a warning if the articulation name is not found + logger.log_warning(f"Articulation '{name}' not found.") + return False From 50ef839dae7b708d74c7e48933526069d8f773ce Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:33:21 +0800 Subject: [PATCH 38/49] Remove correction functions of agent --- embodichain/agents/hierarchy/code_agent.py | 131 --------------------- embodichain/agents/hierarchy/task_agent.py | 38 ------ 2 files changed, 169 deletions(-) diff --git a/embodichain/agents/hierarchy/code_agent.py b/embodichain/agents/hierarchy/code_agent.py index 1aa5a1c..b528ba1 100644 --- a/embodichain/agents/hierarchy/code_agent.py +++ b/embodichain/agents/hierarchy/code_agent.py @@ -263,134 +263,3 @@ def visit_Call(self, node): print("Code executed successfully, but no actions were collected.") return [] - - def build_feedback_message( - self, last_code: str, last_error: str, last_observation: str = None - ) -> HumanMessage: - - useful_info = ( - "The error may be caused by:\n" - "1. You did not follow the basic background information, especially the world coordinate system with its xyz directions.\n" - "2. You did not take into account the NOTE given in the atomic actions or in the example functions.\n" - "3. You did not follow the steps of the task descriptions.\n" - ) - - # Optional observation section - observation_text = "" - if last_observation is not None: - observation_text = ( - "\nThe visual observation feedback of the execution process was:\n" - "```\n" + str(last_observation) + "\n```\n" - ) - - return HumanMessage( - content=( - "Your previously generated code was:\n" - "```\n" + last_code + "\n```\n\n" - "When this code was executed in the test environment, it failed with the following error:\n" - "```\n" - + last_error - + "```\n" - + observation_text - + "\n" - + useful_info - + "\nAnalyze the cause of the failure and produce a corrected version of the code. " - "Modify only what is necessary to fix the issue. The corrected code must:\n" - " - strictly use only the allowed atomic API functions,\n" - " - be executable and unambiguous,\n" - " - directly resolve the error shown above.\n\n" - "Your entire response must be EXACTLY one Python code block:\n" - "```python\n" - "# corrected solution code\n" - "```\n" - ) - ) - - def generate_according_to_task_plan(self, task_plan, **kwargs): - # Generate code via LLM - prompt = getattr(CodePrompt, self.prompt_name)(task_plan=task_plan, **kwargs) - - llm_code = self.llm.invoke(prompt) - llm_code = getattr(llm_code, "content", str(llm_code)) - - match = re.search(r"```python\n(.*?)\n```", llm_code, re.DOTALL) - if match: - llm_code = match.group(1).strip() - else: - llm_code = llm_code.strip() - - print(f"\033[92m\nCode agent output:\n{llm_code}\n\033[0m") - - return kwargs, llm_code - - def act_single_action(self, code: str, **kwargs): - import ast - - # ---- 0. Build execution namespace ---- - ns = { - "__builtins__": __builtins__, - "kwargs": kwargs, # visible for **kwargs injection - } - - # ---- 1. Executor-controlled import ---- - try: - exec( - "from embodichain.toolkits.interfaces import *", - ns, - ns, - ) - except Exception as e: - raise RuntimeError( - "Failed to import embodichain.toolkits.interfaces in act_single_action" - ) from e - - # ---- 2. Parse generated code ---- - tree = ast.parse(code) - body = tree.body - - # ---------- AST transformer: inject **kwargs everywhere ---------- - class InjectKwargs(ast.NodeTransformer): - def visit_Call(self, node): - self.generic_visit(node) - - # Check if **kwargs already exists - has_kwargs = any( - kw.arg is None - and isinstance(kw.value, ast.Name) - and kw.value.id == "kwargs" - for kw in node.keywords - ) - - if not has_kwargs: - node.keywords.append( - ast.keyword( - arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()) - ) - ) - - return node - - transformer = InjectKwargs() - - # ---- 3. Execute actions step by step ---- - for step_id, node in enumerate(body, start=1): - try: - node = transformer.visit(node) - ast.fix_missing_locations(node) - - step_mod = ast.Module([node], type_ignores=[]) - compiled = compile( - step_mod, filename=f"", mode="exec" - ) - - print( - f"\033[95m\nExecuting the current action {code} with **kwargs\033[0m" - ) - exec(compiled, ns, ns) - - except Exception as e: - raise RuntimeError( - f"Execution failed at step {step_id} with action {code}:\n{e}" - ) - - print("\033[95m\nThe current action step executed successfully.\033[0m") diff --git a/embodichain/agents/hierarchy/task_agent.py b/embodichain/agents/hierarchy/task_agent.py index fd7516c..8e2e81b 100644 --- a/embodichain/agents/hierarchy/task_agent.py +++ b/embodichain/agents/hierarchy/task_agent.py @@ -138,41 +138,3 @@ def generate(self, **kwargs) -> str: def act(self, *args, **kwargs): return super().act(*args, **kwargs) - def build_feedback_message( - self, last_plan: str, last_code: str, last_error: str - ) -> HumanMessage: - return HumanMessage( - content=( - "Your previous plan was:\n" - "```\n" + last_plan + "\n```\n\n" - "This plan led the code agent to generate the following code according to your plan:\n" - "```\n" + last_code + "\n```\n\n" - "When this code was executed in the test environment, it failed with the following error:\n" - "```\n" + last_error + "\n```\n\n" + USEFUL_INFO + "\n" - "Please analyze the failure, revise your plan, and provide sufficient instructions to correct the issue, " - "so that the code agent can generate a correct and executable solution based on your plan. " - "Your updated plan must strictly adhere to the atomic API functions and avoid ambiguous actions." - ) - ) - - def generate_for_correction(self, img_dir, **kwargs): - # Generate task plan via LLM - image_files = glob.glob(os.path.join(img_dir, "obs_step_*.png")) - if len(image_files) < 1: - raise ValueError("Need at least one observation images for validation.") - # sort by step index - image_files_sorted = sorted( - image_files, - key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), - ) - obs_image_path = image_files_sorted[-1] # the current image - prompt = getattr(TaskPrompt, self.prompt_name)( - obs_image_path=obs_image_path, **kwargs - ) - - response = self.llm.invoke(prompt).content - print(f"\033[94m\nTask agent output:\n{response}\n\033[0m") - - task_plan, plan_list, validation_list = extract_plan_and_validation(response) - - return task_plan, plan_list, validation_list From 0fc7d8f01f915cae708a06c075dee96f66744eab Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 18 Jan 2026 20:39:20 +0800 Subject: [PATCH 39/49] Add generated code and plan --- .../PourWaterAgent-v3/agent_generated_code.py | 59 +++++++++++++++++++ .../agent_generated_plan.txt | 27 +++++++++ .../agent_generated_code.py | 53 +++++++++++++++++ .../agent_generated_plan.txt | 19 ++++++ 4 files changed, 158 insertions(+) create mode 100644 embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py create mode 100644 embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py new file mode 100644 index 0000000..3399a23 --- /dev/null +++ b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py @@ -0,0 +1,59 @@ +# Step 1: Grasp the bottle +drive( + right_arm_action=grasp( + robot_name='right_arm', + obj_name='bottle', + pre_grasp_dis=0.10 + ), + left_arm_action=None +) + +# Step 2: Move the bottle to the pouring position relative to the cup +drive( + right_arm_action=move_relative_to_object( + robot_name='right_arm', + obj_name='cup', + x_offset=0.05, + y_offset=-0.10, + z_offset=0.125 + ), + left_arm_action=None +) + +# Step 3: Pour water into the cup +drive( + right_arm_action=rotate_eef( + robot_name='right_arm', + degree=-90 + ), + left_arm_action=None +) + +# Step 4: Return the bottle to its upright position +drive( + right_arm_action=rotate_eef( + robot_name='right_arm', + degree=90 + ), + left_arm_action=None +) + +# Step 5: Place the bottle at the specified location +drive( + right_arm_action=place_on_table( + robot_name='right_arm', + obj_name='bottle', + x=0.7, + y=-0.1, + pre_place_dis=0.08 + ), + left_arm_action=None +) + +# Step 6: Return the right arm to its initial pose +drive( + right_arm_action=back_to_initial_pose( + robot_name='right_arm' + ), + left_arm_action=None +) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt new file mode 100644 index 0000000..1eeaef8 --- /dev/null +++ b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt @@ -0,0 +1,27 @@ +**[PLANS]:** + +Step 1: Grasp the bottle — `grasp(robot_name='right_arm', obj_name='bottle', pre_grasp_dis=0.10)` + +Step 2: Move the bottle to the pouring position relative to the cup — `move_relative_to_object(robot_name='right_arm', obj_name='cup', x_offset=0.05, y_offset=-0.10, z_offset=0.125)` + +Step 3: Pour water into the cup — `rotate_eef(robot_name='right_arm', degree=-90)` + +Step 4: Return the bottle to its upright position — `rotate_eef(robot_name='right_arm', degree=90)` + +Step 5: Place the bottle at the specified location — `place_on_table(robot_name='right_arm', obj_name='bottle', x=0.7, y=-0.1, pre_place_dis=0.08)` + +Step 6: Return the right arm to its initial pose — `back_to_initial_pose(robot_name='right_arm')` + +**[VALIDATION_CONDITIONS]:** + +Step 1: The right arm should be holding the bottle securely. + +Step 2: The bottle should be positioned at an offset of [0.05, -0.10, 0.125] relative to the cup. + +Step 3: The bottle should be tilted, pouring water into the cup. + +Step 4: The bottle should be returned to an upright position, held by the right arm. + +Step 5: The bottle should be placed at the location [0.7, -0.1] on the table, and the right arm should release it. + +Step 6: The right arm should be in its initial pose, not holding any object. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py new file mode 100644 index 0000000..c8599a4 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py @@ -0,0 +1,53 @@ +# Step 1: Grasp the fork with the left arm and the spoon with the right arm +drive( + left_arm_action=grasp( + robot_name='left_arm', + obj_name='fork', + pre_grasp_dis=0.10 + ), + right_arm_action=grasp( + robot_name='right_arm', + obj_name='spoon', + pre_grasp_dis=0.10 + ) +) + +# Step 2: Reorient both end-effectors to a downward-facing pose +drive( + left_arm_action=orient_eef( + robot_name='left_arm', + direction='down' + ), + right_arm_action=orient_eef( + robot_name='right_arm', + direction='down' + ) +) + +# Step 3: Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center +drive( + left_arm_action=place_on_table( + robot_name='left_arm', + obj_name='fork', + x=0.0, + y=0.16, + pre_place_dis=0.08 + ), + right_arm_action=place_on_table( + robot_name='right_arm', + obj_name='spoon', + x=0.0, + y=-0.16, + pre_place_dis=0.08 + ) +) + +# Step 4: Return both arms to their initial poses +drive( + left_arm_action=back_to_initial_pose( + robot_name='left_arm' + ), + right_arm_action=back_to_initial_pose( + robot_name='right_arm' + ) +) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt new file mode 100644 index 0000000..d5ef2e7 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt @@ -0,0 +1,19 @@ +**[PLANS]:** + +Step 1: Grasp the fork with the left arm and the spoon with the right arm — `drive(left_arm_action=grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10), right_arm_action=grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10))` + +Step 2: Reorient both end-effectors to a downward-facing pose — `drive(left_arm_action=orient_eef(robot_name='left_arm', direction='down'), right_arm_action=orient_eef(robot_name='right_arm', direction='down'))` + +Step 3: Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center — `drive(left_arm_action=place_on_table(robot_name='left_arm', obj_name='fork', x=0.0, y=0.16, pre_place_dis=0.08), right_arm_action=place_on_table(robot_name='right_arm', obj_name='spoon', x=0.0, y=-0.16, pre_place_dis=0.08))` + +Step 4: Return both arms to their initial poses — `drive(left_arm_action=back_to_initial_pose(robot_name='left_arm'), right_arm_action=back_to_initial_pose(robot_name='right_arm'))` + +**[VALIDATION_CONDITIONS]:** + +Step 1: The left arm should be holding the fork, and the right arm should be holding the spoon. + +Step 2: Both end-effectors should be facing downward, with the fork and spoon still held. + +Step 3: The fork should be placed at y = +0.16, and the spoon should be placed at y = −0.16 relative to the plate’s center. Both arms should have released their objects. + +Step 4: Both arms should be in their initial poses, with no objects held. \ No newline at end of file From f9bbc1498b2ee27e5c74ba2fcd734909f677b908 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 19 Jan 2026 09:55:22 +0800 Subject: [PATCH 40/49] Remove the function of correction --- .../pour_water_agent_v3/agent_config.json | 2 +- .../agent_config_dual.json | 2 +- .../rearrangement_agent_v3/agent_config.json | 2 +- .../agents/hierarchy/validation_agent.py | 111 ------------------ embodichain/agents/mllm/prompt/task_prompt.py | 98 +--------------- .../envs/tasks/tableware/base_agent_env.py | 23 ---- 6 files changed, 7 insertions(+), 231 deletions(-) diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config.json b/configs/gym/agent/pour_water_agent_v3/agent_config.json index 3172f3c..db38818 100644 --- a/configs/gym/agent/pour_water_agent_v3/agent_config.json +++ b/configs/gym/agent/pour_water_agent_v3/agent_config.json @@ -1,5 +1,5 @@ { "TaskAgent": { - "prompt_name": "one_stage_prompt_for_correction" + "prompt_name": "one_stage_prompt" }, "CodeAgent": { "prompt_name": "one_stage_prompt_according_to_task_plan" diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json index 53e707c..415b7ac 100644 --- a/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json +++ b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json @@ -1,5 +1,5 @@ { "TaskAgent": { - "prompt_name": "one_stage_prompt_for_correction" + "prompt_name": "one_stage_prompt" }, "CodeAgent": { "prompt_name": "one_stage_prompt_according_to_task_plan" diff --git a/configs/gym/agent/rearrangement_agent_v3/agent_config.json b/configs/gym/agent/rearrangement_agent_v3/agent_config.json index 5394169..1907e98 100644 --- a/configs/gym/agent/rearrangement_agent_v3/agent_config.json +++ b/configs/gym/agent/rearrangement_agent_v3/agent_config.json @@ -1,5 +1,5 @@ { "TaskAgent": { - "prompt_name": "one_stage_prompt_for_correction" + "prompt_name": "one_stage_prompt" }, "CodeAgent": { "prompt_name": "one_stage_prompt_according_to_task_plan" diff --git a/embodichain/agents/hierarchy/validation_agent.py b/embodichain/agents/hierarchy/validation_agent.py index 7ad33a5..4cb1b99 100644 --- a/embodichain/agents/hierarchy/validation_agent.py +++ b/embodichain/agents/hierarchy/validation_agent.py @@ -217,114 +217,3 @@ def select_best_view_dir( raise ValueError(f"Invalid camera selection from LLM: {response}") return response - - def validate_single_action( - self, - current_action, - action_description, - valid_condition, - img_dir, - obj_position_info, - ): - # --- camera directories --- - img_dirs = { - "cam_1": img_dir / "cam_1", - "cam_2": img_dir / "cam_2", - "cam_3": img_dir / "cam_3", - } - - # === Stage 1: select best view === - selected_cam = self.select_best_view_dir( - img_dirs, action_description, valid_condition - ) - selected_dir = img_dirs[selected_cam] - print(f"\033[38;5;214mSelected camera for validation: {selected_cam}\033[0m") - - # === Stage 2: load FULL sequence from selected view === - image_files = glob.glob(os.path.join(selected_dir, "obs_step_*.png")) - if len(image_files) < 2: - raise ValueError("Need at least two observation images for validation.") - - # Sort images by step index - image_files_sorted = sorted( - image_files, - key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), - ) - - # Encode ALL images in sequence - encoded_images = [encode_image_from_path(p) for p in image_files_sorted] - - system_prompt = ( - "You are a helpful robot manipulation ACTION VALIDATOR.\n\n" - "ROLE:\n" - "- Judge ONLY the OBJECT-LEVEL outcome of ONE atomic action.\n" - "- Do NOT judge robot motion, planning, or execution quality.\n\n" - "CORE ASSUMPTIONS:\n" - "- The robot arm motion itself is correct by definition.\n" - "- Any failure must be due to incorrect OBJECT interaction or state.\n\n" - "COORDINATE RULE:\n" - "- The image is horizontally mirrored: image left ↔ robot right, image right ↔ robot left.\n" - "- Vertical direction is preserved.\n" - "- Use robot base frame terminology in your final judgment.\n\n" - "EVALUATION RULES:\n" - "- Focus on the FINAL image.\n" - "- Earlier images are context only.\n" - "- Do NOT infer numeric precision or motion quality.\n" - "- Ignore minor offsets or simulation noise.\n\n" - "DECISION POLICY:\n" - "- If visual evidence contradicts the expected object state → FAILURE.\n" - "- If visual evidence clearly matches the expected object state → SUCCESS.\n" - ) - - prompt = f""" - Validate the result of ONE atomic robot action. - - -------------------------------------------------- - ACTION - -------------------------------------------------- - {action_description} - - -------------------------------------------------- - EXPECTED OBJECT-LEVEL OUTCOME - -------------------------------------------------- - {valid_condition} - - -------------------------------------------------- - INPUT - -------------------------------------------------- - You are given an ordered image sequence. - - The FINAL image shows the state AFTER the action. - - -------------------------------------------------- - OUTPUT FORMAT (STRICT) - -------------------------------------------------- - Output EXACTLY one of the following. - - IMPORTANT: - - You MUST explicitly state the correctness of BOTH arms in the Evidence. - - [ACTION_SUCCESS] - - Evidence: - - [ACTION_FAILED] - - Reason: - - Evidence: - """ - - # Build multimodal message with ALL images - human_content = [{"type": "text", "text": prompt}] - for img_b64 in encoded_images: - human_content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{img_b64}"}, - } - ) - - messages = [ - SystemMessage(content=system_prompt), - HumanMessage(content=human_content), - ] - - llm_response = self.llm.invoke(messages) - return llm_response.content diff --git a/embodichain/agents/mllm/prompt/task_prompt.py b/embodichain/agents/mllm/prompt/task_prompt.py index 3ba547f..de85b8d 100644 --- a/embodichain/agents/mllm/prompt/task_prompt.py +++ b/embodichain/agents/mllm/prompt/task_prompt.py @@ -3,12 +3,13 @@ # # All rights reserved. # ---------------------------------------------------------------------------- +import torch +import numpy as np from langchain_core.messages import SystemMessage from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) -from embodichain.utils.utility import encode_image from embodichain.utils.utility import encode_image, encode_image_from_path @@ -21,7 +22,7 @@ def one_stage_prompt(observations, **kwargs): Step 2: LLM generates task instructions using only those IDs. """ # Encode image - observation = observations["rgb"] + observation = observations["rgb"].cpu().numpy() if isinstance(observations["rgb"], torch.Tensor) else observations["rgb"] kwargs.update({"observation": encode_image(observation)}) # Build hybrid prompt @@ -88,7 +89,7 @@ def two_stage_prompt(observations, **kwargs): ] ) - observation = observations["rgb"] + observation = observations["rgb"].cpu().numpy() if isinstance(observations["rgb"], torch.Tensor) else observations["rgb"] kwargs.update({"observation": encode_image(observation)}) # for LLM generate task descriptions prompt_query = ChatPromptTemplate.from_messages( @@ -123,94 +124,3 @@ def two_stage_prompt(observations, **kwargs): ) return [prompt.invoke(kwargs), {"prompt": prompt_query, "kwargs": kwargs}] - - @staticmethod - def one_stage_prompt_for_correction(obs_image_path, **kwargs): - """ - Hybrid one-pass prompt: - Step 1: VLM analyzes the image and extracts object IDs. - Step 2: LLM generates task instructions using only those IDs. - """ - # Encode image - kwargs.update({"observation": encode_image_from_path(obs_image_path)}) - - # Build hybrid prompt - prompt = ChatPromptTemplate.from_messages( - [ - SystemMessage( - content=( - "You are a robotic manipulation planner operating STRICTLY in the robot base coordinate frame.\n\n" - "COORDINATE FRAME RULE (NON-NEGOTIABLE):\n" - "- ALL spatial reasoning and motion descriptions (left/right/front/back/up/down, offsets, rotations)\n" - " are defined ONLY in the robot base coordinate frame, oriented from the base looking outward along +x (toward the end-effector).\n" - "- The camera is positioned in front of the robot, facing the arm and looking toward the robot base.\n" - "- Due to this viewpoint, the rendered image is HORIZONTALLY MIRRORED relative to the robot base frame.\n" - "- LEFT–RIGHT in the image MUST be inverted when reasoning:\n" - " * Image left → Robot right\n" - " * Image right → Robot left\n" - "- Vertical orientation is preserved:\n" - " * Image up → Robot up\n" - " * Image down → Robot down\n" - "- Always reason as if you are physically located at the robot base, facing along +x.\n" - "- For your output, you must use the robot base frame and explicitly account for this horizontal mirroring when interpreting the image " - "(e.g., What appears as “left” in the image corresponds to “right” in the robot base frame, and vice versa. " - "Vertical orientation is preserved: what appears as “up” in the image is also “up” in the robot base frame.).\n\n" - "HARD CONSTRAINT:\n" - "- Any reasoning based on image left/right, visual perspective, or camera orientation is VALID.\n" - "- If a direction cannot be inferred from the robot base frame, you must state it explicitly." - "- Each arm may execute at most one atomic action per step. If multiple atomic actions are required, " - "they must be distributed across multiple steps.\n" - "- Both arms may operate in the same step, but each arm may execute at most ONE atomic action per step. " - "If only one arm needs to act (e.g., a single-arm step or recovery), the other arm should remain idle.\n\n" - "TASK:\n" - "- Given the observation and task, produce a step-by-step plan using ONLY the provided atomic API.\n" - "- The plan must be executable without ambiguity.\n\n" - ) - ), - HumanMessagePromptTemplate.from_template( - [ - { - "type": "image_url", - "image_url": { - "url": "data:image/png;base64,{observation}", - }, - }, - { - "type": "text", - "text": ( - "Here is the latest camera observation.\n" - "IMPORTANT: The current image may NOT represent the initial state of the task. " - "It may correspond to an intermediate step where some actions have already been executed.\n\n" - "First, analyze the scene in the image to infer the current state.\n" - "Then, using the context below, produce the remaining actionable task plan from this state onward.\n\n" - "**Environment background:** \n" - "{basic_background}\n\n" - '**Task goal:** \n"' - '{task_prompt}"\n\n' - "**Available atomic actions:** \n" - "{atom_actions}\n" - "**Failed Task Plan (Reference)::**\n" - "{last_task_plan}\n\n" - "**Executed history (reference only):**\n" - "{last_executed_history}\n\n" - "**Most recent failure (CRITICAL):**\n" - "{last_executed_failure}\n\n" - "**REQUIRED OUTPUT**\n" - "[PLANS]:\n" - "Step 1: (...)\n" - "..." - "Step N: (...)\n\n" - "[VALIDATION_CONDITIONS]:\n" - "Step 1: \n" - "..." - "Step N: \n\n" - "VALIDATION_CONDITIONS MUST include the robot arm and relevant object(s), and whether the object(s) should be held or not.\n" - "Produce the COMPLETE remaining task plan." - ), - }, - ] - ), - ] - ) - - return prompt.invoke(kwargs) diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index 7aa1887..e901f50 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -185,29 +185,6 @@ def generate_code_for_actions(self, regenerate=False, **kwargs): # Task planning print(f"\033[92m\nStart task planning.\n\033[0m") - # Handle one_stage_prompt_for_correction which needs obs_image_path - if self.task_agent.prompt_name == 'one_stage_prompt_for_correction': - kwargs.setdefault("last_task_plan", "None.") - kwargs.setdefault("last_executed_failure", "None.") - kwargs.setdefault("last_executed_history", "None.") - - temp_img_dir = Path(tempfile.mkdtemp()) / "obs_images" - temp_img_dir.mkdir(parents=True, exist_ok=True) - - # Convert torch tensor to numpy array if needed - obs_image = self.get_obs_for_agent()["valid_rgb_1"] - if isinstance(obs_image, torch.Tensor): - obs_image = obs_image.cpu().numpy() - if obs_image.dtype in [np.float32, np.float64]: - obs_image = (obs_image * 255).astype(np.uint8) - - obs_image_path = save_obs_image( - obs_image=obs_image, - save_dir=temp_img_dir, - step_id=0 - ) - kwargs['obs_image_path'] = str(obs_image_path) - task_agent_input = self.task_agent.get_composed_observations( env=self, regenerate=regenerate, **kwargs ) From 11fc81838a47a86d2945a027462426c76862089b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 20 Jan 2026 10:32:50 +0800 Subject: [PATCH 41/49] Add script of running feedback agent --- embodichain/lab/scripts/run_agent_feedback.py | 340 +++++++++++++++ .../lab/scripts/run_agent_visual_feedback.py | 395 ++++++++++++++++++ 2 files changed, 735 insertions(+) create mode 100644 embodichain/lab/scripts/run_agent_feedback.py create mode 100644 embodichain/lab/scripts/run_agent_visual_feedback.py diff --git a/embodichain/lab/scripts/run_agent_feedback.py b/embodichain/lab/scripts/run_agent_feedback.py new file mode 100644 index 0000000..ba0f5a2 --- /dev/null +++ b/embodichain/lab/scripts/run_agent_feedback.py @@ -0,0 +1,340 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import gymnasium +import numpy as np +import argparse +import os +import torch +import json + +from threading import Thread +from tqdm import tqdm +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.lab.scripts.generate_video import visualize_data_dict +from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, +) +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import traceback + +def test_code(env, code_file_path, check_num=10, kwargs=None): + """Test the generated code multiple times and evaluate task success rate. + + Uses env.code_agent.act() to execute the code, which handles all the + necessary imports and execution logic. + """ + # ====== Read code content for display ====== + with open(code_file_path, "r", encoding="utf-8") as f: + code_content = f.read() + + # ====== Initialize kwargs ====== + if kwargs is None: + kwargs = {} + if "env" not in kwargs: + kwargs["env"] = env + + # ====== Initialize counters ====== + epid, suc_num, fail_num = 0, 0, 0 + run_records = [] + + # Error categories (same style as previous run() function) + error_list = [ + "Code can not run", # 0 + "Task executed but failed", # 1 + "No error occurred" # 2 + ] + error_num = [0, 0, 0] + + print("\033[93m" + "[Start Testing Task Success Rate]" + "\033[0m") + + # ====== Print generated source ====== + print("\n\033[92m=== generated source code ===\033[0m") + print(code_content) + print("\033[92m=== End ===\033[0m\n") + + # ====== Main loop ====== + for epid in range(check_num): + env.reset() + kwargs['current_check_num'] = epid + error_id = None + + try: + # Use code_agent.act() to execute the code + # This method handles all imports and execution logic + env.get_wrapper_attr("code_agent").act(code_file_path, **kwargs) + + # Check result + if env.get_wrapper_attr("is_task_success")().item(): + print(f"simulate data episode {suc_num} success! (seed = {epid})") + suc_num += 1 + run_records.append("Success!") + else: + print(f"simulate data episode {suc_num} fail! (seed = {epid})") + fail_num += 1 + error_id = 1 + run_records.append(error_list[1]) + + except Exception as e: + # Execution error + exec_trace = traceback.format_exc() + error_list[0] = exec_trace # store full traceback for summary + error_id = 0 + fail_num += 1 + + run_records.append(f"Code can not run, error: {exec_trace}") + + print("-------------") + print(f"simulate data episode {suc_num} fail! (seed = {epid})") + print("Error:", exec_trace) + print("-------------") + + # Count error category + if error_id is not None: + error_num[error_id] += 1 + + # ====== Find most frequent error ====== + if sum(error_num) == 0: + max_error_index = 2 # no errors, fallback to "NO error" + max_error_count = 0 + else: + max_error_index = error_num.index(max(error_num)) + max_error_count = error_num[max_error_index] + + # ====== Summary ====== + print(f'\nComplete test, success rate: {suc_num}/{check_num}') + print(f'Error message list: {error_list}') + print(f'Error count: {error_num}') + print(f'Run records: {run_records}') + + return suc_num / check_num, error_list[max_error_index], max_error_count, run_records + + +def generate_function( + env, + generated_codes, + error_messages, + log_dir=None, +): + # Initialize env + env.reset() + + # First attempt case - create initial code file + if len(error_messages) == 0: + code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")(regenerate=True, log_dir=log_dir) + # Generate code based on error status + else: + code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")( + regenerate=True, log_dir=log_dir, generated_codes=generated_codes, error_messages=error_messages) + + try: + # Update this section to match the new return values of the run function + success_rate, error_message, error_count, run_records = test_code(env, code_file_path, check_num=5, kwargs=kwargs) + generated_codes.append(code) + error_messages.append(error_message) + return code, success_rate, error_message, error_count, run_records + except KeyboardInterrupt: + print("Test interrupted by user") + return code, 0, "Test interrupted by user", 10, None + except Exception as e: + import traceback + error_trace = traceback.format_exc() + print(f"Error occurred during testing: {e}\n{error_trace}") + return code, 0, f"Error occurred during testing: {e}", 10, None + + +def main(args, env, gym_config): + + log_info("Start agent data generation with feedback.", color="green") + + # Initialize variables + generate_num = 5 + success_threshold = 0.6 + suc_list = [] + + # Store each round's code and error + error_messages = [] + generated_codes = [] + + # Store the best code and its success rate + best_code = None + best_success_rate = 0 + best_run_records = None + + # Create log file name with timestamp + import datetime + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + log_dir = Path(database_agent_prompt_dir) / args.task_name / "feedback_logs" / timestamp + os.makedirs(log_dir, exist_ok=True) + log_filename = f"{log_dir}/{args.task_name}.log" + + # Store all attempt records + all_attempts = [] + + # Try multiple generations until success or limit reached + for id in range(generate_num): + log_info(f"Generate code for task: {args.task_name} ({id + 1}/{generate_num})", color='green') + + # Generate and test code + code, success_rate, error_message, error_count, run_records = generate_function( + env, generated_codes, error_messages, log_dir) + + # Track success rates + suc_list.append(success_rate) + + # Record this attempt + attempt_record = { + "attempt_id": id + 1, + "success_rate": success_rate, + "error_message": error_message, + "error_count": error_count, + "code": code, + "run_records": run_records + } + all_attempts.append(attempt_record) + + # Save best code + if success_rate > best_success_rate: + best_success_rate = success_rate + best_code = code + best_run_records = run_records + print(f"New best code found, success rate: {best_success_rate}") + + # Check if generation was successful + if success_rate >= success_threshold: + print(f"Successfully generated code for task: {args.task_name}") + break + + # Handle failure case + log_warning(f"The generated code fail for task: {args.task_name} (attempt {id+1}) with succuss rate {success_rate}\nError message: \n{error_message}") + + # Ensure the final saved code is the best one + if best_code is not None: + file_name = log_dir / "agent_generated_code.py" + print(f"Saving best code, success rate: {best_success_rate}") + with open(file_name, 'w') as file: + file.write(best_code) + + print(f"Best success rate: {best_success_rate}") + print(f"All success rates: {suc_list}") + + # Save log data to file + with open(log_filename, 'w') as log_file: + log_data = { + "task_name": args.task_name, + "best_success_rate": best_success_rate, + "success_rates": suc_list, + "best_code": best_code, + "best_run_records": best_run_records, + "all_attempts": all_attempts + } + json.dump(log_data, log_file, indent=2) + + print(f"Log has been saved to: {log_filename}") + + if args.headless: + env.reset(options={"final": True}) + + +if __name__ == "__main__": + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--render_backend", + help="The rendering backend to use for the simulation.", + default="egl", + type=str, + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--save_video", + help="Whether to save data as video.", + default=False, + action="store_true", + ) + parser.add_argument( + "--save_path", help="path", default="./outputs/thirdviewvideo", type=str + ) + parser.add_argument( + "--debug_mode", + help="Enable debug mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + parser.add_argument("--online_config", type=str, help="online_config", default="") + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument("--task_name", type=str, help="Name of the task.", required=True) + + # Agent related configs + parser.add_argument("--agent_config", type=str, help="agent_config", default=None, required=True) + parser.add_argument("--regenerate", type=bool, help="Whether regenerate code if already existed.", default=False) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + agent_config = load_json(args.agent_config) + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make(id=gym_config["id"], cfg=cfg, agent_config=agent_config, task_name=args.task_name) + main(args, env, gym_config) diff --git a/embodichain/lab/scripts/run_agent_visual_feedback.py b/embodichain/lab/scripts/run_agent_visual_feedback.py new file mode 100644 index 0000000..50be9e7 --- /dev/null +++ b/embodichain/lab/scripts/run_agent_visual_feedback.py @@ -0,0 +1,395 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import gymnasium +import numpy as np +import argparse +import os +import torch +import json + +from threading import Thread +from tqdm import tqdm +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.lab.scripts.generate_video import visualize_data_dict +from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, +) +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import traceback +import glob + +def test_code(env, code_file_path, check_num=10, kwargs=None): + """Test the generated code multiple times and evaluate task success rate. + + Uses env.code_agent.act() to execute the code, which handles all the + necessary imports and execution logic. + """ + # ====== Read code content for display ====== + with open(code_file_path, "r", encoding="utf-8") as f: + code_content = f.read() + + # ====== Initialize kwargs ====== + if kwargs is None: + kwargs = {} + if "env" not in kwargs: + kwargs["env"] = env + + # ====== Initialize counters ====== + epid, suc_num, fail_num = 0, 0, 0 + run_records = [] + + # Error categories (same style as previous run() function) + error_list = [ + "Code can not run", # 0 + "Task executed but failed", # 1 + "No error occurred" # 2 + ] + error_num = [0, 0, 0] + + print("\033[93m" + "[Start Testing Task Success Rate]" + "\033[0m") + + # ====== Print generated source ====== + print("\n\033[92m=== generated source code ===\033[0m") + print(code_content) + print("\033[92m=== End ===\033[0m\n") + + # ====== Main loop ====== + for epid in range(check_num): + env.reset() + kwargs['current_check_num'] = epid + error_id = None + + try: + # Use code_agent.act() to execute the code + # This method handles all imports and execution logic + env.get_wrapper_attr("code_agent").act(code_file_path, **kwargs) + + # Check result + if env.get_wrapper_attr("is_task_success")().item(): + print(f"simulate data episode {epid} success!") + suc_num += 1 + run_records.append("Success!") + else: + print(f"simulate data episode {epid} fail!") + fail_num += 1 + error_id = 1 + run_records.append(error_list[1]) + + except Exception as e: + # Execution error + exec_trace = traceback.format_exc() + error_list[0] = exec_trace # store full traceback for summary + error_id = 0 + fail_num += 1 + + run_records.append(f"Code can not run, error: {exec_trace}") + + print("-------------") + print(f"simulate data episode {epid} fail!") + print("Error:", exec_trace) + print("-------------") + + # Count error category + if error_id is not None: + error_num[error_id] += 1 + + # ====== Find most frequent error ====== + if sum(error_num) == 0: + max_error_index = 2 # no errors, fallback to "NO error" + max_error_count = 0 + else: + max_error_index = error_num.index(max(error_num)) + max_error_count = error_num[max_error_index] + + # ====== Observe at the most frequently occurred error ====== + observation_feedback = None + if max_error_count > 0: + observe_index = 0 + highest_priority = len(error_list) + + for i, record in enumerate(run_records): + if record == "Success!": + continue + + current_priority = len(error_list) + for p, error_pattern in enumerate(error_list): + if error_pattern in record: + current_priority = p + break + + if current_priority < highest_priority: + highest_priority = current_priority + observe_index = i + + if highest_priority == len(error_list) and len(run_records) > 0: + observe_index = 0 + + print(f"Selected observation index observe_index={observe_index}, corresponding error: {run_records[observe_index]}") + + log_dir = kwargs["log_dir"] # require log_dir + gen_id = kwargs.get("id", "unknown") # fallback to a safe string + episode_id = observe_index + save_dir = log_dir / "camera_images" / f"{gen_id}_generate_num" / f"episode{episode_id}" + print(f"Looking for images in: {save_dir}") + + image_files = sorted(glob.glob(os.path.join(save_dir, f"*.png"))) + + # Extract step names from image filenames + step_names = [] + for f in image_files: + filename = os.path.basename(f) + first_underscore_pos = filename.find('_') + if first_underscore_pos != -1: + step_name = filename[first_underscore_pos + 1:].rsplit('.', 1)[0] + step_names.append(step_name) + else: + step_names.append(filename.rsplit('.', 1)[0]) + print(f"Image search pattern: episode{episode_id}_*.png, number of files found: {len(image_files)}") + + observation_feedback = env.get_wrapper_attr("validation_agent").validate(step_names, code_content, error_list[observe_index], image_files) + log_info(f"Observation feedback: {observation_feedback}") + + # ====== Summary ====== + print(f'\nComplete test, success rate: {suc_num}/{check_num}') + print(f'Error message list: {error_list}') + print(f'Error count: {error_num}') + print(f'Run records: {run_records}') + + return suc_num / check_num, error_list[max_error_index], observation_feedback, max_error_count, run_records + + +def generate_function( + env, + generated_codes, + error_messages, + observation_feedbacks, + log_dir=None, + id=0, +): + # Initialize env + env.reset() + + # First attempt case - create initial code file + if len(error_messages) == 0: + code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")(regenerate=True, log_dir=log_dir, id=id) + # Generate code based on error status + else: + code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")( + regenerate=True, log_dir=log_dir, generated_codes=generated_codes, error_messages=error_messages, + observation_feedbacks=observation_feedbacks, id=id) + + try: + # Update this section to match the new return values of the run function + success_rate, error_message, observation_feedback, error_count, run_records = test_code(env, code_file_path, check_num=5, kwargs=kwargs) + generated_codes.append(code) + error_messages.append(error_message) + observation_feedbacks.append(observation_feedback) + return code, success_rate, error_message, observation_feedback, error_count, run_records + except KeyboardInterrupt: + print("Test interrupted by user") + return code, 0, "Test interrupted by user", None, 10, None + except Exception as e: + import traceback + error_trace = traceback.format_exc() + print(f"Error occurred during testing: {e}\n{error_trace}") + return code, 0, f"Error occurred during testing: {e}", None, 10, None + + +def main(args, env, gym_config): + + log_info("Start agent data generation with visual feedback.", color="green") + + # Initialize variables + generate_num = 5 + success_threshold = 0.6 + suc_list = [] + + # Store each round's code and error + generated_codes = [] + error_messages = [] + observation_feedbacks = [] + + # Store the best code and its success rate + best_code = None + best_success_rate = 0 + best_run_records = None + + # Create log file name with timestamp + import datetime + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + log_dir = Path(database_agent_prompt_dir) / args.task_name / "visual_feedback_logs" / timestamp + os.makedirs(log_dir, exist_ok=True) + log_filename = f"{log_dir}/{args.task_name}.log" + + # Store all attempt records + all_attempts = [] + + # Try multiple generations until success or limit reached + for id in range(generate_num): + log_info(f"Generate code for task: {args.task_name} ({id + 1}/{generate_num})", color='green') + + # Generate and test code + code, success_rate, error_message, observation_feedback, error_count, run_records = generate_function( + env, generated_codes, error_messages, observation_feedbacks, log_dir, id=id+1) + + # Track success rates + suc_list.append(success_rate) + + # Record this attempt + attempt_record = { + "attempt_id": id + 1, + "success_rate": success_rate, + "error_message": error_message, + "observation_feedback": observation_feedback, + "error_count": error_count, + "code": code, + "run_records": run_records + } + all_attempts.append(attempt_record) + + # Save best code + if success_rate > best_success_rate: + best_success_rate = success_rate + best_code = code + best_run_records = run_records + print(f"New best code found, success rate: {best_success_rate}") + + # Check if generation was successful + if success_rate >= success_threshold: + print(f"Successfully generated code for task: {args.task_name}") + break + + # Handle failure case + log_warning(f"The generated code fail for task: {args.task_name} (attempt {id+1}) with succuss rate {success_rate}\nError message: \n{error_message}") + + # Ensure the final saved code is the best one + if best_code is not None: + file_name = log_dir / "agent_generated_code.py" + print(f"Saving best code, success rate: {best_success_rate}") + with open(file_name, 'w') as file: + file.write(best_code) + + print(f"Best success rate: {best_success_rate}") + print(f"All success rates: {suc_list}") + + # Save log data to file + with open(log_filename, 'w') as log_file: + log_data = { + "task_name": args.task_name, + "best_success_rate": best_success_rate, + "success_rates": suc_list, + "best_code": best_code, + "best_run_records": best_run_records, + "all_attempts": all_attempts + } + json.dump(log_data, log_file, indent=2) + + print(f"Log has been saved to: {log_filename}") + + if args.headless: + env.reset(options={"final": True}) + + +if __name__ == "__main__": + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--render_backend", + help="The rendering backend to use for the simulation.", + default="egl", + type=str, + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--save_video", + help="Whether to save data as video.", + default=False, + action="store_true", + ) + parser.add_argument( + "--save_path", help="path", default="./outputs/thirdviewvideo", type=str + ) + parser.add_argument( + "--debug_mode", + help="Enable debug mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + parser.add_argument("--online_config", type=str, help="online_config", default="") + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument("--task_name", type=str, help="Name of the task.", required=True) + + # Agent related configs + parser.add_argument("--agent_config", type=str, help="agent_config", default=None, required=True) + parser.add_argument("--regenerate", type=bool, help="Whether regenerate code if already existed.", default=False) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + agent_config = load_json(args.agent_config) + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make(id=gym_config["id"], cfg=cfg, agent_config=agent_config, task_name=args.task_name) + main(args, env, gym_config) From 4092d3169cb1e63936629a09682c655c8898f82e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 20 Jan 2026 11:05:56 +0800 Subject: [PATCH 42/49] Keep components in sim the same --- .../20260120_105727/agent_generated_code.py | 62 + .../20260120_105727/agent_generated_plan.txt | 38 + .../20260120_110033/agent_generated_code.py | 58 + .../20260120_110033/agent_generated_plan.txt | 49 + .../20260120_110212/agent_generated_code.py | 53 + .../20260120_110212/agent_generated_plan.txt | 47 + embodichain/lab/gym/robots/interface.py | 2 +- embodichain/lab/sim/articulation_entity.py | 1391 -------------- embodichain/lab/sim/end_effector/__init__.py | 9 - .../lab/sim/end_effector/end_effector.py | 552 ------ embodichain/lab/sim/end_effector/utility.py | 148 -- embodichain/lab/sim/robots/__init__.py | 1 - embodichain/lab/sim/robots/robot.py | 1177 ------------ embodichain/lab/sim/utility/sim_utils copy.py | 285 --- .../lab/sim/utility/workspace_analyzer_new.py | 1617 ----------------- embodichain/toolkits/interfaces.py | 3 +- 16 files changed, 309 insertions(+), 5183 deletions(-) create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py create mode 100644 embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt delete mode 100644 embodichain/lab/sim/articulation_entity.py delete mode 100644 embodichain/lab/sim/end_effector/__init__.py delete mode 100644 embodichain/lab/sim/end_effector/end_effector.py delete mode 100644 embodichain/lab/sim/end_effector/utility.py delete mode 100644 embodichain/lab/sim/robots/robot.py delete mode 100644 embodichain/lab/sim/utility/sim_utils copy.py delete mode 100644 embodichain/lab/sim/utility/workspace_analyzer_new.py diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py new file mode 100644 index 0000000..c4fd2fd --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py @@ -0,0 +1,62 @@ +# Step 1 — Grasp the Fork and Spoon Simultaneously +drive( + left_arm_action=grasp( + robot_name='left_arm', + obj_name='fork', + pre_grasp_dis=0.10 + ), + right_arm_action=grasp( + robot_name='right_arm', + obj_name='spoon', + pre_grasp_dis=0.10 + ) +) + +# Step 2 — Reorient End-Effectors to Downward-Facing Pose +drive( + left_arm_action=orient_eef( + robot_name='left_arm', + direction='down' + ), + right_arm_action=orient_eef( + robot_name='right_arm', + direction='down' + ) +) + +# Step 3 — Place the Fork and Spoon on Opposite Sides of the Plate +drive( + left_arm_action=move_relative_to_object( + robot_name='left_arm', + obj_name='plate', + x_offset=0, + y_offset=0.16, + z_offset=0 + ), + right_arm_action=move_relative_to_object( + robot_name='right_arm', + obj_name='plate', + x_offset=0, + y_offset=-0.16, + z_offset=0 + ) +) + +drive( + left_arm_action=open_gripper( + robot_name='left_arm' + ), + right_arm_action=open_gripper( + robot_name='right_arm' + ) +) + +# Step 4 — Return Both Arms to Initial Pose +drive( + left_arm_action=back_to_initial_pose( + robot_name='left_arm' + ), + right_arm_action=back_to_initial_pose( + robot_name='right_arm' + ) +) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt new file mode 100644 index 0000000..a53d6c6 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt @@ -0,0 +1,38 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned on either side of the table, ready to interact with the objects. + +### Task Plan + +1. **Grasp the Fork and Spoon Simultaneously** + - **Left Arm (Fork):** + - Use `grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10)` to grasp the fork. + - **Right Arm (Spoon):** + - Use `grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10)` to grasp the spoon. + - Execute: `drive(left_arm_action=grasp_left, right_arm_action=grasp_right)` + +2. **Reorient End-Effectors to Downward-Facing Pose** + - **Left Arm:** + - Use `orient_eef(robot_name='left_arm', direction='down')` to reorient the left arm. + - **Right Arm:** + - Use `orient_eef(robot_name='right_arm', direction='down')` to reorient the right arm. + - Execute: `drive(left_arm_action=orient_left, right_arm_action=orient_right)` + +3. **Place the Fork and Spoon on Opposite Sides of the Plate** + - **Left Arm (Fork):** + - Use `move_relative_to_object(robot_name='left_arm', obj_name='plate', x_offset=0, y_offset=0.16, z_offset=0)` to position the fork. + - Use `open_gripper(robot_name='left_arm')` to release the fork. + - **Right Arm (Spoon):** + - Use `move_relative_to_object(robot_name='right_arm', obj_name='plate', x_offset=0, y_offset=-0.16, z_offset=0)` to position the spoon. + - Use `open_gripper(robot_name='right_arm')` to release the spoon. + - Execute: `drive(left_arm_action=move_left, right_arm_action=move_right)` + - Execute: `drive(left_arm_action=open_left, right_arm_action=open_right)` + +4. **Return Both Arms to Initial Pose** + - **Left Arm:** + - Use `back_to_initial_pose(robot_name='left_arm')` to return the left arm. + - **Right Arm:** + - Use `back_to_initial_pose(robot_name='right_arm')` to return the right arm. + - Execute: `drive(left_arm_action=back_left, right_arm_action=back_right)` + +This plan ensures the fork and spoon are rearranged on opposite sides of the plate, with the robot arms returning to their initial positions at the end. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py new file mode 100644 index 0000000..c099786 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py @@ -0,0 +1,58 @@ +# Step 1 — Grasp the fork and spoon +drive( + left_arm_action=grasp( + robot_name='left_arm', + obj_name='fork', + pre_grasp_dis=0.10 + ), + right_arm_action=grasp( + robot_name='right_arm', + obj_name='spoon', + pre_grasp_dis=0.10 + ) +) + +# Step 2 — Reorient end-effectors to downward-facing pose +drive( + left_arm_action=orient_eef( + robot_name='left_arm', + direction='down' + ), + right_arm_action=orient_eef( + robot_name='right_arm', + direction='down' + ) +) + +# Step 3 — Place fork and spoon on opposite sides of the plate +drive( + left_arm_action=move_relative_to_object( + robot_name='left_arm', + obj_name='plate', + y_offset=0.16 + ), + right_arm_action=move_relative_to_object( + robot_name='right_arm', + obj_name='plate', + y_offset=-0.16 + ) +) + +drive( + left_arm_action=open_gripper( + robot_name='left_arm' + ), + right_arm_action=open_gripper( + robot_name='right_arm' + ) +) + +# Step 4 — Return arms to initial pose +drive( + left_arm_action=back_to_initial_pose( + robot_name='left_arm' + ), + right_arm_action=back_to_initial_pose( + robot_name='right_arm' + ) +) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt new file mode 100644 index 0000000..185df53 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt @@ -0,0 +1,49 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned with the left arm near the fork and the right arm near the spoon. + +### Task Plan + +1. **Grasp the Fork and Spoon:** + - Use the left arm to grasp the fork. + - Use the right arm to grasp the spoon. + + ```python + left_grasp_action = grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10) + right_grasp_action = grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10) + drive(left_arm_action=left_grasp_action, right_arm_action=right_grasp_action) + ``` + +2. **Reorient End-Effectors to Downward-Facing Pose:** + - Reorient both end-effectors to face downward. + + ```python + left_orient_action = orient_eef(robot_name='left_arm', direction='down') + right_orient_action = orient_eef(robot_name='right_arm', direction='down') + drive(left_arm_action=left_orient_action, right_arm_action=right_orient_action) + ``` + +3. **Place Fork and Spoon on Opposite Sides of the Plate:** + - Place the fork at y = +0.16 relative to the plate’s center. + - Place the spoon at y = −0.16 relative to the plate’s center. + + ```python + left_place_action = move_relative_to_object(robot_name='left_arm', obj_name='plate', y_offset=0.16) + right_place_action = move_relative_to_object(robot_name='right_arm', obj_name='plate', y_offset=-0.16) + drive(left_arm_action=left_place_action, right_arm_action=right_place_action) + + left_release_action = open_gripper(robot_name='left_arm') + right_release_action = open_gripper(robot_name='right_arm') + drive(left_arm_action=left_release_action, right_arm_action=right_release_action) + ``` + +4. **Return Arms to Initial Pose:** + - Return both arms to their initial configurations. + + ```python + left_initial_action = back_to_initial_pose(robot_name='left_arm') + right_initial_action = back_to_initial_pose(robot_name='right_arm') + drive(left_arm_action=left_initial_action, right_arm_action=right_initial_action) + ``` + +This plan ensures the fork and spoon are rearranged on opposite sides of the plate, with the fork on the left and the spoon on the right, relative to the robot base frame. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py new file mode 100644 index 0000000..65dc8b5 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py @@ -0,0 +1,53 @@ +# Step 1 — Grasp the Fork and Spoon Simultaneously +drive( + left_arm_action=grasp( + robot_name='left_arm', + obj_name='fork', + pre_grasp_dis=0.10 + ), + right_arm_action=grasp( + robot_name='right_arm', + obj_name='spoon', + pre_grasp_dis=0.10 + ) +) + +# Step 2 — Reorient Both End-Effectors to a Downward-Facing Pose +drive( + left_arm_action=orient_eef( + robot_name='left_arm', + direction='down' + ), + right_arm_action=orient_eef( + robot_name='right_arm', + direction='down' + ) +) + +# Step 3 — Place the Fork and Spoon at Specified Positions +drive( + left_arm_action=place_on_table( + robot_name='left_arm', + obj_name='fork', + x=None, + y=0.16, + pre_place_dis=0.08 + ), + right_arm_action=place_on_table( + robot_name='right_arm', + obj_name='spoon', + x=None, + y=-0.16, + pre_place_dis=0.08 + ) +) + +# Step 4 — Return Both Arms to Initial Pose +drive( + left_arm_action=back_to_initial_pose( + robot_name='left_arm' + ), + right_arm_action=back_to_initial_pose( + robot_name='right_arm' + ) +) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt new file mode 100644 index 0000000..7d5b20e --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt @@ -0,0 +1,47 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned above the table, ready to interact with the objects. + +### Task Plan + +1. **Grasp the Fork and Spoon Simultaneously:** + - Use the left arm to grasp the fork. + - Use the right arm to grasp the spoon. + + ```python + left_grasp_action = grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10) + right_grasp_action = grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10) + drive(left_arm_action=left_grasp_action, right_arm_action=right_grasp_action) + ``` + +2. **Reorient Both End-Effectors to a Downward-Facing Pose:** + - Reorient the left arm's end-effector to face downward. + - Reorient the right arm's end-effector to face downward. + + ```python + left_orient_action = orient_eef(robot_name='left_arm', direction='down') + right_orient_action = orient_eef(robot_name='right_arm', direction='down') + drive(left_arm_action=left_orient_action, right_arm_action=right_orient_action) + ``` + +3. **Place the Fork and Spoon at Specified Positions:** + - Place the fork at y = +0.16 relative to the plate’s center using the left arm. + - Place the spoon at y = −0.16 relative to the plate’s center using the right arm. + + ```python + left_place_action = place_on_table(robot_name='left_arm', obj_name='fork', x=None, y=0.16, pre_place_dis=0.08) + right_place_action = place_on_table(robot_name='right_arm', obj_name='spoon', x=None, y=-0.16, pre_place_dis=0.08) + drive(left_arm_action=left_place_action, right_arm_action=right_place_action) + ``` + +4. **Return Both Arms to Initial Pose:** + - Return the left arm to its initial pose. + - Return the right arm to its initial pose. + + ```python + left_initial_action = back_to_initial_pose(robot_name='left_arm') + right_initial_action = back_to_initial_pose(robot_name='right_arm') + drive(left_arm_action=left_initial_action, right_arm_action=right_initial_action) + ``` + +This plan ensures that both arms perform the required actions simultaneously, maintaining synchronization throughout the task. \ No newline at end of file diff --git a/embodichain/lab/gym/robots/interface.py b/embodichain/lab/gym/robots/interface.py index 5f0a289..c41b1b4 100644 --- a/embodichain/lab/gym/robots/interface.py +++ b/embodichain/lab/gym/robots/interface.py @@ -14,7 +14,7 @@ from gymnasium import spaces from embodichain.data.enum import ControlParts, EndEffector, JointType -from embodichain.lab.sim.robots import Robot +from embodichain.lab.sim.objects import Robot from embodichain.utils import logger from embodichain.data.enum import JointType, EefType, ActionMode diff --git a/embodichain/lab/sim/articulation_entity.py b/embodichain/lab/sim/articulation_entity.py deleted file mode 100644 index 9ce21f5..0000000 --- a/embodichain/lab/sim/articulation_entity.py +++ /dev/null @@ -1,1391 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import numpy as np -import typing -from typing import List, Tuple, Union, Dict, Any -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass, field - -import dexsim -from dexsim.models import Entity -from dexsim.engine import Articulation -from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag -from embodichain.utils import logger - -# Try to import DriveController, but make it optional -try: - from rlia.kit.drive_controllers import DriveController -except ImportError: - # If rlia is not available, use Any as a fallback type - DriveController = Any - -from dexsim.utility import inv_transform -from dexsim.utility.env_utils import load_first_environment - -__all__ = ["ArticulationEntity"] - - -@dataclass -class ArticulationPosition: - r"""Represents the position of an articulation in a robotic system. - - Attributes: - init_qpos (Union[np.ndarray, Dict[str, np.ndarray]]): - The initial joint positions of the articulation, which can be a - NumPy array or a dictionary mapping joint names to their initial - positions. - - init_base_xpos (Union[np.ndarray, Dict[str, np.ndarray]], optional): - The initial base position of the articulation, which can also be a - NumPy array or a dictionary mapping base names to their initial - positions. Defaults to None. - """ - - init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = field(default_factory=dict) - init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None - - -@dataclass -class ArticulationControl: - r"""Controls the behavior of an articulation in a robotic system. - - Attributes: - speed_ratio (float): - The ratio of speed for the articulation control. Default is 0.5. - - time_step (float): - The time step for control updates in seconds. Default is 0.02. - - drive_type (DriveType): - The type of drive used for the articulation control. Default is 'TARGET'. - """ - - speed_ratio: float = 0.5 - time_step: float = 0.02 - drive_type: "DriveType" = "TARGET" - - -@dataclass -class ArticulationJointConfiguration: - link_names: List[str] = field(default_factory=list) - joint_names: List[str] = field(default_factory=list) - - root_link_name: str = field(default_factory=dict) - end_link_name: str = field(default_factory=dict) - - -class ArticulationEntity(metaclass=ABCMeta): - r""" - Abstract class for articulation entity in simulation. - """ - - def __init__( - self, - urdf_path: Union[str, List[str]] = dict(), - init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = dict(), - init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None, - speed_ratio: float = 0.5, - time_step: float = 0.02, - drive_type: DriveType = DriveType.FORCE, - env: dexsim.environment.Arena = None, - **kwargs, - ): - r"""Initialize the articulation entity. - - Args: - urdf_path (str): urdf file path of robot - init_qpos (np.ndarray, optional): [dof] of double. Init robot joint state(home joint state). - init_base_xpos (np.ndarray, optional): [4, 4] of double. Robot base pose in arena coordinate system. - speed_ratio (float, optional): 0 ~ 1. Robot speed ratio. - time_step (float, optional): wait time between two update. Defaults to 1/50. - drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. Defaults to DriveType.FORCE. - env (Arena, optional): dexsim.environment.Arena. Load the first world(None defaults). - kwargs(optional): Accepts additional keyword arguments. - """ - # placeholder for articulations to be created to the robot. - # a robot can have multiple articulations, for example, - # 1. a arm with a gripper (manipulator) - # 2. two arms - # 3. mobile manipulator - self.articulation = None - - ## Additional variable for DualManipulator, Humanoids and DexterousHands: - # Dictionary to map child to its parent articulation "self.articulation" - self.child_articulations: Dict[str, Articulation] = dict() - - # URDF file path(s) for the robot - self.urdf_path = urdf_path - - # initial joint positions of the robot. - self.init_qpos = init_qpos - - # initial base pose of the robot in arena coordinate system. - self.init_base_xpos = init_base_xpos - - # Dictionary to store degrees of freedom for each articulation - self._dof: Dict[str, int] = dict() - - # Dictionary for actual control joint indices of articulations - self._joint_ids: Dict[str, np.ndarray] = dict() - - # self._actived_joint_names = dict() - - # TODO: Maybe turn to dict stored joint pos, vel, acc limits. - # List to store the limits for each joint's motion. - self._joint_limit = [] - - # placeholder for actors to attach to the robot. - self.attached_actors: Dict[str, Entity] = dict() - - # Dictionary to map control group names to their corresponding root link names, - # used for accessing the base position of each control group. - self.root_link_names: Dict[str] = kwargs.get("root_link_names", {}) - - # Dictionary to map control group names to their corresponding end link names, - # used for accessing the terminal position of each control group. - self.end_link_names: Dict[str] = kwargs.get("end_link_names", {}) - - # Speed ratio for the robot's movement - self.speed_ratio = speed_ratio - - # Time step for control updates - self.time_step = time_step - - # Validate and set the drive type - if drive_type not in [DriveType.FORCE, DriveType.FORCE]: - logger.log_error(f"Invalid drive type: {drive_type}.") - self.drive_type = drive_type - - # Dictionary to map child to its parent init_base_xpos "self.init_base_xpos" - self.child_init_base_xpos = dict() - - # Dictionaries for drive and task controllers - self.drive_controllers: Dict[str, DriveController] = dict() - - # Load the first environment if not provided - self._env, self._world = load_first_environment(env) - - def get_articulation(self, uid: str = None) -> dexsim.engine.Articulation: - r"""Get articulation based on its unique identifier (uid). - - This method returns the articulation associated with the provided uid. - If uid is not specified (None), it returns all articulations. If the - uid is invalid, a warning is logged, and None is returned. - - Args: - uid (str, optional): The unique identifier for the articulation. If None, all articulations will be returned. - - Returns: - dexsim.engine.Articulation or Dict: The articulation corresponding to the provided uid, or a dictionary of all articulations if uid is None. Returns None if the uid is invalid. - """ - - if uid is None or uid == self.uid: - return self.articulation - - if uid in self.child_articulations: - return self.child_articulations[uid] - else: - logger.log_warning( - f"Current uid {self.uid} cannot find the corresponding Articulation." - ) - return None - - def _setup_child_articulations(self, uid: str, control_parts: Dict): - r"""Initialize child articulations and establish a mapping between parent and child articulations. - - This method sets up child articulations associated with a parent articulation identified by its UID. - It verifies the existence of the parent articulation before proceeding to initialize the child articulations. - - Args: - uid (str): The unique identifier (UID) of the parent articulation. - control_parts (Dict): A dictionary of control parts to initialize as child articulations. - - Returns: - bool: True if the child articulations were successfully set up; False otherwise. - """ - # Use a list comprehension to filter valid control parts and log warnings for the invalid ones - control_parts_dict = {} - - # Check if the articulation is valid and if the provided UID matches the instance's UID - if self.articulation is None or uid != self.uid: - logger.log_warning(f"Articulation with UID '{uid}' not found.") - return False - - # Iterate over control parts to set up child articulations - for control_part in control_parts: - # Add to child articulations - control_parts_dict[control_part] = self.articulation - - # Establish the relationship between the child articulations and their parent - self.child_articulations = control_parts_dict - - return True - - @property - def default_physical_attrs(self) -> PhysicalAttr: - physical_attr = PhysicalAttr() - if self.drive_type == DriveType.FORCE: - physical_attr.static_friction = 1.0 - physical_attr.dynamic_friction = 0.9 - physical_attr.linear_damping = 0.7 - physical_attr.angular_damping = 0.7 - physical_attr.contact_offset = 0.005 - physical_attr.rest_offset = 0.001 - physical_attr.restitution = 0.05 - physical_attr.has_gravity = True - physical_attr.max_linear_velocity = 4000 - physical_attr.max_angular_velocity = 25 - physical_attr.max_depenetration_velocity = 1e1 - else: # DriveType.FORCE and so on - physical_attr.static_friction = 1.0 - physical_attr.dynamic_friction = 0.9 - physical_attr.linear_damping = 0.7 - physical_attr.angular_damping = 0.7 - physical_attr.contact_offset = 0.005 - physical_attr.rest_offset = 0.001 - physical_attr.restitution = 0.05 - physical_attr.has_gravity = False - physical_attr.max_linear_velocity = 1e6 - physical_attr.max_angular_velocity = 1e6 - physical_attr.max_depenetration_velocity = 1e1 - return physical_attr - - @property - def default_drive_param(self) -> Dict: - # Stiffness: - # Recommended range: 2000 N/m to 10000 N/m - # Note: Higher stiffness is suitable for tasks that require precise position control, - # such as gripping and assembly. You can start with 5000 N/m and fine-tune based on feedback from the actual application. - # Damping: - # Recommended range: 200 Ns/m to 1000 Ns/m - # Note: Damping values ​​should be high enough to dampen oscillations, - # but not too high to excessively hinder motion. You can start with 500 Ns/m and adjust based on dynamic performance. - # Max force: - # Recommended range: 10000 N to 100000 N - # Note: The maximum force should be set according to the load capacity of the robot arm - # to ensure that it does not exceed its load capacity when working. You can start with 50000 N, depending on the specific task load. - if self.drive_type == DriveType.FORCE: - param = {"stiffness": 2e3, "damping": 2e2, "max_force": 2e4} - elif self.drive_type == DriveType.FORCE: - param = {"stiffness": 1e8, "damping": 1e6, "max_force": 1e10} - return param - - def set_uid(self, uid: str) -> None: - r"""Set unique id of the robot. - - Args: - uid (str): Unique id of the robot. - """ - if uid == self.uid: - logger.log_warning( - f"The uid: {uid} is the same as the current: {self.uid}." - ) - else: - self.uid = uid - - def get_urdf_path(self) -> str: - r"""Provides the file path to the Unified Robot Description Format (URDF) file. - - Returns: - str: A string representing the file path to the robot's URDF file. - """ - return self.urdf_path - - def get_dof(self, name: str = None) -> Union[int, Dict[str, int]]: - r"""Get degree of freedom (DoF) of the robot. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - Union[int, Dict[str, int]]: - - If `name` is None, returns the total DoF of the robot as an integer. - - If `name` is provided and found, returns the DoF of the specified articulation as an integer. - - If `name` is provided but not found, logs a warning and returns 0. - """ - # TODO: Need to clarify behavior. - if name is None: - if isinstance(self._dof, dict): - return sum(self._dof.values()) - else: - return ( - self._dof - ) # Assuming _dof is an integer representing the total DoF - elif name in self._dof: - return self._dof[ - name - ] # Assuming _dof[name] is an integer representing the DoF of the specified articulation - - logger.log_warning(f"Articulation '{name}' not found.") - return 0 - - def _convert_pose(self, pose: np.ndarray, is_to_arena: bool) -> np.ndarray: - r"""Convert a given pose to the specified coordinate system. - - Args: - pose (np.ndarray): A [4, 4] transformation matrix representing the pose to be converted. - is_to_arena (bool): If True, convert to arena coordinate system; otherwise, convert to world coordinate system. - - Returns: - np.ndarray: A [4, 4] transformation matrix representing the pose in the specified coordinate system. - """ - if pose is None: - return np.eye(4) - - pose_array = np.array(pose) - - if pose_array.shape == (4, 4): - poses_to_convert = [pose_array] - elif pose_array.ndim == 3 and pose_array.shape[1:] == (4, 4): - poses_to_convert = pose_array - else: - logger.log_warning(f"Invalid shape for pose: {pose.shape}") - return np.eye(4) - - # Retrieve the world pose of the arena's root node - arena_root_pose = self._env.get_root_node().get_world_pose() - - # Determine the transformation logic based on the value of is_to_arena - if is_to_arena: - # Apply the inverse transformation to convert to the arena coordinate system - inv_arena_root_pose = np.linalg.inv(arena_root_pose) - converted_poses = [inv_arena_root_pose @ p for p in poses_to_convert] - else: - # Directly apply the transformation to convert to the world coordinate system - converted_poses = [arena_root_pose @ p for p in poses_to_convert] - - # Return the result in the same format as the input - if pose_array.shape == (4, 4): - return converted_poses[0] # Return single pose - else: - return np.array(converted_poses) # Return list/array of poses - - def set_joint_ids(self, joint_ids: np.ndarray, uid: str = None): - r"""Set joint IDs for the given UID. - - Args: - joint_ids (np.ndarray): Joint IDs to set. - uid (str, optional): The unique identifier for the joint. Defaults to None. - """ - uid = uid or self.uid - self._joint_ids[uid] = joint_ids - - def get_joint_ids(self, name: str = None) -> List: - r"""Gets joint IDs from the internal storage. - - Args: - name (str, optional): The name of the joint to look up. - If None, all joint IDs are returned. - - Returns: - List: A list of joint IDs associated with the specified name, - or a dictionary of all joint IDs if no name is given. - Returns an empty list if the name is not found. - """ - if name is None: - return {key: value for key, value in self._joint_ids.items()} - if name in self._joint_ids: - return self._joint_ids[name] - else: - logger.log_warning( - f"Joint ids with name '{name}' not found in self._joint_ids." - ) - return [] - - def get_joint_limits( - self, name: str = None - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get joint limits for the specified articulation. - - Args: - name (str): Name of the articulation. Defaults to None. - - Returns: - np.ndarray: [dof, 2] of float. Lower and upper joint limits. - Dict[str, np.ndarray]: [dof, 2] of float. Lower and upper joint limits for all articulations. - """ - limits = self.articulation.get_joint_limits() - - if name is None: - return limits - else: - if self.uid == name: - return limits[self._joint_ids[name]] - - if name not in self.child_articulations: - logger.log_warning(f"Articulation '{name}' not found.") - return None - return limits[self._joint_ids[name]] - - def get_link_names( - self, name: str = None - ) -> Union[List[str], Dict[str, List[str]]]: - r"""Gets the list of link names for a given articulation. - - Args: - name (str, optional): The name of the articulation. If None, returns link names for all articulations. - - Returns: - List[str]: A list of link names for the specified articulation if `name` is provided. - Dict[str, List[str]]: A dictionary mapping articulation names to their respective link name lists if `name` is None. - None: Returns None if the specified articulation name is not found. - """ - # todo: Articulation needs to distinguish between some parents and children. - link_names = self.articulation.get_link_names() - - if name is None or name == self.uid: - # Return a dictionary of link names for all articulations - return link_names - else: - if name in self.child_articulations: - return link_names[self._joint_ids[name]] - - def _get_link_velocity( - self, name: str = None, is_linear: bool = True, is_root: bool = False - ) -> Union[np.ndarray, None]: - r"""Get the link velocity of the specified articulation. - - Args: - name (str, optional): Name of the articulation. If None, retrieves velocities for all articulations. - is_linear (bool, optional): If True, retrieves linear velocity; otherwise, retrieves angular velocity. - is_root (bool, optional): If True, returns the root link velocity as a flattened array. - - Returns: - Union[np.ndarray, None]: Returns the velocity of the specified joint as a numpy array, or None if not found. - """ - - def _get_link_velocity_helper( - name: str, is_linear: bool = True, is_root: bool = False - ) -> typing.Optional[np.ndarray]: - """Helper function to get the link velocity for a specific articulation.""" - if name == self.uid: - link_general_vel = self.articulation.get_link_general_velocities() - link_velocity = ( - link_general_vel[:, :3] if is_linear else link_general_vel[:, 3:] - ) - return link_velocity[0].reshape(-1) if is_root else link_velocity - elif name in self.child_articulations: - link_general_vel = self.child_articulations[ - name - ].get_link_general_velocities() - link_velocity = ( - link_general_vel[:, :3] if is_linear else link_general_vel[:, 3:] - ) - return link_velocity[0].reshape(-1) if is_root else link_velocity - else: - return None - - if name is None: - link_velocity = _get_link_velocity_helper( - name=self.uid, is_linear=is_linear, is_root=is_root - ) - else: - link_velocity = _get_link_velocity_helper( - name=name, is_linear=is_linear, is_root=is_root - ) - - return link_velocity - - def get_body_link_linear_velocity( - self, - name: str = None, - ) -> Union[np.ndarray, None]: - r"""Get body link linear velocity in coordinate frame. - - Args: - name (str, optional): The name of the articulation. - If None, retrieves the velocity of all articulations. - - Returns: - Union[np.ndarray, None]: - If a name is provided, returns an array of shape [link_num, 3] - representing the linear velocity of the specified articulation. - If name is None, returns a dictionary mapping articulation names - to their corresponding linear velocities. - """ - return self._get_link_velocity(name=name, is_linear=True, is_root=False) - - def get_body_link_angular_velocity( - self, - name: str = None, - ) -> Union[np.ndarray, None]: - r"""Get body link angular velocity in coordinate frame. - - Args: - name (str, optional): The name of the articulation. - If None, retrieves the velocity of all articulations. - - Returns: - Union[np.ndarray, None]: - If a name is provided, returns an array of shape [link_num, 3] - representing the angular velocity of the specified articulation. - If name is None, returns a dictionary mapping articulation names - to their corresponding angular velocities. - """ - return self._get_link_velocity(name=name, is_linear=False, is_root=False) - - def get_root_link_linear_velocity( - self, - name: str = None, - ) -> Union[np.ndarray, None]: - r"""Get root link linear velocity in coordinate frame. - - Args: - name (str, optional): The name of the articulation. - If None, retrieves the velocity of all articulations. - - Returns: - Union[np.ndarray, None]: - If a name is provided, returns an array of shape [3] - representing the linear velocity of the root link. - If name is None, returns a dictionary mapping articulation names - to their corresponding linear velocities. - """ - return self._get_link_velocity(name=name, is_linear=True, is_root=True) - - def get_root_link_angular_velocity( - self, - name: str = None, - ) -> Union[np.ndarray, None]: - r"""Get root link angular velocity in coordinate frame. - - Args: - name (str, optional): The name of the articulation. - If None, retrieves the velocity of all articulations. - - Returns: - Union[np.ndarray, None]: - If a name is provided, returns an array of shape [3] - representing the angular velocity of the root link. - If name is None, returns a dictionary mapping articulation names - to their corresponding angular velocities. - """ - return self._get_link_velocity(name=name, is_linear=False, is_root=True) - - def _set_articulation_property( - self, - name: str, - property_name: str, - value: Union[np.ndarray, Dict[str, np.ndarray]], - use_params: bool = True, - **params, - ) -> bool: - r"""Helper function to set a property for a specific articulation. - - This function attempts to set a specified property (e.g., position, velocity) - for the articulation identified by 'name'. It first checks if the articulation - is a child articulation and then checks the main articulations. If the - articulation is found and the property exists, the function sets the property - with the provided value. - - Args: - name (str): The name of the articulation to set the property for. - property_name (str): The name of the property to set. - value (Union[np.ndarray, Dict[str, np.ndarray]]): The value to set the property to. - use_params (bool): Whether to use params when calling the property method. - - Returns: - bool: True if the property was successfully set, False otherwise. - """ - # Use self._joint_ids[name] if params is empty - if use_params and not params: - params = {"joint_ids": self._joint_ids[name]} - - # Check in child articulations first - if name in self.child_articulations: - child_articulation = self.child_articulations[name] - if hasattr(child_articulation, property_name): - # Call the property method with or without params - if use_params: - getattr(child_articulation, property_name)(value, **params) - else: - getattr(child_articulation, property_name)(value) - return True - - # Check the main articulation - if name == self.uid: - if hasattr(self.articulation, property_name): - # Call the property method with or without params - if use_params: - getattr(self.articulation, property_name)(value, **params) - else: - getattr(self.articulation, property_name)(value) - return True - - logger.log_warning(f"Articulation '{name}' not found.") - return False - - def get_current_xpos( - self, name: str = None, is_world_coordinates: bool = True - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get the current pose of the articulations. - - This method retrieves the current pose of specified articulation(s) in either world - or base coordinates. It handles both single articulations and hierarchical structures - with parent-child relationships. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - is_world_coordinates (bool, optional): - Whether to use the arena(world) coordinate system(WCS) or the Base - coordinate system(BCS). Defaults to True. - - Returns: - Union[np.ndarray, Dict[str, np.ndarray]]: - Returns the xpos for the specified articulation if `name` is provided and found. - If `name` is None, returns xpos for all articulations. - Returns None if `name` is provided but not found. - """ - - # Function to calculate the current position based on qpos - def calculate_xpos( - key: str, qpos: np.ndarray, parent_key: str = None - ) -> np.ndarray: - if key == self.uid: - articulation = self.articulation - else: - if key in self.child_articulations: - articulation = self.child_articulations.get(key, None) - if articulation is None: - return None # Articulation not found - - # Case 1: Use parent's drive controller for forward kinematics - if ( - parent_key - and (parent_key in self.drive_controllers) - and hasattr(self.drive_controllers[parent_key], "get_fk") - and (self.drive_controllers.get(key, None) is None) - ): - end_link_name = self.end_link_names.get(key, None) - if end_link_name is None: - end_link_index = -1 - else: - end_link_index = self.drive_controllers[parent_key].get_link_orders( - end_link_name - ) - - _, xpos = self.drive_controllers[parent_key].get_fk( - qpos, index=end_link_index - ) - # Case 2: Use articulation's own drive controller - elif (key in self.drive_controllers) and hasattr( - self.drive_controllers[key], "get_fk" - ): - if len(qpos) != self.drive_controllers[key]: - qpos = qpos[self._joint_ids[key]] - end_link_name = self.end_link_names.get(key) - if end_link_name is None: - end_link_index = -1 - else: - end_link_index = self.drive_controllers[key].get_link_orders( - end_link_name - ) - - _, xpos = self.drive_controllers[key].get_fk(qpos, index=end_link_index) - # Case 3: Fallback to direct world pose - else: - xpos = self._convert_pose( - articulation.get_world_pose(), is_to_arena=True - ) - return xpos - - # Get the base xpos for the articulation - # If parent_key exists, use it; otherwise use the current key - base_xpos = self.get_base_xpos(parent_key if parent_key else key) - - # Get initial transformation matrix, default to identity if not found - initial_xpos = self.init_base_xpos.get(key, np.eye(4)) - - if is_world_coordinates: - # Special handling for root links which require different transformation logic - if self.root_link_names.get(key, None) is not None: - if key not in self.drive_controllers: - # For articulations without drive controllers, - # transform using base transformation matrix - return base_xpos @ xpos - else: - # For articulations with drive controllers, - # get an up-to-date base transformation and apply it - root_base_xpos = self.get_base_xpos(key) - return root_base_xpos @ xpos - else: - # Handle non-root links - # TODO: judge by num of drive_controllers - return ( - (initial_xpos @ xpos) - if parent_key is not None - else (base_xpos @ xpos) - ) - - return xpos - - # If name is None, calculate for all articulations - if name is None: - current_xpos = {} - qpos = self.get_current_qpos(self.uid) # Get qpos once for all - - # Calculate for all main articulations - xpos = calculate_xpos(self.uid, qpos) - if xpos is not None: - current_xpos[self.uid] = xpos - - # Calculate for child articulations using parent drive controller - for child_key in self.child_articulations: - xpos = calculate_xpos(child_key, qpos, self.uid) - if xpos is not None: - current_xpos[child_key] = xpos - - return current_xpos - - # Check for articulation in child articulations - if name in self.child_articulations: - if self.uid in self._actived_joint_names: - xpos = calculate_xpos(name, self.get_current_qpos()[self.uid], self.uid) - else: - xpos = calculate_xpos(name, self.get_current_qpos(self.uid), self.uid) - if xpos is not None: - return xpos - - # Check for articulation in main articulation - xpos = calculate_xpos(name, self.get_current_qpos(name)) - if xpos is not None: - return xpos - - logger.log_warning(f"Articulation '{name}' not found.") - return None - - def get_base_xpos( - self, name: str = None, is_init: bool = False - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get current robot base pose in arena coordinate system. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - is_init (bool, optional): Init base xpos or current base xpos. Current base xpos defaults. - - Returns: - np.ndarray: Joint positions for the specified articulation if `name` is provided and found. - Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - - if is_init: - # Return initial base positions - return self.init_base_xpos.get(name) if name else self.init_base_xpos - - # Initialize a dictionary for current base positions - current_base_xpos_dict = {} - - # Get the current base xpos for the main articulation - base_xpos = self.articulation.get_link_pose(self.root_link_names[self.uid]) - current_base_xpos_dict[self.uid] = self._convert_pose( - base_xpos, is_to_arena=True - ) - - # Populate the dictionary with joint positions for all child articulations - for key in self.child_articulations: - if ( - self.root_link_names.get(key, None) - in self.child_articulations[key].get_link_names() - ): - child_base_xpos = self._get_articulation_property( - key, "get_link_pose", link_name=self.root_link_names[key] - ) - current_base_xpos_dict[key] = self._convert_pose( - child_base_xpos, is_to_arena=True - ) - - if name is None: - return current_base_xpos_dict - - # If a specific articulation name is provided - if name == self.uid: - return self._convert_pose(base_xpos, is_to_arena=True) - - # Get the base xpos for the specified articulation - current_base_xpos = self._get_articulation_property( - name, "get_link_pose", link_name=self.root_link_names[name] - ) - return self._convert_pose(current_base_xpos, is_to_arena=True) - - def set_base_xpos( - self, name: str = None, base_xpos: np.ndarray = np.eye(4) - ) -> None: - r"""Set the robot's base pose. - - Args: - name (str, optional): - Name of the articulation. If specified, the function will - apply the base pose to the articulation with this name. - Defaults to None, which means the base pose will be set for - the entire robot. - - base_xpos (np.ndarray, optional): - A [4, 4] matrix representing the transformation matrix that - defines the base pose of the robot. The matrix should - contain rotation and translation information. Defaults to - the identity matrix (np.eye(4)), indicating no change in pose. - """ - if base_xpos is None: - logger.log_warning("base_xpos is None, no action taken.") - return False - - if name is None or name == self.uid: - if isinstance(base_xpos, dict): - failed_cases = [] - for articulation_name, pos in base_xpos.items(): - if not self._set_articulation_property( - articulation_name, - "set_world_pose", - self._convert_pose(pos, is_to_arena=False), - False, - ): - failed_cases.append(articulation_name) - if failed_cases: - logger.log_warning( - f"Failed to set base xpos for articulations: {failed_cases}" - ) - return False - return True - elif isinstance(base_xpos, (list, np.ndarray)): - self._set_articulation_property( - name, - "set_world_pose", - self._convert_pose(base_xpos, is_to_arena=False), - False, - ) - return True - else: - logger.log_warning( - f"Expected base xpos to be dict for articulations, got {type(base_xpos)}." - ) - return False - else: - if isinstance(base_xpos, (list, np.ndarray)): - return self._set_articulation_property( - name, - "set_world_pose", - self._convert_pose(base_xpos, is_to_arena=False), - False, - ) - else: - logger.log_warning( - f"Expected base xpos to be np.ndarray for articulation '{name}', got {type(base_xpos)}." - ) - return False - - def get_current_joint_poses( - self, name: str = None - ) -> Union[List[np.ndarray], Dict[str, List[np.ndarray]]]: - r"""Get current robot joint poses. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - List[np.ndarray]: List of [4, 4]. Joint poses for the specified articulation if `name` is provided and found. - Dict[str, List[np.ndarray]]: Joint poses for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - name = name or self.uid - - if name == self.uid: - current_joint_poses = dict() - if hasattr(self.articulation, "get_joint_poses"): - current_joint_poses = self._convert_pose( - self.articulation.get_joint_poses(self._joint_ids[self.uid]), - is_to_arena=True, - ) - - return current_joint_poses - else: - if name in self.child_articulations: - logger.log_warning(f"Articulation {name} not found.") - return None - - return self._convert_pose( - self.child_articulations[name].get_joint_poses(self._joint_ids[name]), - is_to_arena=True, - ) - - def get_init_qpos(self, name: str = None) -> None: - r"""Get robot initial joint positions. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - np.ndarray: initial joint positions for the specified articulation if `name` is provided and found. - Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - if name is None: - return self.init_qpos - - if name in self.child_articulations or name == self.uid: - return self.init_qpos[name] - - logger.log_warning(f"Articulation {name} not found.") - return None - - def set_init_qpos( - self, name: str = None, qpos: Union[np.ndarray, Dict[str, np.ndarray]] = [] - ) -> None: - r"""Set initial joint positions. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - qpos (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot initial joint positions. - """ - if qpos is None: - logger.log_warning("qpos is None, no action taken.") - return - - if name is None or name == self.uid: - if isinstance(qpos, dict): - for articulation_name, pos in qpos.items(): - if articulation_name in self.init_qpos: - self.init_qpos[articulation_name] = pos - else: - logger.log_warning( - f"Articulation '{articulation_name}' not found in init_qpos." - ) - elif isinstance(qpos, (list, np.ndarray)): - self.init_qpos[self.uid] = qpos - else: - logger.log_warning( - f"Unsupported qpos type: {type(qpos)}, expected np.ndarray or dict." - ) - else: - if not isinstance(qpos, (list, np.ndarray)): - logger.log_warning( - f"Expected qpos to be np.ndarray for articulation '{name}', got {type(qpos)}." - ) - return - - if name in self.init_qpos: - self.init_qpos[name] = qpos - else: - logger.log_warning(f"Articulation '{name}' not found in init_qpos.") - - def _get_articulation_property( - self, name: str, property_name: str, **params - ) -> Union[np.ndarray, None]: - r"""Helper function to get a property for a specific articulation. - - This function retrieves the value of a specified property (e.g., position, - velocity) for the articulation identified by 'name'. It first checks if the - articulation is a main articulation and then checks child articulations. If - the articulation is found and the property exists, the function returns the - property's value. - - Args: - name (str): The name of the articulation to get the property from. - property_name (str): The name of the property to retrieve. - - Returns: - Union[np.ndarray, None]: The value of the property if found, None otherwise. - """ - # Use self._joint_ids[name] if params is empty - if not params: - if name in self._joint_ids: - params = {"joint_ids": self._joint_ids[name]} - else: - logger.log_warning(f"Joint_id '{name}' not found.") - has_similar_name = False - for key, val in self._joint_ids.items(): - if name in key: - params = {"joint_ids": val} - logger.log_warning(f"Joint_id '{key}' is used for {name}.") - name = key - has_similar_name = True - break - if not has_similar_name: - return None - - if name == self.uid: - return getattr(self.articulation, property_name)(**params) - - if len(self._joint_ids[name]): - if name in self.child_articulations: - child_articulation = self.child_articulations[name] - return getattr(child_articulation, property_name)(**params) - else: - return None - - logger.log_warning(f"Articulation '{name}' not found.") - return None - - def set_current_qpos( - self, name: str = None, qpos: Union[np.ndarray, Dict[str, np.ndarray]] = None - ): - r"""Set current robot joint positions. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - qpos (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot current joint positions. - - Returns: - bool: True if the positions were successfully set, False otherwise. - """ - if qpos is None: - logger.log_warning("qpos is None, no action taken.") - return False - - if name is None or name == self.uid: - if isinstance(qpos, dict): - failed_cases = [] - for articulation_name, pos in qpos.items(): - if not self._set_articulation_property( - articulation_name, "set_current_qpos", pos - ): - failed_cases.append(articulation_name) - if failed_cases: - logger.log_warning( - f"Failed to set qpos for articulations: {failed_cases}" - ) - return False - return True - elif isinstance(qpos, (list, np.ndarray)): - return self._set_articulation_property(name, "set_current_qpos", qpos) - else: - logger.log_warning( - f"Expected qpos to be dict for articulations, got {type(qpos)}." - ) - return False - else: - if isinstance(qpos, (list, np.ndarray)): - return self._set_articulation_property(name, "set_current_qpos", qpos) - else: - logger.log_warning( - f"Expected qpos to be np.ndarray for articulation '{name}', got {type(qpos)}." - ) - return False - - def get_current_qpos( - self, name: str = None - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get current robot joint positions. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - np.ndarray: Joint positions for the specified articulation if `name` is provided and found. - Dict[str, np.ndarray]: Joint positions for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - # Validate the name parameter - if name is not None and not isinstance(name, str): - logger.log_warning( - f"The 'name' parameter must be a string or None, got {type(name)}." - ) - return None - - if name is None: - # Initialize a dictionary to hold joint positions for all articulations - current_qpos_dict = {} - - # Get the current joint positions for the main articulation - qpos = self.articulation.get_current_qpos() - current_qpos_dict[self.uid] = qpos - - # Populate the dictionary with joint positions for all child articulations - for key in self.child_articulations: - current_qpos_dict[key] = self._get_articulation_property( - key, "get_current_qpos" - ) - - return current_qpos_dict - else: - - return self._get_articulation_property(name, "get_current_qpos") - - def set_current_qvel( - self, name: str = None, qvel: Union[np.ndarray, Dict[str, np.ndarray]] = None - ): - r"""Set the current joint velocities of the robot. - - Args: - name (str, optional): - Name of the articulation. If None, the velocities will be set - for all articulations. - - qvel (Union[np.ndarray, Dict[str, np.ndarray]], optional): - Joint velocities. This can be a NumPy array for a single - articulation or a dictionary mapping articulation names to - their respective velocities. - - Returns: - bool: Returns True if the joint velocities were successfully set, - otherwise returns False if no action was taken or if there - were errors in the input. - """ - if qvel is None: - logger.log_warning("qvel is None, no action taken.") - return False - - if name is None or name == self.uid: - if isinstance(qvel, dict): - failed_cases = [] - for articulation_name, vel in qvel.items(): - if not self._set_articulation_property( - articulation_name, "set_current_qvel", vel - ): - failed_cases.append(articulation_name) - if failed_cases: - logger.log_warning( - f"Failed to set qvel for articulations: {failed_cases}" - ) - return False - return True - else: - logger.log_warning( - f"Expected qvel to be dict for articulations, got {type(qvel)}." - ) - return False - else: - if isinstance(qvel, (list, np.ndarray)): - return self._set_articulation_property(name, "set_current_qvel", qvel) - else: - logger.log_warning( - f"Expected qvel to be np.ndarray for articulation '{name}', got {type(qvel)}." - ) - return False - - def get_current_qvel( - self, name: str = None - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get current robot joint velocities. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - np.ndarray: Joint velocities for the specified articulation if `name` is provided and found. - Dict[str, np.ndarray]: Joint velocities for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - if name is None: - # Initialize a dictionary to hold joint velocities for all articulations - current_qvel_dict = {} - - # Get the current joint velocities for the main articulation - qvel = self.articulation.get_current_qvel() - # Store the velocity of the main articulation in the dictionary using its unique ID - current_qvel_dict[self.uid] = qvel - - # Iterate over child articulations to get their velocities - for key in self.child_articulations: - # Retrieve and store the joint velocity for the child articulation in the dictionary - current_qvel_dict[key] = self._get_articulation_property( - key, "get_current_qvel" - ) - - # Return the dictionary containing velocities for all articulations - return current_qvel_dict - else: - return self._get_articulation_property(name, "get_current_qvel") - - def set_current_qf( - self, - name: str = None, - qf: Union[np.ndarray, Dict[str, np.ndarray]] = None, - ): - r"""Set current robot joint force. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - qf (Union[np.ndarray, Dict[str, np.ndarray]]): [dof] of float. Robot current joint force. - - """ - if qf is None: - logger.log_warning("joint_force is None, no action taken.") - return False - - if name is None: - if isinstance(qf, dict): - failed_cases = [] - for articulation_name, force in qf.items(): - if not self._set_articulation_property( - articulation_name, "set_current_qf", force - ): - failed_cases.append(articulation_name) - if failed_cases: - logger.log_warning( - f"Failed to set joint force for articulations: {failed_cases}" - ) - return False - return True - else: - logger.log_warning( - f"Expected joint_force to be dict for articulations, got {type(qf)}." - ) - return False - else: - if isinstance(qf, (list, np.ndarray)): - return self._set_articulation_property(name, "set_current_qf", qf) - else: - logger.log_warning( - f"Expected joint_force to be np.ndarray for articulation '{name}', got {type(qf)}." - ) - return False - - def get_current_qf( - self, name: str = None - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - r"""Get current robot joint force. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - np.ndarray: Joint force for the specified articulation if `name` is provided and found. - Dict[str, np.ndarray]: Joint force for all articulations if `name` is None. - None: If `name` is provided but not found. - """ - if name is None: - # Initialize a dictionary to hold joint forces for all articulations - current_qf_dict = {} - - # Get the current joint forces for the main articulation - qvel = self.articulation.get_current_qvel() - # Store the velocity of the main articulation in the dictionary using its unique ID - current_qf_dict[self.uid] = qvel - - # Iterate over child articulations to get their forces - for key in self.child_articulations: - # Retrieve and store the joint velocity for the child articulation in the dictionary - current_qf_dict[key] = self._get_articulation_property( - key, "get_current_qvel" - ) - - # Return the dictionary containing forces for all articulations - return current_qf_dict - else: - return self._get_articulation_property(name, "get_current_qf") - - @staticmethod - def is_approx(qpos1: np.ndarray, qpos2: np.ndarray, eps: float = 1e-5): - r"""Evaluate whether qpos1 and qpos2 are 'close'. - - Args: - qpos1 (np.ndarray): a object of joint - qpos2 (np.ndarray): a object of other joint - - Returns: - bool: is close - """ - qpos1 = np.array(qpos1) - qpos2 = np.array(qpos2) - if qpos1.shape != qpos2.shape: - logger.log_warning( - "qpos1 shape {} does not match qpos2 shape {}, qpos1: {}, qpos2: {}.".format( - qpos1.shape, qpos2.shape, qpos1, qpos2 - ) - ) - return False - - dis = np.linalg.norm(qpos1 - qpos2, ord=1) - return dis < eps - - def create_physical_visible_node( - self, name: str, rgba: np.array = None, link_name: str = None - ) -> bool: - r"""Create a physical visible node for the articulation. - - Args: - name (str): - The name/identifier of the articulation to create the visible node for. - Must match either the main articulation's UID or a child articulation's name. - - rgba (np.ndarray, optional): - An array of 4 float values representing the RGBA color values: - - Red component (0.0 to 1.0) - - Green component (0.0 to 1.0) - - Blue component (0.0 to 1.0) - - Alpha/transparency (0.0 to 1.0) - Defaults to [0.0, 1.0, 0.0, 0.6] (semi-transparent green). - - link_name (str, optional): - The specific link name of the articulation to create the visible node for. - If None, visible nodes will be created for all links of the articulation. - Defaults to None. - - Returns: - bool: - True if the visible node was successfully created. - False if: - - The articulation name was not found - - The link name was invalid - - The creation process failed - """ - if rgba is None: - rgba = np.array([0.0, 1.0, 0.0, 0.6]) - else: - rgba = np.array(rgba) - - assert rgba.shape == (4,), "RGBA array must have 4 elements." - - # Prepare parameters for the node creation - params = {"rgba": rgba} - - # Add link_name to parameters if provided - if link_name is not None: - params["link_name"] = link_name - - # Check if the name matches the uid and create the node - if name == self.uid: - return self.articulation.create_physical_visible_node(**params) - elif name in self.child_articulations: - # Otherwise, create the node for the specified child articulation - return self.child_articulations[name].create_physical_visible_node(**params) - - logger.log_warning(f"Articulation '{name}' not found.") - return False - - def set_physical_visible( - self, - name: str, - is_physic_visible: bool, - is_render_body_visible: bool = True, - link_name: str = None, - ) -> bool: - r"""Set whether the current physical collision is visible. - - Args: - name (str): The name of the articulation. - is_physic_visible (bool): Whether the current physical node is visible. - is_render_body_visible (bool, optional): Whether the render body is visible. Defaults to True. - link_name (str, optional): The link name of the articulation. If None, set all articulation visible. Defaults to None. - - Returns: - bool: Returns True if the setting is successful, False otherwise. - """ - # Prepare parameters for setting visibility - params = { - "is_physic_visible": is_physic_visible, - "is_render_body_visible": is_render_body_visible, - } - - # Add link_name to parameters if provided - if link_name is not None: - params["link_name"] = link_name - - # Check if the name matches the uid and set visibility - if name == self.uid: - self.articulation.set_physical_visible(**params) - return True - - # Check if the name is in child articulations and set visibility for it - elif name in self.child_articulations: - self.child_articulations[name].set_physical_visible(**params) - return True - - # Log a warning if the articulation name is not found - logger.log_warning(f"Articulation '{name}' not found.") - return False diff --git a/embodichain/lab/sim/end_effector/__init__.py b/embodichain/lab/sim/end_effector/__init__.py deleted file mode 100644 index 69ee8c8..0000000 --- a/embodichain/lab/sim/end_effector/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- -from .end_effector import EndEffector -from .utility import * - -del end_effector, utility diff --git a/embodichain/lab/sim/end_effector/end_effector.py b/embodichain/lab/sim/end_effector/end_effector.py deleted file mode 100644 index 708e6ce..0000000 --- a/embodichain/lab/sim/end_effector/end_effector.py +++ /dev/null @@ -1,552 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- -import typing -import dexsim.engine -import numpy as np -import dexsim.environment -from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag, ActorType - -from embodichain.lab.sim.end_effector.utility import ( - load_model_from_file, - inv_transform, -) -from abc import ABC, abstractmethod -from embodichain.lab.sim.articulation_entity import ArticulationEntity -from embodichain.utils import logger -import dexsim -import time - - -class EndEffector(ArticulationEntity, ABC): - r""" - Abstract class for end effector in simulation. - """ - - def __init__( - self, - env: dexsim.environment.Arena, - file: str, - drive_type: DriveType = DriveType.FORCE, - **kwargs, - ) -> None: - """init end effector - - Args: - env (dexsim.environment.Arena): dexsim environment. - file (str): input file (urdf or mesh file) - drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. Defaults to DriveType.FORCE. - kwargs(optional): Accepts additional keyword arguments. - """ - urdf_path = load_model_from_file(file_path=file) - - super().__init__( - urdf_path=urdf_path, - init_qpos=None, - init_base_xpos=np.eye(4), - speed_ratio=0.5, - time_step=0.02, - drive_type=drive_type, - env=env, - **kwargs, - ) - - self._init_end_effector(**kwargs) - - self.articulation.set_physical_attr(self.default_physical_attrs) - self.articulation.set_drive( - drive_type=self.drive_type, **self.default_drive_param - ) - - @abstractmethod - def _init_end_effector(self, **kwargs) -> None: - r"""Initializes the robot using the URDF path with necessary parameters.""" - pass - - def _set_ee_control_data(self, **kwargs): - self._dof = self.articulation.get_dof() - self._actived_joint_names = self.articulation.get_actived_joint_names() - self._root_link_name = self.articulation.get_root_link_name() - self._attached_nodes = dict() # {node_name: [dexsim.engine.Node, ActorType]} - self._leaf_link_names = self.articulation.get_leaf_link_names() - - if self._dof > 0: - # ignore mimic information for 0-dof articulation - self._joint_ids[self.uid] = np.arange(self._dof) - self._joint_limit = self.articulation.get_joint_limits() - self._set_mimic() - - self.attach_robot_uid = None # if end-effector is attach to robot. - - # KWARGS. If true, set object to be dynamic when release object, otherwise do nothing. - self._is_release_dynamic = kwargs.get("is_release_dynamic", True) - - # open state sample num - self._open_state_sample_num = kwargs.get("open_state_sample_num", 30) - - # open state and close state - self.open_state = np.array( - [ - 1.0, - ] - ) - self.close_state = np.array( - [ - 0.0, - ] - ) - - @property - def actived_joint_names(self) -> typing.List[str]: - return self._actived_joint_names - - def _set_to_init_qpos(self): - self._init_qpos = np.array([]) - if self._dof > 0: - self._init_qpos = self._joint_limit[:, 0] - self.articulation.set_current_qpos( - self._init_qpos, self._joint_ids[self.uid] - ) - - def get_init_qpos(self) -> np.ndarray: - return self._init_qpos - - @property - def release_dynamic(self) -> bool: - """get is release dynamic - - Returns: - bool: If true, set object to be dynamic when release object, otherwise do nothing. - """ - return self._is_release_dynamic - - @release_dynamic.setter - def release_dynamic(self, is_release_dynamic: bool): - """set is release dynamic - - Args: - is_release_dynamic (bool): If true, set object to be dynamic when release object, otherwise do nothing. - """ - self._is_release_dynamic = is_release_dynamic - - def _set_mimic(self) -> None: - r"""Sets up the mimic configuration for the articulation. - - Attributes Updated: - - self._mimic_joint_ids: Array of joint IDs that are mimicked. - - self._mimic_master_ids: Array of master joint IDs that control the mimicked joints. - - self._mimic_multipliers: Array of multipliers for the mimicked joints. - - self._mimic_offsets: Array of offsets for the mimicked joints. - - self._control_joint_ids: Array of joint IDs that are not mimicked and can be controlled. - - self._control_limit: Joint limits for the controllable joints. - - self._control_num: Number of controllable joints. - """ - mimic_info = self.articulation.get_mimic_info() - - self._mimic_joint_ids = mimic_info.mimic_id - self._mimic_master_ids = mimic_info.mimic_parent - self._mimic_multipliers = mimic_info.mimic_multiplier - self._mimic_offsets = mimic_info.mimic_offset - - # Using set for faster membership testing - mimic_joint_set = set(self._mimic_joint_ids) - - # List comprehension for better readability and performance - self._control_joint_ids = np.array( - [i for i in range(self._dof) if i not in mimic_joint_set] - ) - - self._control_limit = self._joint_limit[self._control_joint_ids] - self._control_num = self._control_joint_ids.shape[0] - - def _qpos_to_control_state(self, qpos: np.ndarray) -> np.ndarray: - """full joint state to control joint state - - Args: - qpos (np.ndarray): [dof] of float. Full joint state. - - Returns: - np.ndarray: [control_joint_num] of float. control joint state - """ - return qpos[self._control_joint_ids] - - def _control_state_to_qpos(self, control_state: np.ndarray) -> np.ndarray: - """control joint state to full joint state - - Args: - control_state (np.ndarray): [control_joint_num] of float. control joint state - - Returns: - np.ndarray: [dof] of float. Full joint state. - """ - qpos = np.empty(shape=(self._dof,), dtype=float) - qpos[self._control_joint_ids] = control_state - qpos[self._mimic_joint_ids] = ( - qpos[self._mimic_master_ids] * self._mimic_multipliers + self._mimic_offsets - ) - return qpos - - def _qpos_to_control_state_path(self, qpos_path: np.ndarray): - return qpos_path[:, self._control_joint_ids] - - def _control_state_to_qpos_path(self, control_state_path: np.ndarray): - waypoint_num = control_state_path.shape[0] - qpos_path = np.empty(shape=(waypoint_num, self._dof), dtype=float) - qpos_path[:, self._control_joint_ids] = control_state_path - qpos_path[:, self._mimic_joint_ids] = ( - qpos_path[:, self._mimic_master_ids] * self._mimic_multipliers - + self._mimic_offsets - ) - return qpos_path - - def _to_arena_pose(self, pose: np.ndarray) -> np.ndarray: - return inv_transform(self._env.get_root_node().get_world_pose()) @ pose - - def get_xpos(self) -> np.ndarray: - """get gripper root link pose - - Returns: - np.ndarray: [4, 4] of float. root link 6d pose - """ - return self._to_arena_pose( - self.articulation.get_link_pose(self._root_link_name) - ) - - def set_xpos(self, pose: np.ndarray) -> None: - """directly set gripper world pose - - Args: - pose (np.ndarray): [4, 4] of float. root link 6d pose - """ - # TODO: When gripper attach to robot base, this function result can be wild. - assert pose.shape == (4, 4) - self.set_world_pose(self._to_arena_pose(pose)) - - def set_world_pose(self, pose: np.ndarray) -> None: - """Set the world pose of the end effector.""" - assert pose.shape == (4, 4), "Pose must be a 4x4 transformation matrix." - self.articulation.set_world_pose(pose) - - def get_qpos(self) -> np.ndarray: - """get robot joint state array - - Returns: - np.ndarray: (joint_num, ) of float. joint state array - """ - return np.array(self.articulation.get_current_qpos(self._joint_ids[self.uid])) - - def set_qpos(self, qpos: np.ndarray) -> None: - """set gripper joint state array - - Args: - qpos (np.ndarray): (joint_num, ) of float. joint state array - """ - assert qpos.shape == (self._dof,) - self.articulation.set_current_qpos(qpos, self._joint_ids[self.uid]) - - def get_control_qpos(self) -> np.ndarray: - """get control joint state - - Returns: - np.ndarray: (control_joint_num, ) of float. - """ - return self._qpos_to_control_state(self.get_qpos()) - - def set_control_qpos(self, control_state: np.ndarray) -> None: - """set control joint state - - Args: - control_state (np.ndarray): (control_joint_num, ) of float - """ - assert control_state.shape == self._control_joint_ids.shape - qpos = self._control_state_to_qpos(control_state) - self.articulation.set_current_qpos(qpos, self._joint_ids[self.uid]) - - def move_qpos(self, qpos_path: np.ndarray, is_wait=True, move_time: float = 1): - assert qpos_path.shape[1] == self._dof - self.move_joints( - qpos_path, - is_wait=is_wait, - joint_ids=self._joint_ids[self.uid], - move_time=move_time, - ) - - def get_leaf_link_pose(self) -> dict: - """get leaf link pose. - - Returns: - dict: {"link_name", np.ndarray [4, 4]} pose of each leaf link - """ - leaf_link_poses = dict() - for leaf_link_name in self._leaf_link_names: - leaf_link_pose = self.articulation.get_link_pose(leaf_link_name) - leaf_link_poses[leaf_link_name] = leaf_link_pose - return leaf_link_poses - - def get_leaf_contact(self, is_flatten: bool = False) -> dict: - """Get leaf link contacts. - Leaf link: 1. has physical body; 2. no child link; 3. parent link is not fixed. - - Args: - is_flatten (bool): get flatten - - Returns: - is_flatten == False: - dict: { - "link_name": { - "nodes": [dexsim.engine.Node, ...], - "contact_positions": [link_contact_num, 3] of float. np.ndarray, - "contact_normals": [link_contact_num, 3] of float. np.ndarray, - "contact_distances": [link_contact_num] of float. np.ndarray, - }, - ... - } - - is_flatten == True: - ContactInfo - - ContactInfo.nodes(List[dexsim.engine.Node]): List of Contact object node ptr - ContactInfo.link_name(List[str]): List of contact link name - ContactInfo.contact_positions(np.ndarray): [contact_num, 3] of float, matrix of contact_positions. - ContactInfo.contact_normals(np.ndarray): [contact_num, 3] of float, matrix of contact normal. - ContactInfo.contact_distances(np.ndarray): [contact_num] of float. Contact distance. Negetive for peneration and postive for surface distance. - """ - contact_info = self.articulation.get_leaf_contacts() - if is_flatten: - return contact_info - link_contact_all_id = np.arange(len(contact_info.nodes)) - - contact_info_dict = dict() - # Tricky implementation. save str ing np.ndarray, and select link name by mask - contact_link_names = np.array(contact_info.link_name) - contact_link_name_unique = np.unique(contact_link_names) - # unpack contact info - for link_name in contact_link_name_unique: - contact_info_dict[link_name] = dict() - link_contact_mask = contact_link_names == link_name - link_contact_ids = link_contact_all_id[link_contact_mask] - contact_info_dict[link_name]["nodes"] = [] - for link_contact_idx in link_contact_ids: - contact_info_dict[link_name]["nodes"].append( - contact_info.nodes[link_contact_idx] - ) - contact_info_dict[link_name][ - "contact_positions" - ] = contact_info.contact_positions[link_contact_ids] - contact_info_dict[link_name][ - "contact_normals" - ] = contact_info.contact_normals[link_contact_ids] - contact_info_dict[link_name][ - "contact_distances" - ] = contact_info.contact_distances[link_contact_ids] - return contact_info_dict - - def get_cpp_articulation(self): - return self.articulation - - def attach(self, node: dexsim.engine.Node) -> str: - """attach certain actor to current end-effector - (will attach to root link) - - Args: - node (dexsim.engine.Node): dexsim actor - - Returns: - str: Name of the attached actor, return none str if will attach wrong actor. - """ - node_name = node.get_name() - original_actor_type = node.get_actor_type() - - if original_actor_type == ActorType.STATIC: - logger.log_info( - "Skipping attachment to static object, its name: {}.".format(node_name) - ) - return "" - if original_actor_type == ActorType.DYNAMIC: - # TODO: tricky implemetation. Fix dynamic actor to kinematic - node.set_actor_type(ActorType.KINEMATIC) - # node.enable_collision(False) - - node_pose = node.get_local_pose() - self_pose = self.get_xpos() - relative_pose = inv_transform(self_pose) @ node_pose - - self.articulation.attach_node( - obj=node, link_name=self._root_link_name, relative_pose=relative_pose - ) - - self._attached_nodes[node_name] = [node, original_actor_type] - return node_name - - def detach(self, node_name: str) -> bool: - """detach certain actor to current suctor - - Args: - actor (dexsim.models.Entity): dexsim actor - - Returns: - bool: is_success - """ - if node_name in self._attached_nodes: - node = self._attached_nodes[node_name][0] - original_actor_type = self._attached_nodes[node_name][1] - arena_root_node = self._env.get_root_node() - node.attach_node(arena_root_node) - if original_actor_type != ActorType.STATIC and self._is_release_dynamic: - node.set_actor_type(ActorType.DYNAMIC) - # node.enable_collision(True) - self._attached_nodes.pop(node_name) - return True - else: - logger.log_warning(f"Actor {node_name} to be detach is not attached yet.") - return False - - @abstractmethod - def get_control_state(self, **kwargs) -> np.ndarray: - """get control state of end-effector - - Returns: - np.ndarray: [state_dof] of float. Control state array - """ - - @abstractmethod - def get_open_state(self, **kwargs) -> np.ndarray: - """get control state of end-effector - - Returns: - np.ndarray: [state_dof] of float. Open state array - """ - - @abstractmethod - def set_open_state(self, open_state: np.ndarray, **kwargs): - """set control state of end-effector - - Args: - open_state (np.ndarray): [state_dof] of float. Open state - """ - - def to_target_open_state_path( - self, - target_open_state: np.ndarray, - start_open_state: np.ndarray = None, - step_num: int = None, - step_size: float = None, - **kwargs, - ) -> np.ndarray: - """Generate a path from the start open state to the target open state for a gripper or a robotic hand. - - An "open state" refers to the configuration of the gripper or robotic hand at a given moment, - which can include the positions of fingers, joints, and any gripping mechanisms. - The "target state" is the desired configuration that the gripper or hand should achieve after - the motion, typically used for grasping or releasing an object. - - Args: - target_open_state (np.ndarray): Target open state, shape [state_dof]. - start_open_state (np.ndarray, optional): Starting open state, shape [state_dof]. Default is None, which uses the current open state. - step_num (int, optional): Number of interpolation points. Default is None. - step_size (float, optional): Step size for interpolation. Default is None. - - Returns: - np.ndarray: Path as an array of shape [waypoint_num, state_dof]. - """ - - if start_open_state is None: - start_open_state = self.get_open_state() - - if step_num is not None and step_size is not None: - logger.log_warning( - "Please provide either 'step_num' or 'step_size', not both." - ) - return [] - - if step_num is not None: - step_num = max(step_num, 1) - elif step_size is not None: - distance = np.linalg.norm(target_open_state - start_open_state) - step_num = int(np.ceil(distance / step_size)) - else: - state_range = np.abs(start_open_state - target_open_state).max() - step_num = int(np.round(self._open_state_sample_num * state_range)) - - open_state_path = np.linspace(start_open_state, target_open_state, step_num) - - return open_state_path - - def open(self, **kwargs): - """open end-effector. only for demo""" - if self._world is not None: - if self._world.is_physics_manually_update(): - logger.log_warning("Cannot call open in physics manually update mode.") - return - open_state_path = self.to_target_open_state_path(self.open_state) - for i in range(open_state_path.shape[0]): - self.set_open_state(open_state_path[i]) - time.sleep(0.02) - - def close(self, **kwargs): - """close end-effector. only for demo""" - if self._world is not None: - if self._world.is_physics_manually_update(): - logger.log_warning("Cannot call close in physics manually update mode.") - return - open_state_path = self.to_target_open_state_path(self.close_state) - for i in range(open_state_path.shape[0]): - self.set_open_state(open_state_path[i]) - time.sleep(0.02) - - @property - def default_physical_attrs(self) -> PhysicalAttr: - physical_attr = PhysicalAttr() - if self.drive_type == DriveType.FORCE: - physical_attr.mass = 0.01 # TODO: mass setting is not activated currently - physical_attr.static_friction = 2.0 - physical_attr.dynamic_friction = 1.5 - physical_attr.linear_damping = 0.7 - physical_attr.angular_damping = 0.7 - physical_attr.contact_offset = 0.005 - physical_attr.rest_offset = 0.001 - physical_attr.restitution = 0.05 - physical_attr.has_gravity = True - physical_attr.max_linear_velocity = 4000 - physical_attr.max_angular_velocity = 25 - physical_attr.max_depenetration_velocity = 1e1 - else: # DriveType.FORCE and so on - physical_attr.mass = 0.01 # TODO: mass setting is not activated currently - physical_attr.static_friction = 2.0 - physical_attr.dynamic_friction = 1.5 - physical_attr.linear_damping = 0.7 - physical_attr.angular_damping = 0.7 - physical_attr.contact_offset = 0.005 - physical_attr.rest_offset = 0.001 - physical_attr.restitution = 0.05 - physical_attr.has_gravity = False - physical_attr.max_linear_velocity = 1e6 - physical_attr.max_angular_velocity = 1e6 - physical_attr.max_depenetration_velocity = 1e1 - return physical_attr - - @property - def default_drive_param(self) -> typing.Dict: - # Stiffness: - # Recommended range: 2000 N/m to 10000 N/m - # Note: Higher stiffness is suitable for tasks that require precise position control, - # such as gripping and assembly. You can start with 5000 N/m and fine-tune based on feedback from the actual application. - # Damping: - # Recommended range: 200 Ns/m to 1000 Ns/m - # Note: Damping values ​​should be high enough to dampen oscillations, - # but not too high to excessively hinder motion. You can start with 500 Ns/m and adjust based on dynamic performance. - # Max force: - # Recommended range: 10000 N to 100000 N - # Note: The maximum force should be set according to the load capacity of the robot arm - # to ensure that it does not exceed its load capacity when working. You can start with 50000 N, depending on the specific task load. - if self.drive_type == DriveType.FORCE: - if hasattr(self, "max_force"): - max_force = self.max_force - else: - max_force = 1e3 - param = {"stiffness": 1e2, "damping": 1e1, "max_force": max_force} - elif self.drive_type == DriveType.FORCE: - param = {"stiffness": 1e8, "damping": 1e6, "max_force": 1e10} - return param diff --git a/embodichain/lab/sim/end_effector/utility.py b/embodichain/lab/sim/end_effector/utility.py deleted file mode 100644 index d1a61e6..0000000 --- a/embodichain/lab/sim/end_effector/utility.py +++ /dev/null @@ -1,148 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import os -import typing -import pathlib -import hashlib -import numpy as np -import open3d as o3d -from dexsim.kit.meshproc import convex_decomposition_coacd -from dexsim.kit.meshproc.utility import mesh_list_to_file -from embodichain.utils import logger - - -def load_model_from_file(**kwargs) -> typing.Optional[str]: - """Loads a model from the specified file path. - - This function checks the provided file path to determine if it is a URDF file - or a mesh file (STL, OBJ, PLY). If it is a URDF file, it is loaded directly. - If it is a mesh file, a URDF file is generated from the mesh. - - Args: - file_path (str): The path to the input file (URDF or mesh file). - - Returns: - Optional[str]: The path to the loaded URDF file, or None if the file path is not provided or unsupported. - """ - file_path = kwargs.get("file_path", None) - - if file_path is None: - logger.log_warning("No file path provided for the model.") - return None - - file_suffix = pathlib.Path(file_path).suffix - mesh_suffix_list = [".stl", ".obj", ".ply"] - - if file_suffix == ".urdf": - # Load the URDF file directly - urdf_path = file_path - elif file_suffix in mesh_suffix_list: - # Generate URDF from the mesh file - urdf_path = generate_gripper_urdf_from_meshpath(file_path) - else: - logger.log_warning( - f"Unsupported file extension {file_suffix} for model file {file_path}." - ) - return None # Return None for unsupported file types - - return urdf_path - - -def generate_gripper_urdf_from_meshpath( - mesh_file: str, cache_dir: str = None, max_convex_hull_num: int = 16 -) -> str: - r"""Generate URDF for a gripper given a mesh file path. - - Args: - mesh_file (str): The path of mesh file. - cache_dir (str, optional): Cache directory. Defaults to None. - max_convex_hull_num (int, optional): The maximum convex hull number. Defaults to 16. - - Returns: - str: Urdf file path. - """ - mesh_md5_key = hashlib.md5(open(mesh_file, "rb").read()).hexdigest() - - # Set cache directory - save_dir = ( - pathlib.Path(cache_dir) - if cache_dir - else pathlib.Path.home() / "urdf_generate_cache" - ) - # Create the directory if it doesn't exist - save_dir.mkdir(parents=True, exist_ok=True) - - # Define cache file names - acd_file = f"{mesh_md5_key}_acd_{max_convex_hull_num}.obj" - visual_file = f"{mesh_md5_key}_visual.obj" - acd_cache_path = save_dir / acd_file - visual_cache_path = save_dir / visual_file - - # Generate convex decomposition cache if not exists - if not acd_cache_path.is_file() or not visual_cache_path.is_file(): - try: - in_mesh = o3d.t.io.read_triangle_mesh(mesh_file) - _, out_mesh_list = convex_decomposition_coacd( - in_mesh, max_convex_hull_num=max_convex_hull_num - ) - - # Write approximate convex decomposition result - mesh_list_to_file(str(acd_cache_path), out_mesh_list) - # Write visual mesh - o3d.t.io.write_triangle_mesh(str(visual_cache_path), in_mesh) - except Exception as e: - raise RuntimeError(f"Error during mesh processing: {e}") - - # Create URDF string - urdf_str = f""" - - - - - - - - - - - - - - - -""" - - urdf_cache_path = save_dir / f"{mesh_md5_key}.urdf" - - try: - with open(urdf_cache_path, "w") as writer: - writer.write(urdf_str) - except IOError as e: - raise RuntimeError(f"Failed to write URDF file: {e}") - - return str(urdf_cache_path) - - -def inv_transform(transform: np.ndarray) -> np.ndarray: - r"""Compute the inverse transformation. - - Args: - transform (np.ndarray): A [4 x 4] transformation matrix. - - Returns: - np.ndarray: The inverse transformation matrix. - """ - r = transform[:3, :3] - t = transform[:3, 3].T - inv_r = r.T - inv_t = -inv_r @ t - - inv_pose = np.eye(4, dtype=np.float32) - inv_pose[:3, :3] = inv_r - inv_pose[:3, 3] = inv_t - - return inv_pose diff --git a/embodichain/lab/sim/robots/__init__.py b/embodichain/lab/sim/robots/__init__.py index 95db284..de4c08a 100644 --- a/embodichain/lab/sim/robots/__init__.py +++ b/embodichain/lab/sim/robots/__init__.py @@ -15,5 +15,4 @@ # ---------------------------------------------------------------------------- from .dexforce_w1 import * -from .robot import Robot from .cobotmagic import CobotMagicCfg diff --git a/embodichain/lab/sim/robots/robot.py b/embodichain/lab/sim/robots/robot.py deleted file mode 100644 index 5d0bc7e..0000000 --- a/embodichain/lab/sim/robots/robot.py +++ /dev/null @@ -1,1177 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import numpy as np -from typing import List, Tuple, Union, Dict, Any, Optional -from abc import ABC, abstractmethod -from copy import deepcopy -import open3d as o3d -import os -from pathlib import Path -import pytorch_kinematics as pk -from matplotlib import colormaps - -import dexsim -from dexsim.models import Entity, MeshObject - -# from dexsim.engine import Articulation -from dexsim.types import DriveType, PhysicalAttr, ArticulationFlag, PrimitiveType - - -from embodichain.utils import logger -from dexsim.utility.env_utils import create_point_cloud_from_o3d_pcd - -# Try to import DriveController, but make it optional -try: - from rlia.kit.drive_controllers import DriveController -except ImportError: - # If rlia is not available, create a dummy type for type checking - DriveController = None - -from embodichain.lab.sim.end_effector import EndEffector - -# from dexsim.utility import inv_transform -from dexsim.sensor import Sensor, MonocularCam, BinocularCam -from embodichain.lab.sim.articulation_entity import ArticulationEntity - -__all__ = ["Robot"] - - -class Robot(ArticulationEntity, ABC): - r""" - Abstract class for robot in simulation. - """ - - def __init__( - self, - urdf_path: Union[str, List[str]] = dict(), - init_qpos: Union[np.ndarray, Dict[str, np.ndarray]] = dict(), - init_base_xpos: Union[np.ndarray, Dict[str, np.ndarray]] = None, - speed_ratio: float = 0.5, - time_step: float = 0.02, - drive_type: DriveType = DriveType.FORCE, - env: dexsim.environment.Arena = None, - **kwargs, - ): - r"""Initialize the robot. - - Args: - urdf_path (str): urdf file path of robot - init_qpos (np.ndarray, optional): [dof] of double. Init robot joint state(home joint state). - init_base_xpos (np.ndarray, optional): [4, 4] of double. Robot base pose in arena coordinate system. - speed_ratio (float, optional): 0 ~ 1. Robot speed ratio. - time_step (float, optional): wait time between two update. Defaults to 1/50. - drive_type (DriveType, optional): DriveType.FORCE or DriveType.FORCE. - env (Arena, optional): dexsim.environment.Arena. Load the first world(None defaults). - kwargs(optional): Accepts additional keyword arguments. - """ - # unique name of the robot. - self.uid = kwargs.get("uid", "Robot") - - super().__init__( - urdf_path=urdf_path, - init_qpos=init_qpos, - init_base_xpos=init_base_xpos, - speed_ratio=speed_ratio, - time_step=time_step, - drive_type=drive_type, - env=env, - **kwargs, - ) - - # Initialize the robot - self._init_robot(**kwargs) - - # Disable self-collision avoidance for the articulation - self.set_enable_self_collision_flag(self.uid, False) - - # Additional parameters - self.attach_end_effectors = {} - - # Build pk_serial_chain - self.pk_serial_chain = self.build_pk_serial_chain() - - def set_enable_self_collision_flag(self, name: str = None, is_enable: bool = False): - r"""Set the self-collision flag for the specified articulation - or all articulations. - - Args: - name (str, optional): Name of the articulation. - If None, apply to all articulations. Defaults to None. - is_enable (bool, optional): Flag to enable - or disable self-collision. Defaults to False. - """ - if name is None or name == self.uid: - self.articulation.set_articulation_flag( - ArticulationFlag.DISABLE_SELF_COLLISION, not is_enable - ) - else: - if name in self.child_articulations: - self.child_articulations[name].set_articulation_flag( - ArticulationFlag.DISABLE_SELF_COLLISION, not is_enable - ) - else: - logger.log_warning(f"Articulation '{name}' not found.") - - @abstractmethod - def _init_robot(self, **kwargs) -> None: - r"""Initializes the robot using the URDF path with necessary parameters.""" - pass - - def get_end_effector(self, uid: str = None): - r"""Get the end effector by its unique identifier. - - Args: - uid (str): Unique identifier for the end effector to be attached. - If None, returns a dictionary of all end effectors. - - Returns: - EndEffector: The end effector associated with the given uid, or None if not found. - """ - if uid is None: - return self.attach_end_effectors - - end_effector = self.attach_end_effectors.get(uid) - return end_effector - - def attach_end_effector( - self, - uid: str, - end_effector: EndEffector, - robot_uid: str = None, - attach_xpos: np.ndarray = np.eye(4), - ee_link_name: str = "ee_link", - **kwargs, - ): - r"""Attach an end effector to the robotic system. - - Args: - uid (str): Unique identifier for the end effector to be attached. - end_effector (EndEffector): An instance of the EndEffector class representing the end effector to be attached. - robot_uid (str, optional): Unique identifier for the robot to which the end effector is to be attached. Defaults to None. - attach_xpos (np.ndarray, optional): 4x4 transformation matrix (homogeneous transformation matrix) representing the pose - at which the end effector should be attached. Defaults to identity matrix. - ee_link_name (str, optional): The link string that represents the end effector link in the robot. Defaults to "ee_link". - **kwargs: Additional keyword arguments for extended functionality (if applicable). - Returns: - tuple: A tuple containing a boolean and a value: - - (bool) False if the end effector is already attached, True otherwise. - - (None) Always returns None as the second element. - """ - # If robot_uid is not provided, use the current object's uid - robot_uid = robot_uid or self.uid - - # Check if the end effector is already attached to the robot - if robot_uid == self.uid or robot_uid in self.child_articulations: - target_articulation = ( - self.articulation - if robot_uid == self.uid - else self.child_articulations[robot_uid] - ) - - # Get degrees of freedom for the target articulation and the end effector - arm_dof = target_articulation.get_dof() - ef_dof = end_effector.get_dof() - - # Get the root link name of the end effector - ef_root_link_name = end_effector.articulation.get_root_link_name() - ef_link_names = end_effector.articulation.get_link_names() - end_effector.drive_type = self.drive_type - - end_effector_joint_names = ( - end_effector.articulation.get_actived_joint_names() - ) - - # Load the end effector's URDF into the target articulation at the specified position - target_articulation.load_urdf( - end_effector.get_urdf_path(), ee_link_name, attach_xpos - ) - - # Remove the previous articulation of the end effector - ef_articulation = end_effector.get_articulation(end_effector.uid) - self._env.remove_articulation(ef_articulation) - - # Assign the target articulation to the end effector - end_effector.articulation = target_articulation - - target_articulation_joint_names = ( - target_articulation.get_actived_joint_names() - ) - - # Update joint indices for the end effector - ef_joint_ids = arm_dof + np.arange(ef_dof) - end_effector.set_joint_ids(ef_joint_ids) - - # Combine initial positions - ef_init_qpos = end_effector._init_qpos - joint_name_to_idx = { - name: idx for idx, name in enumerate(target_articulation_joint_names) - } - ef_ids = np.array( - [joint_name_to_idx[name] for name in end_effector_joint_names] - ) - - robot_ids = np.arange(arm_dof) - target_articulation.set_current_qpos( - self.get_init_qpos(robot_uid), joint_ids=robot_ids - ) - target_articulation.set_current_qpos(ef_init_qpos, joint_ids=ef_ids) - - # Set physical attributes for the target articulation - target_articulation.set_physical_attr(self.default_physical_attrs) - target_articulation.set_drive( - drive_type=self.drive_type, **self.default_drive_param - ) - - # Store end effector details in the class attributes - self.child_articulations[uid] = end_effector.articulation - self._dof[uid] = ef_dof - self._joint_ids[uid] = ef_ids - self.init_qpos[uid] = ef_init_qpos - self.root_link_names[uid] = ef_root_link_name - end_effector.attach_robot_uid = robot_uid - - end_effector._joint_ids[end_effector.uid] = ef_ids - - # TODO: update robot, etc. - # Update the joint ids for other end effector - for ee_name, ee in self.attach_end_effectors.items(): - ee_idx_list = np.array( - [joint_name_to_idx[name] for name in ee.actived_joint_names] - ) - - self._joint_ids[ee_name] = ee_idx_list - ee._joint_ids[ee.uid] = ee_idx_list - - # ee_init_qpos = self.init_qpos[ee_name] - # Update the initial positions in the class attributes - # self.init_qpos[ee.uid] = ee_init_qpos[ee_idx_list] - - # Keep a reference of the attached end effector - self.attach_end_effectors[uid] = end_effector - - # set end-effector physical param and drive param - for link_name in ef_link_names: - target_articulation.set_physical_attr( - attrib=end_effector.default_physical_attrs, - link_name=link_name, - is_replace_inertial=True, - ) - target_articulation.set_drive( - drive_type=self.drive_type, - joint_ids=ef_joint_ids, - **end_effector.default_drive_param, - ) - # end_effector.set_drive(end_effector.drive_type) - return True, end_effector - else: - logger.log_warning(f"Articulation '{uid}' not found.") - return False, None - - def attach_sensor( - self, - sensor: Sensor, - robot_uid: str = None, - attach_xpos: np.ndarray = np.eye(4), - link_name: str = "ee_link", - ): - r"""Attach a sensor to a robot. - - Note: - Currently, this function is only available for Monocular and Binocular sensors. - - Args: - sensor (Sensor): The sensor object to be attached. It can be a MonocularCam or BinocularCam. - robot_uid (str, optional): Unique identifier for the robot to which the sensor will be attached. Defaults to None, which refers to the current robot. - attach_xpos (np.ndarray, optional): 4x4 transformation matrix (homogeneous transformation matrix) representing the pose - at which the sensor should be attached. Defaults to the identity matrix. - link_name (str, optional): The link string that represents the attachment point on the robot. Defaults to "ee_link". - - Returns: - None: This function does not return a value but logs warnings for unsupported sensor types or invalid robot identifiers. - """ - robot_uid = robot_uid or self.uid - - # Check if the robot_uid matches the current robot or a child articulation - if robot_uid == self.uid or robot_uid in self.child_articulations: - target_articulation = ( - self.articulation - if robot_uid == self.uid - else self.child_articulations[robot_uid] - ) - - # Attach the sensor based on its type - if isinstance(sensor, MonocularCam): - target_articulation.attach_node( - obj=sensor.get_node(), - link_name=link_name, - relative_pose=attach_xpos, - ) - elif isinstance(sensor, BinocularCam): - # Attach the left camera node - if sensor._coordinate_system == "center": - relative_pose = sensor._relativate_T_l - else: - relative_pose = sensor.get_relative_transform() - relative_pose[:3, 3] = relative_pose[:3, 3] * -0.5 - target_articulation.attach_node( - obj=sensor.get_node(is_left=True), - link_name=link_name, - relative_pose=attach_xpos @ relative_pose, - ) - # Attach the right camera node - target_articulation.attach_node( - obj=sensor.get_node(is_left=False), - link_name=link_name, - relative_pose=attach_xpos @ np.linalg.inv(relative_pose), - ) - else: - logger.log_warning("Unsupported sensor type: %s", type(sensor).__name__) - else: - logger.log_warning(f"Articulation '{robot_uid}' not found.") - - # @deprecated(reason="Currently unable to detach this component.") - def detach_end_effector( - self, - uid: str, - robot_uid: str = None, - ): - r"""Detach an end effector from the robotic system. - - Args: - uid (str): Unique identifier for the end effector to be detached. - robot_uid (str, optional): Unique identifier for the robot from which the end effector is to be detached. - - Returns: - bool: True if the end effector was successfully detached, False otherwise. - """ - if uid not in self.child_articulations: - logger.log_warning(f"End effector {uid} already detached.") - return False - - robot_uid = robot_uid or self.uid - if robot_uid is not self.uid: - logger.log_warning(f"Articulation with UID '{robot_uid}' not found.") - return False - - if uid in self.init_qpos: - del self.init_qpos[uid] - if uid in self.init_base_xpos: - del self.init_base_xpos[uid] - self.child_articulations[uid].detach_parent() - self.child_articulations.pop(uid) - return True - - def close(self, uid: str = None, target: float = 1.0) -> bool: - r"""Closes the attached end effector, if this manipulator has one. If no UID is provided, - it will close all end effectors associated with the manipulator. - - Args: - uid (str, optional): - A unique identifier for the specific end effector to be closed. - If None, the method will attempt to close all end effectors. - Defaults to None. - target (float, optional): - The target position for the close operation, typically representing - the closure position of the end effector. - Defaults to 1.0 (fully closed). - - Returns: - bool: - Returns True if the end effector(s) were closed successfully, - and False otherwise. If no end effector is found with the given UID, - a warning is logged. - """ - is_success = False - if uid is None or uid == self.uid: - for key, value in self.attach_end_effectors.items(): - if isinstance(value, EndEffector): - value.close(target=target) - is_success = True # Mark success if any end effector is closed - else: - if uid in self.attach_end_effectors: - self.attach_end_effectors[uid].close(target=target) - is_success = True - else: - logger.log_warning(f"End effector with UID '{uid}' not found.") - - return is_success - - def open(self, uid: str = None, target: float = 0.0) -> bool: - r""" - Opens the attached end effector, if this manipulator has one. If no UID is provided, - it will open all end effectors associated with the manipulator. - - Args: - uid (str, optional): - A unique identifier for the specific end effector to be opened. - If None, the method will attempt to open all end effectors. - Defaults to None. - target (float, optional): - The target position for the open operation, typically representing - the opening position of the end effector. - Defaults to 0.0 (fully opened). - - Returns: - bool: - Returns True if the end effector(s) were opened successfully, - and False otherwise. If no end effector is found with the given UID, - a warning is logged. - """ - is_success = False - if uid is None or uid == self.uid: - for key, value in self.attach_end_effectors.items(): - if isinstance(value, EndEffector): - value.open(target=target) - is_success = True # Mark success if any end effector is opened - else: - if uid in self.attach_end_effectors: - self.attach_end_effectors[uid].open(target=target) - is_success = True - else: - logger.log_warning(f"End effector with UID '{uid}' not found.") - - return is_success - - def set_controller(self, controller=None, uid: str = None, **kwargs): - r"""Set a drive or task controller to the robot. - - Args: - controller (DriveController, optional): - The controller instance to be added to the robot. Can be either: - - DriveController: For low-level joint control - uid (str, optional): - Unique identifier for the articulation to be controlled. - If None, uses the robot's main articulation ID. - - Returns: - bool: True if controller was successfully set, False otherwise. - """ - uid = uid or self.uid - - # Check if the robot_uid matches the current robot or a child articulation - if uid == self.uid or uid in self.child_articulations: - target_articulation = ( - self.articulation if uid == self.uid else self.child_articulations[uid] - ) - - if DriveController is not None and isinstance(controller, DriveController) and any( - isinstance(controller, ctl_type) - for ctl_type in self.supported_drive_controller_types.values() - ): - if hasattr(controller, "set_init_qpos"): - controller.set_init_qpos(self.init_qpos[uid]) - controller.set_articulation(target_articulation) - controller.set_control_q_ids(self._joint_ids[uid]) - self.drive_controllers[uid] = controller - else: - logger.log_warning(f"Controller type '{type(controller)}' not support.") - return False - else: - logger.log_warning(f"Articulation '{uid}' not found.") - return False - - return True - - def set_speed_ratio(self, speed_ratio: float, uid: str = None): - r"""Set speed ratio of the robot. - - Args: - speed_ratio (float): 0.0~1.0. robot speed ratio. - uid (str): Uid of the articulation. - """ - uid = uid or self.uid - - if uid == self.uid or uid in self.child_articulations: - self.speed_ratio = speed_ratio - return True - else: - logger.log_warning( - f"Drive controller with UID '{uid}' not found. Please add the drive controller before set speed ratio." - ) - return False - - def get_speed_ratio(self, uid: str = None): - r"""Get speed ratio of the robot. - - Args: - uid (str): Uid of the articulation. - """ - uid = uid or self.uid - - if uid == self.uid or uid in self.child_articulations: - return self.speed_ratio - else: - logger.log_warning( - f"Drive controller with UID '{uid}' not found. Please add the drive controller before set speed ratio." - ) - return None - - @abstractmethod - def get_fk(self, qpos: np.ndarray, uid: str = None) -> np.ndarray: - r"""Get forward kinematic of given joints - - Args: - qpos (np.ndarray): [dof] of float. - uid (str, optional): uid of the articulation. Defaults to None. - - Returns: - np.ndarray: Pose of the end-effector. - """ - pass - - @abstractmethod - def get_ik(self, xpos: np.ndarray, uid: str = None, **kwargs) -> np.ndarray: - r"""Get inverse kinematic of given end-effector pose. - - Args: - xpos (np.ndarray): [4, 4] of matrix. - uid (str, optional): uid of the articulation. Defaults to None. - **kwargs: Other parameters. which can be used to specify the IK method. - - Returns: - np.ndarray: [dof] of float. - """ - pass - - @abstractmethod - def move( - self, - path: Union[np.ndarray, List[np.ndarray]], - is_joint: bool = False, - is_wait: bool = True, - **kwargs, - ) -> bool: - r"""Move the robot to the given path. - - Args: - path (np.ndarray): [4, 4] | [waypoint_num, 4, 4] | [dof] of float or - [waypoint_num, dof] of float. Path in cartesian space or joint space. - is_joint (bool, optional): Whether the path is in joint space. Defaults to False. - is_wait (bool, optional): Whether to synchronize the robot movement. Defaults to True. - **kwargs: Other parameters. - - Returns: - bool: is_move_success - """ - pass - - def get_dof(self, name: str = None) -> Union[int, Dict[str, int]]: - r"""Get degree of freedom (DoF) of the robot. - - Args: - name (str, optional): Name of the articulation. Defaults to None. - - Returns: - Union[int, Dict[str, int]]: - - If `name` is None, returns the total DoF of the robot as an integer. - - If `name` is provided and found, returns the DoF of the specified articulation as an integer. - - If `name` is provided but not found, logs a warning and returns 0. - """ - # TODO: Need to clarify behavior. - if name is None: - if isinstance(self._dof, dict): - return sum(self._dof.values()) - else: - return ( - self._dof - ) # Assuming _dof is an integer representing the total DoF - elif name in self._dof: - return self._dof[ - name - ] # Assuming _dof[name] is an integer representing the DoF of the specified articulation - - logger.log_warning(f"Articulation '{name}' not found.") - return 0 - - def get_proprioception(self, remove_index: bool = True) -> Dict[str, Any]: - r"""Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. - - The default proprioception information includes: - - xpos: End-effector pose in the robot base coordinate system. - - qpos: Joint positions. - - qvel: Joint velocities. - - qf (effort): Joint forces. - - Args: - remove_index (bool, optional): - If True, the suffix index of the UID will be removed. - Defaults to True. - - Returns: - Dict[str, Any]: - A dictionary containing the robot's proprioception information, - where keys are the UID or modified UID and values are dictionaries - containing the proprioception data. - """ - obs = {} - - # Helper function to populate proprioception data for a given name - def populate_proprioception(name: str): - return { - "xpos": self.get_current_xpos(name=name, is_world_coordinates=False), - "qpos": self.get_current_qpos(name=name), - "qvel": self.get_current_qvel(name=name), - "qf": self.get_current_qf(name=name), - } - - # Process the main UID - base_name = self.uid.split("_")[0] if remove_index else self.uid - obs[base_name] = populate_proprioception(self.uid) - - # Process child articulations - for child_name in self.child_articulations: - if remove_index: - import re - - modified_name = re.sub(r"(_\d+)$", "", child_name) - else: - modified_name = child_name - - if modified_name in obs: - if isinstance(obs[modified_name], list): - obs[modified_name].append(populate_proprioception(child_name)) - else: - obs[modified_name] = [ - obs[modified_name], - populate_proprioception(child_name), - ] - else: - obs[modified_name] = populate_proprioception(child_name) - - return obs - - def attach_actor( - self, actor: Entity, relative_xpos: np.ndarray, uid: str = None, **kwargs - ) -> Entity: - r"""Attach an actor to the robot. - - Args: - actor (Entity): - The actor to be attached to the robot. - relative_xpos (np.ndarray): - A [4, 4] matrix representing the relative pose of the actor to the robot. - uid (str, optional): - Unique identifier of the articulation. If None, defaults to the robot's UID. - **kwargs: - Additional parameters for future extension. - - Returns: - Entity: - The attached actor, or None if the attachment failed. - """ - uid = uid or self.uid - - # Define a function to attach the actor to the specified articulation - def attach_to_articulation(articulation): - actor_name = actor.get_name() - self.attached_actors[actor_name] = actor - articulation.attach_node(actor.node, "ee_link", relative_xpos) - return actor - - # Check if UID matches the robot's UID - if uid == self.uid: - return attach_to_articulation(self.articulation) - - # Check if UID matches any child articulation - elif uid in self.child_articulations: - return attach_to_articulation(self.child_articulations[uid]) - - # Log a warning if the articulation is not found - logger.log_warning(f"Articulation with UID '{uid}' not found.") - return None - - def remove_actor(self, actor_name: str, delete: bool = False) -> None: - r"""Remove the attached actor from the robot. - - Args: - actor_name (str): Name of the actor to be removed. - delete (bool, optional): Whether to delete the actor from the simulation. Defaults to False. - """ - if actor_name in self.attached_actors: - for key, value in self.child_articulations.items(): - if isinstance(value, EndEffector): - value.detach(actor_name) - self.attached_actors.pop(actor_name) - if delete: - self._env.remove_actor(actor_name) - - def get_attached_actor_names(self) -> List[str]: - r"""Get names of all attached actors. - - Returns: - List[str]: Names of all attached actors. - """ - return list(self.attached_actors.keys()) - - def compute_qpos_reachability( - self, - name: str, - resolution: float = np.radians(50), - qpos_limits: np.ndarray = None, - cache_mode: str = "memory", - visualize: bool = False, - batch_size: int = 100000, - use_cached: bool = True, - **kwargs, - ) -> Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: - """Compute the robot's reachable workspace by joint space sampling. - - Samples points in joint space and optionally visualizes the resulting end-effector positions - as a colored point cloud. If `visualize` is True, points closer to the robot base are colored green, - transitioning to red for points further away. If `visualize` is False, only the sampling is performed - without any visualization. - - - Args: - name (str): Identifier of the robot drive controller to analyze - resolution (float, optional): Angular resolution for joint space sampling in radians. - Lower values provide finer sampling but increase computation time. - Defaults to 50 degrees (≈0.873 radians) - qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). - If None, uses limits from drive controller or articulation. - Defaults to None - cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". - Defaults to "memory". - visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. - Defaults to False. - batch_size (int, optional): Number of samples to process in each batch. - Defaults to 100000. - use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. - Ignored for "memory" mode. Defaults to True. - - - Returns: - Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: - The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. - The second element is a point cloud handle if visualization is enabled and successful, otherwise None. - """ - from embodichain.lab.sim.utility.workspace_analyzer import ( - WorkspaceAnalyzer, - ) - from embodichain.lab.sim import REACHABLE_XPOS_DIR - - if name not in self.drive_controllers: - logger.log_warning(f"Drive controller '{name}' not found") - return None, None - - # try: - # Get robot configuration - base_xpos = self.get_base_xpos(name=name) - drive_controller = self.drive_controllers[name] - - if qpos_limits is None: - if hasattr(drive_controller, "get_joint_limits"): - res, upper_limits, lower_limits = self.drive_controllers[ - name - ].get_joint_limits() - if not res: - logger.log_warning("Failed to get joint limits") - return None, None - joint_ranges = np.column_stack((lower_limits, upper_limits)) - else: - joint_limits = self.articulation.get_joint_limits() - joint_ranges = joint_limits[self._joint_ids[name]] - else: - joint_ranges = qpos_limits - paths = self.get_urdf_path() - urdf_path = paths if isinstance(paths, str) else paths[self.uid] - robot_name = os.path.splitext(os.path.basename(urdf_path))[0] - # Initialize workspace analyzer - analyzer = WorkspaceAnalyzer( - robot=self, name=name, resolution=resolution, joint_ranges=joint_ranges - ) - # Format resolution to avoid issues with decimal points in paths - resolution_str = f"{resolution:.2f}".replace(".", "_") - # Join into one directory name - save_dir = REACHABLE_XPOS_DIR / f"{robot_name}_{name}_{resolution_str}" - # Sample workspace points - sampled_xpos = analyzer.sample_qpos_workspace( - cache_mode=cache_mode, - save_dir=save_dir, - batch_size=batch_size, - use_cached=use_cached, - ) - if visualize == True: - # Create and configure point cloud visualization - # all_positions = [xpos[:3, 3] for xpos in sampled_xpos] - N = len(sampled_xpos) - all_pos = np.empty((N, 3), dtype=np.float16) - for i, mat in enumerate(sampled_xpos): - all_pos[i] = mat[:3, 3].astype(np.float16) - pcd = analyzer._process_point_cloud(positions=all_pos) - # Transfer to World Coordinate - pcd.transform(base_xpos) - pcd_handle = create_point_cloud_from_o3d_pcd(pcd=pcd, env=self._env) - else: - return sampled_xpos, None - - return sampled_xpos, pcd_handle - - # except Exception as e: - # logger.log_warning(f"Failed to visualize qpos workspace: {str(e)}") - # return None, None - - def compute_xpos_reachability( - self, - name: str, - ref_xpos: np.ndarray, - xpos_resolution: float = 0.2, - qpos_resolution: float = np.radians(60), - pos_eps: float = 5e-4, - rot_eps: float = 5e-4, - max_iterations: int = 1500, - num_samples: int = 5, - batch_size: int = 100000, - save_threshold: int = 10000000, - qpos_limits: np.ndarray = None, - cache_mode: str = "memory", - visualize: bool = True, - use_cached: bool = True, - **kwargs, - ) -> Tuple[ - Optional[list[np.ndarray]], # First return: list of sampled 4x4 poses - Optional[ - dexsim.models.PointCloud - ], # Second return: point cloud handle if visualization is enabled - ]: - """Compute the robot's reachable workspace by Cartesian space sampling. - - Samples points in Cartesian space and checks reachability using inverse kinematics. - If `visualize` is True, visualizes reachable positions as a colored point cloud; - Otherwise, only performs the sampling result as open3d PointCloud. - - - Args: - name (str): Identifier of the robot drive controller to analyze - ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the - orientation for IK solutions - xpos_resolution (float, optional): Cartesian space sampling resolution in meters. - Smaller values provide finer sampling but increase - computation time. Defaults to 0.2 meters. - qpos_resolution (float, optional): Angular resolution for initial joint space - sampling in radians. Used to determine workspace - bounds. Defaults to 60 degrees. - pos_eps (float, optional): Position tolerance for IK solutions in meters. - Defaults to 2e-4 meters. - rot_eps (float, optional): Rotation tolerance for IK solutions in radians. - Defaults to 2e-4 radians. - max_iterations (int, optional): Maximum number of IK iterations per sample. - Defaults to 2000. - num_samples (int, optional): Number of samples to generate in Cartesian space. - Defaults to 10. - qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). - If None, uses limits from drive controller or - articulation. Defaults to None - cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". - Defaults to "memory". - visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. - Defaults to True. - use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. - Ignored for "memory" mode. Defaults to True. - - Returns: - Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: - The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. - The second element is a point cloud handle if visualization is enabled and successful, otherwise None. - """ - from embodichain.lab.sim.utility.workspace_analyzer import ( - WorkspaceAnalyzer, - ) - from embodichain.lab.sim import REACHABLE_XPOS_DIR - - if name not in self.drive_controllers: - logger.log_warning(f"Drive controller '{name}' not found") - return None, None - - # try: - # Get robot configuration - base_xpos = self.get_base_xpos(name=name) - ref_xpos_robot = dexsim.utility.inv_transform(base_xpos) @ ref_xpos - drive_controller = self.drive_controllers[name] - - if qpos_limits is None: - if hasattr(drive_controller, "get_joint_limits"): - res, upper_limits, lower_limits = self.drive_controllers[ - name - ].get_joint_limits() - if not res: - logger.log_warning("Failed to get joint limits") - return None, None - joint_ranges = np.column_stack((lower_limits, upper_limits)) - else: - joint_limits = self.articulation.get_joint_limits() - joint_ranges = joint_limits[self._joint_ids[name]] - else: - joint_ranges = qpos_limits - - paths = self.get_urdf_path() - urdf_path = paths if isinstance(paths, str) else paths[self.uid] - robot_name = os.path.splitext(os.path.basename(urdf_path))[0] - - qpos_resolution_str = f"{qpos_resolution:.2f}".replace(".", "_") - xpos_resolution_str = f"{xpos_resolution:.2f}".replace(".", "_") - # Join into one directory name - save_dir = ( - REACHABLE_XPOS_DIR - / f"{robot_name}_{name}_{qpos_resolution_str}_{xpos_resolution_str}" - ) - - # Initialize workspace analyzer - analyzer = WorkspaceAnalyzer( - robot=self, - name=name, - resolution=qpos_resolution, - joint_ranges=joint_ranges, - ) - # Sample workspace points - sampled_xpos = analyzer.sample_xpos_workspace( - ref_xpos=ref_xpos_robot, - xpos_resolution=xpos_resolution, - qpos_resolution=qpos_resolution, - cache_mode=cache_mode, - batch_size=batch_size, - save_dir=save_dir, - save_threshold=save_threshold, - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - use_cached=use_cached, - ) - - if visualize == visualize: - if sampled_xpos is None: - logger.log_warning("No reachable positions found.") - return None, None - all_positions = [xpos[:3, 3] for xpos in sampled_xpos] - pcd = analyzer._process_point_cloud( - positions=all_positions, is_voxel_down=False - ) - # Transfer to World Coordinate - pcd.transform(base_xpos) - # Create and configure point cloud visualization - pcd_handle = create_point_cloud_from_o3d_pcd(pcd=pcd, env=self._env) - else: - return sampled_xpos, None - - return sampled_xpos, pcd_handle - - def compute_voxel_reachability( - self, - name: str, - voxel_size: float = 0.04, - num_directions: int = 50, - num_yaws=6, - pos_eps: float = 5e-4, - rot_eps: float = 5e-4, - max_iterations: int = 1500, - num_samples: int = 5, - qpos_limits: np.ndarray = None, - cache_mode: str = "memory", - visualize: bool = False, - use_cached: bool = True, - **kwargs, - ) -> Tuple[Optional[List[np.ndarray]], Optional[List[MeshObject]]]: - """ - Compute the robot's reachable workspace by voxel-based sampling. - - Samples voxel centers within a sphere around the robot’s end-effector base - and checks reachability via inverse kinematics. - If `visualize` is True, spawns a colored sphere actor at each voxel center - to indicate success rate; otherwise returns only the sampled poses. - - Args: - name (str): Identifier of the drive controller to analyze. - voxel_size (float, optional): Edge length of each cubic voxel (m). - Smaller values give finer resolution but increase computation time. - Defaults to 0.04. - num_directions (int, optional): Number of sample directions per voxel. - Defaults to 50. - num_yaws (int, optional): Number of discrete yaw rotations **around the local Z-axis** - to try for each sample direction when solving IK. A higher value can - increase rotational coverage but incurs more IK calls. Defaults to 6. - qpos_limits (np.ndarray, optional): Custom joint limits array of shape - (n_joints, 2). If None, retrieves limits from the controller or - articulation. Defaults to None. - cache_mode (str, optional): “memory” or “disk” mode for caching IK - results. Defaults to "memory". - visualize (bool, optional): If True, returns a list of DexSim actor - handles for visualization; otherwise returns None for actors. - Defaults to False. - use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. - Ignored for "memory" mode. Defaults to True. - - Returns: - Tuple[Optional[List[np.ndarray]], Optional[List[MeshObject]]]: - - List of sampled end-effector poses (4×4 matrices), or None on failure. - - List of sphere actor handles if visualize=True, else None. - """ - from embodichain.lab.sim.utility.workspace_analyzer import ( - WorkspaceAnalyzer, - ) - from embodichain.lab.sim import REACHABLE_XPOS_DIR - - # 1) Validate drive controller - if name not in self.drive_controllers: - logger.log_warning(f"Drive controller '{name}' not found") - return None, None - - try: - drive_controller = self.drive_controllers[name] - - # 2) Determine joint limits - if qpos_limits is None: - if hasattr(drive_controller, "get_joint_limits"): - res, upper, lower = drive_controller.get_joint_limits() - if not res: - logger.log_warning("Failed to get joint limits") - return None, None - joint_ranges = np.column_stack((lower, upper)) - else: - all_limits = self.articulation.get_joint_limits() - joint_ranges = all_limits[self._joint_ids[name]] - else: - joint_ranges = qpos_limits - - # 3) Prepare save directory - urdf_paths = self.get_urdf_path() - urdf_path = ( - urdf_paths if isinstance(urdf_paths, str) else urdf_paths[self.uid] - ) - robot_name = os.path.splitext(os.path.basename(urdf_path))[0] - - vs_str = f"{voxel_size:.2f}".replace(".", "_") - nd_str = str(num_directions) - save_dir = ( - REACHABLE_XPOS_DIR / f"Voxel_{robot_name}_{name}_{vs_str}_{nd_str}" - ) - - # 4) Set up workspace analyzer - analyzer = WorkspaceAnalyzer( - robot=self, name=name, joint_ranges=joint_ranges - ) - - # 5) Sample voxels and IK - ( - voxel_centers, - voxel_success_counts, - sampled_xpos, - ) = analyzer.sample_voxel_workspace( - voxel_size=voxel_size, - num_directions=num_directions, - num_yaws=num_yaws, - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - cache_mode=cache_mode, - batch_size=5000, - save_dir=save_dir, - save_threshold=10_000_000, - use_cached=use_cached, - ) - - # 6) Visualization (optional) - if visualize: - colormap = colormaps.get_cmap("jet") - actor_handles: List[MeshObject] = [] - - for idx, (center, count) in enumerate( - zip(voxel_centers, voxel_success_counts), start=1 - ): - # map success rate to color - frac = count / num_directions - color = colormap(1.0 - frac)[:3] - - # build and color sphere mesh - sphere = o3d.geometry.TriangleMesh.create_sphere(voxel_size / 2) - sphere.paint_uniform_color(color) - - verts = np.asarray(sphere.vertices) - inds = np.asarray(sphere.triangles) - cols = np.asarray(sphere.vertex_colors) - cols4 = np.ones((cols.shape[0], 4), dtype=float) - cols4[:, :3] = cols - - # create uniquely named actor e.g. "sphere1", "sphere2", … - actor_name = f"sphere{idx}" - actor = self._env.create_actor(actor_name, True, True) - actor.set_mesh( - vertices=verts, - indices=inds, - shape=PrimitiveType.TRIANGLES, - smooth_angle=-1, - colors=cols4, - ) - actor.set_location(*center) - - actor_handles.append(actor) - - return sampled_xpos, actor_handles - - # 7) Return only sampled poses - return sampled_xpos, None - - except Exception as e: - print(f"Failed to visualize voxel workspace: {e}") - return None, None - - def destroy(self) -> None: - r"""Release the resources of the robot.""" - # Safely handle drive_controllers - if hasattr(self, "drive_controllers") and isinstance( - self.drive_controllers, dict - ): - for key in self.drive_controllers.keys(): - self.drive_controllers[key] = None - - # Safely handle task_controllers - if hasattr(self, "task_controllers") and isinstance( - self.task_controllers, dict - ): - for key in self.task_controllers.keys(): - self.task_controllers[key] = None - - # Safely handle articulation - if hasattr(self, "articulation"): - self.articulation = None - - # Safely handle child_articulations - if hasattr(self, "child_articulations") and isinstance( - self.child_articulations, dict - ): - for key in self.child_articulations.keys(): - if self.child_articulations[key] is not None: - if hasattr(self.child_articulations[key], "get_articulation"): - self._env.remove_articulation( - self.child_articulations[key].get_articulation() - ) - else: - self._env.remove_articulation(self.child_articulations[key]) - - self.child_articulations[key] = None - - @staticmethod - def build_pk_serial_chain(**kwargs) -> Dict[str, pk.SerialChain]: - """Build the serial chain from the URDF file. - - Args: - **kwargs: Additional arguments for building the serial chain. - - Returns: - Dict[str, pk.SerialChain]: The serial chain of the robot. - """ - # paths = self.get_urdf_path() - # urdf_path = paths if isinstance(paths, str) else paths[self.uid] - # chain = pk.build_chain_from_urdf(open(urdf_path, mode="rb").read()) - - # articulation = robot.get_articulation(self.uid) - # link_names = articulation.get_link_names() - # serial_chain = pk.SerialChain(chain, link_names[-1], link_names[0]) - - # return {self.uid: serial_chain} - return {} diff --git a/embodichain/lab/sim/utility/sim_utils copy.py b/embodichain/lab/sim/utility/sim_utils copy.py deleted file mode 100644 index 93d6e48..0000000 --- a/embodichain/lab/sim/utility/sim_utils copy.py +++ /dev/null @@ -1,285 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import os -import dexsim -import open3d as o3d - -from typing import List, Union, Optional - -from dexsim.types import DriveType, ArticulationFlag, LoadOption, RigidBodyShape -from dexsim.engine import Articulation -from dexsim.environment import Env, Arena -from dexsim.models import MeshObject - -from embodichain.lab.sim.cfg import ArticulationCfg, RigidObjectCfg, SoftObjectCfg -from embodichain.lab.sim.shapes import MeshCfg, CubeCfg, SphereCfg -from embodichain.utils import logger -from dexsim.kit.meshproc import get_mesh_auto_uv -import numpy as np - - -def get_dexsim_arenas() -> List[dexsim.environment.Arena]: - """Get all arenas in the default dexsim world. - - Returns: - List[dexsim.environment.Arena]: A list of arenas in the default world, or an empty list if no world is found. - """ - world = dexsim.default_world() - if world is None: - logger.log_warning(f"No default world found. Returning empty arena list.") - return [] - - env = world.get_env() - arenas = env.get_all_arenas() - if len(arenas) == 0: - return [env] - return arenas - - -def get_dexsim_arena_num() -> int: - """Get the number of arenas in the default dexsim world. - - Returns: - int: The number of arenas in the default world, or 0 if no world is found. - """ - arenas = get_dexsim_arenas() - return len(arenas) - - -def get_dexsim_drive_type(drive_type: str) -> DriveType: - """Get the dexsim drive type from a string. - - Args: - drive_type (str): The drive type as a string. - - Returns: - DriveType: The corresponding DriveType enum. - """ - if drive_type == "force": - return DriveType.FORCE - elif drive_type == "acceleration": - return DriveType.ACCELERATION - else: - logger.error(f"Invalid dexsim drive type: {drive_type}") - - -def set_dexsim_articulation_cfg(arts: List[Articulation], cfg: ArticulationCfg) -> None: - """Set articulation configuration for a list of dexsim articulations. - - Args: - arts (List[Articulation]): List of dexsim articulations to configure. - cfg (ArticulationCfg): Configuration object containing articulation settings. - """ - - def get_drive_type(drive_pros): - if isinstance(drive_pros, dict): - return drive_pros.get("drive_type", None) - return getattr(drive_pros, "drive_type", None) - - drive_pros = getattr(cfg, "drive_pros", None) - drive_type = get_drive_type(drive_pros) if drive_pros is not None else None - - if drive_type == "force": - drive_type = DriveType.FORCE - elif drive_type == "acceleration": - drive_type = DriveType.ACCELERATION - else: - logger.log_error(f"Unknow drive type {drive_type}") - - for i, art in enumerate(arts): - art.set_physical_attr(cfg.attrs.attr()) - art.set_articulation_flag(ArticulationFlag.FIX_BASE, cfg.fix_base) - art.set_articulation_flag( - ArticulationFlag.DISABLE_SELF_COLLISION, cfg.disable_self_collision - ) - art.set_solver_iteration_counts( - min_position_iters=cfg.min_position_iters, - min_velocity_iters=cfg.min_velocity_iters, - ) - link_names = art.get_link_names() - for name in link_names: - physical_body = art.get_physical_body(name) - inertia = physical_body.get_mass_space_inertia_tensor() - inertia = np.maximum(inertia, 1e-4) - physical_body.set_mass_space_inertia_tensor(inertia) - - if i == 0 and cfg.compute_uv: - render_body = art.get_render_body(name) - if render_body: - render_body.set_projective_uv() - - # TODO: will crash when exit if not explicitly delete. - # This may due to the destruction of render body order when exiting. - del render_body - - -def is_rt_enabled() -> bool: - """Check if Ray Tracing rendering backend is enabled in the default dexsim world. - - Returns: - bool: True if Ray Tracing rendering is enabled, False otherwise. - """ - config = dexsim.get_world_config() - - return config.renderer == dexsim.types.Renderer.FASTRT - - -def create_cube( - envs: List[Union[Env, Arena]], size: List[float], uid: str = "cube" -) -> List[MeshObject]: - """Create cube objects in the specified environments or arenas. - - Args: - envs (List[Union[Env, Arena]]): List of environments or arenas to create cubes in. - size (List[float]): Size of the cube as [length, width, height] in meters. - uid (str, optional): Unique identifier for the cube objects. Defaults to "cube". - - Returns: - List[MeshObject]: List of created cube mesh objects. - """ - cubes = [] - for i, env in enumerate(envs): - cube = env.create_cube(size[0], size[1], size[2]) - cube.set_name(f"{uid}_{i}") - cubes.append(cube) - return cubes - - -def create_sphere( - envs: List[Union[Env, Arena]], - radius: float, - resolution: int = 20, - uid: str = "sphere", -) -> List[MeshObject]: - """Create sphere objects in the specified environments or arenas. - - Args: - envs (List[Union[Env, Arena]]): List of environments or arenas to create spheres in. - radius (float): Radius of the sphere in meters. - resolution (int, optional): Resolution of the sphere mesh. Defaults to 20. - uid (str, optional): Unique identifier for the sphere objects. Defaults to "sphere". - - Returns: - List[MeshObject]: List of created sphere mesh objects. - """ - spheres = [] - for i, env in enumerate(envs): - sphere = env.create_sphere(radius, resolution) - sphere.set_name(f"{uid}_{i}") - spheres.append(sphere) - return spheres - - -def load_mesh_objects_from_cfg( - cfg: RigidObjectCfg, env_list: List[Arena], cache_dir: Optional[str] = None -) -> List[MeshObject]: - """Load mesh objects from configuration. - - Args: - cfg (RigidObjectCfg): Configuration for the rigid object. - env_list (List[Arena]): List of arenas to load the objects into. - - cache_dir (Optional[str], optional): Directory for caching convex decomposition files. Defaults to None - Returns: - List[MeshObject]: List of loaded mesh objects. - """ - obj_list = [] - body_type = cfg.to_dexsim_body_type() - if isinstance(cfg.shape, MeshCfg): - - option = LoadOption() - option.rebuild_normals = cfg.shape.load_option.rebuild_normals - option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent - option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal - option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent - option.smooth = cfg.shape.load_option.smooth - - cfg: RigidObjectCfg - max_convex_hull_num = cfg.max_convex_hull_num - fpath = cfg.shape.fpath - - compute_uv = cfg.shape.compute_uv - - for i, env in enumerate(env_list): - if max_convex_hull_num > 1: - obj = env.load_actor_with_coacd( - fpath, - duplicate=True, - attach_scene=True, - option=option, - cache_path=cache_dir, - actor_type=body_type, - max_convex_hull_num=max_convex_hull_num, - ) - else: - obj = env.load_actor( - fpath, duplicate=True, attach_scene=True, option=option - ) - obj.add_rigidbody(body_type, RigidBodyShape.CONVEX) - obj.set_name(f"{cfg.uid}_{i}") - obj_list.append(obj) - - if compute_uv: - vertices = obj.get_vertices() - triangles = obj.get_triangles() - - o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) - _, uvs = get_mesh_auto_uv( - o3d_mesh, np.array(cfg.shape.project_direction) - ) - obj.set_uv_mapping(uvs) - - elif isinstance(cfg.shape, CubeCfg): - from embodichain.lab.sim.utility.sim_utils import create_cube - - obj_list = create_cube(env_list, cfg.shape.size, uid=cfg.uid) - for obj in obj_list: - obj.add_rigidbody(body_type, RigidBodyShape.BOX) - - elif isinstance(cfg.shape, SphereCfg): - from embodichain.lab.sim.utility.sim_utils import create_sphere - - obj_list = create_sphere( - env_list, cfg.shape.radius, cfg.shape.resolution, uid=cfg.uid - ) - for obj in obj_list: - obj.add_rigidbody(body_type, RigidBodyShape.SPHERE) - else: - logger.log_error( - f"Unsupported rigid object shape type: {type(cfg.shape)}. Supported types: MeshCfg, CubeCfg, SphereCfg." - ) - return obj_list - - -def load_soft_object_from_cfg( - cfg: SoftObjectCfg, env_list: List[Arena] -) -> List[MeshObject]: - obj_list = [] - - option = LoadOption() - option.rebuild_normals = cfg.shape.load_option.rebuild_normals - option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent - option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal - option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent - option.smooth = cfg.shape.load_option.smooth - option.share_mesh = False - - for i, env in enumerate(env_list): - obj = env.load_actor( - fpath=cfg.shape.fpath, duplicate=True, attach_scene=True, option=option - ) - obj.add_softbody(cfg.voxel_attr.attr(), cfg.physical_attr.attr()) - if cfg.shape.compute_uv: - vertices = obj.get_vertices() - triangles = obj.get_triangles() - - o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) - _, uvs = get_mesh_auto_uv(o3d_mesh, cfg.shape.project_direction) - obj.set_uv_mapping(uvs) - obj.set_name(f"{cfg.uid}_{i}") - obj_list.append(obj) - return obj_list diff --git a/embodichain/lab/sim/utility/workspace_analyzer_new.py b/embodichain/lab/sim/utility/workspace_analyzer_new.py deleted file mode 100644 index 8443d0f..0000000 --- a/embodichain/lab/sim/utility/workspace_analyzer_new.py +++ /dev/null @@ -1,1617 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import gc -import os -import time -import numpy as np -import open3d as o3d -import torch -import dexsim - -from dataclasses import dataclass -from typing import List, Tuple, Optional, Union, Dict, Sequence -from itertools import product, islice -from tqdm import tqdm - -from embodichain.utils import logger -from embodichain.lab.sim.objects import Robot -from scipy.spatial.transform import Rotation as R - - -@dataclass -class JointConfig: - """Joint configuration parameters""" - - range: Tuple[float, float] # Joint motion range - samples: int # Number of samples - - -@dataclass -class JointSamplingConfig: - """Joint space sampling configuration""" - - joints: List[JointConfig] # List of joint configurations - - -def batched(iterable, n): - """Yield successive n-sized batches from iterable.""" - it = iter(iterable) - while True: - batch = list(islice(it, n)) - if not batch: - break - yield batch - - -class WorkspaceAnalyzer: - def __init__( - self, - robot: Robot, - name: str, - joint_ranges: np.ndarray, - resolution: float = np.radians(35), - ): - self.robot = robot - self.solver = self.robot.get_solver(name) - self.control_part = name - self.resolution = resolution - self.joint_ranges = np.array(joint_ranges) - self.device = "cpu" - - self._sampling_configs = self._init_sampling_configs() - - self.control_part_base_xpos = self.robot.get_control_part_base_pose( - name=name, to_matrix=True - ) - - def _get_fk_result(self, qpos: np.ndarray) -> Tuple[bool, np.ndarray]: - r"""Calculate forward kinematics - - Computes the end-effector pose given joint angles. - - Args: - qpos: Joint angles array - - Returns: - tuple: (success, pose) - - success (bool): True if calculation succeeded - - pose (np.ndarray): 4x4 homogeneous transformation matrix - """ - try: - result = self.robot.compute_fk(name=self.control_part, qpos=qpos) - - # Default values - success = False - xpos = np.eye(4) - - # Handle different return types - if isinstance(result, tuple): - if len(result) >= 2: - success, xpos = result[:2] - else: - if result is None: - success = False - else: - success = True - xpos = result - - return success, xpos - - except Exception as e: - logger.log_warning(f"FK calculation failed: {str(e)}") - return False, np.eye(4) - - def _get_ik_result( - self, xpos: np.ndarray, qpos_seed: Optional[np.ndarray] = np.array([]) - ) -> Tuple[bool, np.ndarray]: - """Calculate inverse kinematics - - Computes joint angles that achieve the desired end-effector pose. - - Args: - xpos: Target 4x4 homogeneous transformation matrix - qpos_seed: Initial joint angles for IK solver (optional) - - Returns: - tuple: (success, joint_angles) - - success (bool): True if solution found - - joint_angles (np.ndarray): Solution joint angles - """ - # try: - # Call robot's IK solver - result = self.robot.get_ik( - uid=self.control_part, xpos=xpos, qpos_seed=qpos_seed - ) - - # Default values - success = False - q_sol = np.zeros(self.robot.get_dof(self.control_part)) - - # Process IK result - if isinstance(result, tuple): - if len(result) >= 2: - success, q_sol = result[:2] - else: - if result is None: - success = False - else: - success = True - q_sol = result - - return success, q_sol - - # except Exception as e: - # logger.log_warning(f"IK calculation failed: {str(e)}") - # return False, None - - def _init_sampling_configs(self) -> Dict[str, JointSamplingConfig]: - r"""Initialize joint space sampling configurations - - Returns: - Dictionary mapping config names to sampling configurations - """ - original_ranges = self.joint_ranges.copy() - - self.joint_ranges = np.clip(self.joint_ranges, -np.pi, np.pi) - - clipped_joints = [] - for i, (orig, clipped) in enumerate(zip(original_ranges, self.joint_ranges)): - if not np.allclose(orig, clipped): - clipped_joints.append(i) - - if clipped_joints: - logger.log_info("Some joint ranges were clipped to [-π, π]:") - for joint_idx in clipped_joints: - orig_range = original_ranges[joint_idx] - new_range = self.joint_ranges[joint_idx] - logger.log_info( - f"Joint {joint_idx}: [{orig_range[0]:.3f}, {orig_range[1]:.3f}] -> " - f"[{new_range[0]:.3f}, {new_range[1]:.3f}] rad" - ) - - # Calculate joint range sizes - joint_ranges_size = np.abs(self.joint_ranges[:, 1] - self.joint_ranges[:, 0]) - - # Calculate number of samples per joint - samples = [ - max(3, int(np.ceil(range_size / self.resolution))) - for range_size in joint_ranges_size - ] - - # Create default sampling configuration - sampling_config = JointSamplingConfig( - joints=[ - JointConfig(range=joint_range, samples=sample_num) - for joint_range, sample_num in zip(self.joint_ranges, samples) - ], - ) - - # Log sampling configuration info - logger.log_info(f"Analyze control part: [{self.control_part}]") - logger.log_info( - f"Angular Resolution: {self.resolution:.3f} rad ({np.degrees(self.resolution):.1f}°)" - ) - for i, (joint_range, num_samples) in enumerate(zip(self.joint_ranges, samples)): - range_size = abs(joint_range[1] - joint_range[0]) - actual_resolution = range_size / (num_samples - 1) if num_samples > 1 else 0 - logger.log_info( - f"- Joint {i+1}: Range={range_size:.2f}rad, Samples={num_samples}, " - f"Actual Resolution={actual_resolution:.3f}rad ({np.degrees(actual_resolution):.1f}°)" - ) - - return sampling_config - - def _generate_combinations(self, joint_values): - r"""Generator function to produce joint angle combinations one at a time - - This avoids generating all combinations at once to save memory - """ - if not joint_values: - yield [] - else: - for first in joint_values[0]: - for rest in self._generate_combinations(joint_values[1:]): - yield [first] + rest - - def _process_batch( - self, batch: List[np.ndarray], timeout: float = 10.0 - ) -> List[np.ndarray]: - r"""Process a batch of joint configurations - - Args: - batch: List of joint configurations to process - timeout: Batch processing timeout in seconds - - Returns: - List of end effector XYZ positions - """ - positions = [] - start_time = time.time() - - for qpos in batch: - if time.time() - start_time > timeout: - logger.log_warning(f"Batch processing timeout ({timeout}s)") - break - - try: - qpos = np.array(qpos) - res, xpos = self._get_fk_result(qpos=qpos) - if res: - # Only save XYZ position - positions.append(xpos[:3, 3]) - except Exception as e: - logger.log_warning(f"Error processing joint configuration: {str(e)}") - continue - - return positions - - def _validate_params(self, cache_mode: str, save_dir: str): - r"""Validate input parameters""" - if cache_mode not in ["memory", "disk"]: - raise ValueError("cache_mode must be 'memory' or 'disk'") - - if cache_mode == "disk" and save_dir is None: - raise ValueError("save_dir must be provided when cache_mode is 'disk'") - - def _init_joint_values(self, config: JointSamplingConfig) -> List[np.ndarray]: - r"""Initialize joint sampling values""" - return [ - np.linspace(joint.range[0], joint.range[1], joint.samples) - for joint in config.joints - ] - - def _save_batch_results( - self, positions: List[np.ndarray], save_dir: str, batch_id: int - ): - r"""Save results for a single batch - - Args: - positions: List of XYZ positions - save_dir: Directory to save results - batch_id: Batch identifier - """ - - batch_dir = os.path.join(save_dir, "batches") - # Ensure directory exists - os.makedirs(batch_dir, exist_ok=True) - # Save numpy array - batch_path = os.path.join(batch_dir, f"batch_{batch_id:04d}.npy") - np.save(batch_path, np.array(positions)) - logger.log_info( - f"Saved batch {batch_id}: {len(positions)} points -> {batch_path}" - ) - - def _process_point_cloud( - self, - positions: List[np.ndarray], - voxel_size: float = 0.05, - nb_neighbors: int = 20, - std_ratio: float = 2.0, - is_voxel_down: bool = True, - ) -> o3d.geometry.PointCloud: - r"""Process sampled point cloud data - - Args: - positions: List of XYZ positions - voxel_size: Voxel size (m) - nb_neighbors: Number of neighbors for statistical filter - std_ratio: Standard deviation ratio for statistical filter - - Returns: - o3d.geometry.PointCloud: Processed point cloud - """ - # Create point cloud object - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(np.array(positions)) - - logger.log_info(f"Point cloud processing:") - - if is_voxel_down: - # 1. Voxel downsampling - logger.log_info( - f"- Performing voxel downsampling (voxel_size={voxel_size}m)" - ) - pcd_down = pcd.voxel_down_sample(voxel_size=voxel_size) - - # 2. Statistical outlier removal - logger.log_info( - f"- Removing outliers (neighbors={nb_neighbors}, std_ratio={std_ratio})" - ) - cl, ind = pcd_down.remove_statistical_outlier( - nb_neighbors=nb_neighbors, std_ratio=std_ratio - ) - pcd_clean = pcd_down.select_by_index(ind) - else: - pcd_clean = pcd - - # 3. Estimate normals - logger.log_info("- Estimating point cloud normals") - pcd_clean.estimate_normals( - search_param=o3d.geometry.KDTreeSearchParamHybrid( - radius=voxel_size * 2, max_nn=30 - ) - ) - - # 4. Orient normals consistently - logger.log_info("- Orienting normals consistently") - - pcd_clean.orient_normals_to_align_with_direction() - - # 5. Add color based on distance to origin - points = np.asarray(pcd_clean.points) - - # Calculate distances to origin - distances = np.linalg.norm(points, axis=1) - - # Find the centroid - center = np.mean(points, axis=0) - - # Calculate distances to the centroid - distances_to_center = np.linalg.norm(points - center, axis=1) - - # Normalize distances - max_dist = np.max(distances_to_center) - normalized_distances = distances_to_center / max_dist - - # Create HSV color space (green to red gradient) - hsv_colors = np.zeros((len(points), 3)) - hsv_colors[:, 0] = 0.3333 * ( - 1 - normalized_distances - ) # Hue: green (0.3333) to red (0) - hsv_colors[:, 1] = 1.0 # Saturation: max saturation - hsv_colors[:, 2] = 0.8 # Value: medium brightness - - # Convert HSV to RGB - colors = np.zeros_like(points) - for i in range(len(points)): - h, s, v = hsv_colors[i] - - # HSV to RGB conversion - c = v * s - x = c * (1 - abs((h * 6) % 2 - 1)) - m = v - c - - if h < 1 / 6: - rgb = [c, x, 0] - elif h < 2 / 6: - rgb = [x, c, 0] - elif h < 3 / 6: - rgb = [0, c, x] - elif h < 4 / 6: - rgb = [0, x, c] - elif h < 5 / 6: - rgb = [x, 0, c] - else: - rgb = [c, 0, x] - - colors[i] = [r + m for r in rgb] - - pcd_clean.colors = o3d.utility.Vector3dVector(colors) - - logger.log_info(f"- Original points: {len(positions)}") - logger.log_info(f"- Processed points: {len(pcd_clean.points)}") - logger.log_info( - f"- Distance range: {np.min(distances):.3f}m ~ {np.max(distances):.3f}m" - ) - - return pcd_clean - - def _merge_batch_files(self, save_dir: str, total_batches: int) -> List[np.ndarray]: - r"""Merge all sampled points from batch files - - Args: - save_dir: Directory to save data - total_batches: Total number of batches - - Returns: - List[np.ndarray]: List of all sampled positions - """ - # Get current date for subdirectory name - # current_date = time.strftime("%Y%m%d") - batch_dir = os.path.join(save_dir, "batches") - - logger.log_info("Starting to merge batch files...") - all_xpos = [] - - # Load and process batches - for batch_id in tqdm(range(total_batches), desc="Merging progress"): - batch_path = os.path.join(batch_dir, f"batch_{batch_id:04d}.npy") - - try: - # Load batch data - batch_data = np.load(batch_path) - all_xpos.extend(batch_data) - # Delete processed batch file - # os.remove(batch_path) - except Exception as e: - logger.log_warning(f"Error processing batch {batch_id}: {str(e)}") - - # Remove empty batch directory - if os.path.exists(batch_dir) and not os.listdir(batch_dir): - os.rmdir(batch_dir) - - logger.log_info(f"Merging complete: {len(all_xpos)} sampled points") - return all_xpos - - def sample_qpos_workspace( - self, - resolution: float = None, - cache_mode: str = "memory", # Cache mode "memory" or "disk" - save_dir: str = None, # Save directory - batch_size: int = 100000, # Batch processing size - save_threshold: int = 10000000, # Save threshold - use_cached: bool = True, # Use cached results if available - ) -> List[np.ndarray]: - r"""Sample joint space and calculate corresponding workspace poses - - Args: - resolution: Sampling resolution - cache_mode: Cache mode ("memory" - in-memory list, "disk" - disk storage) - save_dir: Save directory path (must be provided when cache_mode="disk") - batch_size: Number of samples per batch - save_threshold: Number of samples to accumulate before saving in disk mode - use_cached: Whether to use cached results if available (only in disk mode) - - Returns: - List[np.ndarray]: List of valid end effector poses of poses (in memory mode) or empty list (in disk mode) - """ - if resolution is not None: - self.resolution = resolution - self._sampling_configs = self._init_sampling_configs() - - # Validate parameters - self._validate_params(cache_mode, save_dir) - - # Initialize sampling configuration - joint_values = self._init_joint_values(self._sampling_configs) - total_samples = np.prod([len(values) for values in joint_values]) - - logger.log_info( - f"Sampling joint space with resolution {np.degrees(self.resolution):.1f}°..." - ) - logger.log_info(f"Total sample points: {total_samples}") - logger.log_info(f"Cache mode: {cache_mode}") - logger.log_info(f"Save directory: {save_dir if save_dir else 'N/A'}") - logger.log_info(f"Sampling using: {self.device}") - - if cache_mode == "memory": - return self._sample_memory_mode(joint_values, total_samples, batch_size) - else: - return self._sample_disk_mode( - joint_values, - total_samples, - save_dir, - batch_size, - save_threshold, - use_cached, - ) - - def _sample_memory_mode( - self, joint_values: List[np.ndarray], total_samples: int, batch_size: int - ) -> List[np.ndarray]: - r"""Memory mode sampling""" - if not self.robot.pk_serial_chain: - all_xpos = [] - for qpos in tqdm( - product(*joint_values), - total=total_samples, - desc="Memory mode serial sampling", - ): - q = np.array(qpos, dtype=np.float32) - res, xpos = self._get_fk_result(qpos=q) - if res: - all_xpos.append(xpos) - if len(all_xpos) % 1000 == 0: - gc.collect() - return all_xpos - self.chain = self.robot.pk_serial_chain[self.control_part].to( - dtype=torch.float32, device=self.device - ) - sampled_xpos = [] - joint_combinations = product(*joint_values) - - T_tcp = torch.as_tensor(self.solver.get_tcp(), dtype=torch.float32).to( - self.device - ) - - with tqdm( - total=total_samples, - desc=f"Sampling {total_samples} points (batch={batch_size})", - ) as pbar: - for qpos_batch in batched(joint_combinations, batch_size): - # compute and collect - batch_mats = self._compute_batch_xpos(qpos_batch, T_tcp) - sampled_xpos.extend(batch_mats) - - # advance progress bar and cleanup - pbar.update(len(batch_mats)) - gc.collect() - - return sampled_xpos - - def _sample_disk_mode( - self, - joint_values: List[np.ndarray], - total_samples: int, - save_dir: str, - batch_size: int, - save_threshold: int, - use_cached: bool = True, - ) -> List[np.ndarray]: - r"""Disk mode sampling, with serial fallback if no pk_serial_chain.""" - # 1) If batches already exist, just merge & return - batches_dir = os.path.join(save_dir, "batches") - if os.path.exists(batches_dir) and use_cached: - npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] - if npy_files: - return self._merge_batch_files(save_dir, len(npy_files)) - - sampled_xpos = [] - current_batch = [] - total_processed = 0 - batch_count = 0 - - # 2) Choose serial vs. GPU path - if not self.robot.pk_serial_chain: - # serial, one qpos at a time - current_batch = [] - with tqdm(total=total_samples, desc="Disk mode serial sampling") as pbar: - for qpos in product(*joint_values): - q = np.array(qpos, dtype=np.float32) - res, xpos = self._get_fk_result(qpos=q) - if res: - current_batch.append(xpos) - # flush by batch_size - if len(current_batch) >= batch_size: - sampled_xpos.extend(current_batch) - total_processed += len(current_batch) - current_batch = [] - # flush to disk by save_threshold - if len(sampled_xpos) >= save_threshold: - self._save_batch_results( - sampled_xpos, save_dir, batch_count - ) - batch_count += 1 - sampled_xpos = [] - gc.collect() - pbar.update(1) - - else: - self.chain = self.robot.pk_serial_chain[self.control_part].to( - dtype=torch.float32, device=self.device - ) - # GPU‐batched path - T_tcp = torch.as_tensor( - self.robot.get_tcp(self.control_part), - dtype=torch.float32, - device=self.device, - ) - with tqdm( - total=total_samples, desc=f"Sampling in {batch_size}-sized batches" - ) as pbar: - for qpos_batch in batched(product(*joint_values), batch_size): - batch_mats = self._compute_batch_xpos(qpos_batch, T_tcp) - sampled_xpos.extend(batch_mats) - total_processed += len(batch_mats) - # flush to disk by save_threshold - if len(sampled_xpos) >= save_threshold: - self._save_batch_results(sampled_xpos, save_dir, batch_count) - batch_count += 1 - sampled_xpos = [] - gc.collect() - pbar.update(len(batch_mats)) - - # Process remaining samples - if sampled_xpos: - self._save_batch_results(sampled_xpos, save_dir, batch_count) - batch_count += 1 - - logger.log_info( - f"Sampling complete: {total_processed} samples, {batch_count} batches" - ) - - # If there are saved batches, read and merge them to process point cloud - if batch_count > 0: - all_xpos = self._merge_batch_files(save_dir, batch_count) - return all_xpos - - return None - - def sample_xpos_workspace( - self, - ref_xpos: np.ndarray, - xpos_resolution: float = 0.2, - qpos_resolution: float = np.radians(60), - cache_mode: str = "memory", - save_dir: str = None, - batch_size: int = 5000, - save_threshold: int = 10000000, - pos_eps: float = 5e-4, - rot_eps: float = 5e-4, - max_iterations: int = 1500, - num_samples: int = 5, - use_cached: bool = True, - ) -> List[np.ndarray]: - r"""Sample Cartesian space and calculate corresponding joints - - Args: - ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the - orientation for IK solutions. Translation components will - be overridden during sampling. - xpos_resolution (float, optional): Cartesian space sampling resolution in meters. - Smaller values provide finer sampling but increase - computation time. Defaults to 0.2 meters. - qpos_resolution (float, optional): Angular resolution for initial joint space - sampling in radians. Used to determine workspace - bounds. Defaults to 60 degrees. - cache_mode (str, optional): Caching strategy, either: - - "memory": Store samples in memory (faster but memory-intensive) - - "disk": Save samples to disk (slower but memory-efficient) - Defaults to "memory". - save_dir (str, optional): Directory path for saving results when using disk cache. - Must be provided if cache_mode is "disk". Defaults to None. - batch_size (int, optional): Number of samples to process in each batch. - Larger values may improve performance but increase - memory usage. Defaults to 5000. - save_threshold (int, optional): Number of samples to accumulate before saving - to disk in disk mode. Defaults to 10,000,000. - pos_eps (float, optional): Position tolerance for IK solutions in meters. - Defaults to 5e-4. - rot_eps (float, optional): Rotation tolerance for IK solutions in radians. - Defaults to 5e-4. - max_iterations (int, optional): Maximum iterations for IK solver. - Defaults to 1500. - num_samples (int, optional): Number of IK samples to generate for each position. - Defaults to 5. - use_cached (bool, optional): Whether to use cached results if available (only in disk mode) - - Returns: - List[np.ndarray]: List of valid end effector poses - """ - # logger.set_log_level(level="error") - - start_time = time.time() - try: - qpos_sampled_xpos = self.sample_qpos_workspace( - resolution=qpos_resolution, - cache_mode="memory", - batch_size=5000, - save_threshold=save_threshold, - ) - - qpos_all_positions = [xpos[:3, 3] for xpos in qpos_sampled_xpos] - qpos_pcd = self._process_point_cloud(positions=qpos_all_positions) - aabb = qpos_pcd.get_axis_aligned_bounding_box() - - sample_points = self._sample_in_aabb( - aabb.min_bound, aabb.max_bound, xpos_resolution - ) - - # Validate parameters - self._validate_params(cache_mode, save_dir) - - if cache_mode == "memory": - return self._sample_xpos_memory_mode( - positions=sample_points, - ref_xpos=ref_xpos, - batch_size=batch_size, - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - ) - else: - return self._sample_xpos_disk_mode( - positions=sample_points, - ref_xpos=ref_xpos, - save_dir=save_dir, - batch_size=batch_size, - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - save_threshold=save_threshold, - use_cached=use_cached, - ) - finally: - logger.set_log_level(level="info") - # Record the end time - end_time = time.time() - # Calculate the time cost - time_cost = end_time - start_time - logger.log_info(f"Time cost: {time_cost:.2f} seconds") - - def _compute_batch_xpos( - self, qpos_batch: Sequence[np.ndarray], T_tcp: torch.Tensor - ) -> List[np.ndarray]: - """Given a batch of q-poses, compute TCP-transformed FK matrices - and return them as numpy float16 arrays.""" - # 1) to NumPy (float32) → to torch.Tensor on correct device - np_qpos = np.array(qpos_batch, dtype=np.float32) - tensor_qpos = torch.as_tensor(np_qpos, dtype=torch.float32, device=self.device) - - # 2) batched forward kinematics → 4×4 matrices - ret_batch = self.chain.forward_kinematics( - tensor_qpos, end_only=True - ).get_matrix() - - # 3) apply TCP offset - T_final = torch.matmul(ret_batch, T_tcp) - - T_final = torch.bmm( - self.control_part_base_xpos.to(dtype=torch.float32).expand( - T_final.shape[0], -1, -1 - ), - T_final, - ) - - # 4) move to CPU, cast to float16 - T_cpu16 = T_final.cpu().to(dtype=torch.float16) - - # 5) return list of numpy arrays - return [mat.numpy() for mat in T_cpu16] - - def _sample_in_aabb( - self, min_bound: np.ndarray, max_bound: np.ndarray, resolution: float - ) -> np.ndarray: - r"""Uniformly sample within an axis-aligned bounding box (AABB) - - Args: - min_bound: AABB minimum bound [x_min, y_min, z_min] - max_bound: AABB maximum bound [x_max, y_max, z_max] - resolution: Sampling resolution (m) - - Returns: - np.ndarray: Array of sampled points with shape (N, 3) - """ - # Calculate number of samples per axis - num_samples = np.ceil((max_bound - min_bound) / resolution).astype(int) - - # Ensure at least 2 samples per dimension - num_samples = np.maximum(num_samples, 2) - - # Generate sample points for each axis - x = np.linspace(min_bound[0], max_bound[0], num_samples[0]) - y = np.linspace(min_bound[1], max_bound[1], num_samples[1]) - z = np.linspace(min_bound[2], max_bound[2], num_samples[2]) - - # Create a grid of points - X, Y, Z = np.meshgrid(x, y, z) - - # Convert grid to N×3 array - points = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T - - logger.log_info(f"Sampling space range:") - logger.log_info(f"- X: [{min_bound[0]:.3f}, {max_bound[0]:.3f}] m") - logger.log_info(f"- Y: [{min_bound[1]:.3f}, {max_bound[1]:.3f}] m") - logger.log_info(f"- Z: [{min_bound[2]:.3f}, {max_bound[2]:.3f}] m") - logger.log_info(f"Sampling resolution: {resolution:.3f} m") - logger.log_info(f"Number of samples: {len(points)}") - - return points - - def _sample_xpos_memory_mode( - self, - positions: List[np.ndarray], - ref_xpos: np.ndarray, - batch_size: int, - pos_eps: float, - rot_eps: float, - max_iterations: int, - num_samples: int, - ) -> List[np.ndarray]: - r"""Memory mode sampling with batch processing and progress bar - - Args: - positions: List of positions to validate. - ref_xpos: Reference end effector pose. - batch_size (int): Number of positions to process in each batch. - - Returns: - List[np.ndarray]: List of valid end effector poses. - """ - valid_xpos = [] - - # Get the degree of freedom (DOF) of the robot to create joint seed - dof_number = self.robot.get_dof(self.control_part) - - # Total number of positions to process - total_positions = len(positions) - - # TODO: Optimize efficiency by using batch IK if available. - # If self.robot implements get_batch_ik_solution, prefer batch processing for IK to significantly accelerate sampling. - # Otherwise, fall back to single-point IK calls (slower). - # This check ensures the most efficient computation path is used automatically. - # (Batch IK can greatly improve performance for large-scale workspace sampling.) - # Example: - # if hasattr(self.robot, "get_batch_ik_solution"): - if False: - # If the robot has get_batch_ik_solution, use it for batch processing - num_batches = (total_positions // batch_size) + ( - 1 if total_positions % batch_size != 0 else 0 - ) - - # Create progress bar with total samples and batch size - with tqdm( - total=total_positions, desc=f"Sampling in {batch_size}-sized batches" - ) as pbar: - # Iterate through positions in batches - for batch_idx in range(num_batches): - # Select the current batch of positions - batch_positions = positions[ - batch_idx * batch_size : (batch_idx + 1) * batch_size - ] - - # Create a batch of target poses (batch_size, 4, 4) - target_xpos_batch = [] - for point in batch_positions: - target_xpos = ref_xpos.copy() - target_xpos[:3, 3] = point - target_xpos_batch.append(target_xpos) - - # Convert to numpy array (batch_size, 4, 4) - target_xpos_batch = np.array(target_xpos_batch) - # Create joint seed batch of zeros (batch_size, dof) - joint_seed_batch = np.zeros((len(batch_positions), dof_number)) - # Use get_batch_ik_solution for batch processing - res, _ = self.robot.get_batch_ik_solution( - target_xpos_list=target_xpos_batch, # Batch of target poses - joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) - uid=self.control_part, - is_world_coordinates=False, # Set based on your use case - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - ) - - # Append valid target poses to valid_xpos - for j, is_valid in enumerate(res): - if is_valid: - valid_xpos.append(target_xpos_batch[j]) - - # Update the progress bar after processing the batch - pbar.update( - len(batch_positions) - ) # Update progress bar with batch size - - # Perform garbage collection after every batch - if len(valid_xpos) % 1000 == 0: - gc.collect() - - else: - # Fallback to the previous method if get_batch_ik_solution is not available - with tqdm( - total=total_positions, desc="Sampling in single IK calls" - ) as pbar: - for point in positions: - # Construct target pose - target_xpos = ref_xpos.copy() - target_xpos[:3, 3] = point - - # Calculate IK using the old method (get_ik) - res, _ = self.robot.get_ik(uid=self.control_part, xpos=target_xpos) - if res: - valid_xpos.append(target_xpos) - - # Update the progress bar after each point is processed - pbar.update(1) # Update progress bar with 1 point - - # Perform garbage collection after every 1000 valid points - if len(valid_xpos) % 1000 == 0: - gc.collect() - - return valid_xpos if valid_xpos else None - - def _sample_xpos_disk_mode( - self, - positions: List[np.ndarray], - ref_xpos: np.ndarray, - save_dir: str, - batch_size: int, - pos_eps: float, - rot_eps: float, - max_iterations: int, - num_samples: int, - save_threshold: int, - use_cached: bool = True, - ) -> List[np.ndarray]: - r"""Disk mode sampling with batch processing - - Args: - positions: List of positions to validate. - ref_xpos: Reference end effector pose. - save_dir: Directory to save results. - batch_size: Number of samples per batch. - save_threshold: Number of samples to accumulate before saving. - - Returns: - List[np.ndarray]: List of valid end effector poses. - """ - valid_positions = [] - current_batch = [] - total_processed = 0 - batch_count = 0 - # Record the start time - logger.log_info(f"Starting disk mode sampling...") - logger.log_info(f"Save directory: {save_dir}") - - # If there are saved batches, read and return without calculation - batches_dir = os.path.join(save_dir, "batches") - if os.path.exists(batches_dir) and use_cached: - npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] - batch_count = len(npy_files) - - if batch_count > 0: - all_xpos = self._merge_batch_files(save_dir, batch_count) - return all_xpos - - # Check if self.robot has the method get_batch_ik_solution - if hasattr(self.robot, "get_batch_ik_solution"): - # If get_batch_ik_solution is available, use batch processing - with tqdm(total=len(positions), desc="Disk mode sampling") as pbar: - for i in range(0, len(positions), batch_size): - # Select the current batch of positions - batch_positions = positions[i : i + batch_size] - - # Create a batch of target poses (batch_size, 4, 4) - target_xpos_batch = [] - for point in batch_positions: - target_xpos = ref_xpos.copy() - target_xpos[:3, 3] = point - target_xpos_batch.append(target_xpos) - - # Convert to numpy array (batch_size, 4, 4) - target_xpos_batch = np.array(target_xpos_batch) - - # Create the joint seed batch (batch_size, dof) - dof_number = self.robot.get_dof(self.control_part) - joint_seed_batch = np.zeros((len(batch_positions), dof_number)) - - # Use get_batch_ik_solution for batch processing - res, _ = self.robot.get_batch_ik_solution( - target_xpos_list=target_xpos_batch, # Batch of target poses - joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) - uid=self.control_part, - is_world_coordinates=False, # Set based on your use case - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - ) - - # Append valid target poses to valid_positions - for j, is_valid in enumerate(res): - if is_valid: - current_batch.append(target_xpos_batch[j]) - - # Process batch when it reaches batch_size - if len(current_batch) >= batch_size: - valid_positions.extend(current_batch) - total_processed += len(current_batch) - - current_batch = [] - - # Save when reaching the threshold - if len(valid_positions) >= save_threshold: - self._save_batch_results( - valid_positions, save_dir, batch_count - ) - batch_count += 1 - valid_positions = [] - gc.collect() - - # Update the progress bar - pbar.update(len(batch_positions)) # Update with batch size - - else: - # Fallback to the previous method if get_batch_ik_solution is not available - with tqdm(total=len(positions), desc="Disk mode sampling") as pbar: - for point in positions: - # Construct target pose - target_xpos = ref_xpos.copy() - target_xpos[:3, 3] = point - - # Calculate IK using the old method (get_ik) - res, _ = self.robot.compute_ik( - name=self.control_part, pose=target_xpos - ) - if res: - current_batch.append(target_xpos) - - # Process batch when it reaches batch_size - if len(current_batch) >= batch_size: - valid_positions.extend(current_batch) - total_processed += len(current_batch) - - current_batch = [] - - # Save when reaching the threshold - if len(valid_positions) >= save_threshold: - self._save_batch_results( - valid_positions, save_dir, batch_count - ) - batch_count += 1 - valid_positions = [] - gc.collect() - - # Update the progress bar - pbar.update(1) # Update with 1 point per iteration - - # Process remaining data - if current_batch: - valid_positions.extend(current_batch) - total_processed += len(current_batch) - - if valid_positions: - self._save_batch_results(valid_positions, save_dir, batch_count) - batch_count += 1 - - logger.log_info( - f"Sampling complete: {total_processed} samples, {batch_count} batches" - ) - - # If there are saved batches, read and merge them to process point cloud - if batch_count > 0: - all_xpos = self._merge_batch_files(save_dir, batch_count) - return all_xpos - - return None - - def sample_voxel_workspace( - self, - voxel_size: float = 0.04, - num_directions: int = 50, - num_yaws: int = 6, - pos_eps: float = 2e-4, - rot_eps: float = 2e-4, - max_iterations: int = 1500, - num_samples: int = 5, - cache_mode: str = "memory", - save_dir: str = None, - batch_size: int = 5000, - save_threshold: int = 10000000, - use_cached: bool = True, - ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray]]: - r"""Sample Cartesian space using voxel‐based IK reachability. - - Divides the workspace into a grid of voxels around the arm base, then for - each voxel center sweeps through a set of directions and yaw rotations, - calling the IK solver to test reachability. - - Args: - voxel_size (float, optional): - Edge length of each cubic voxel in meters. - Smaller voxels give finer resolution but increase computation. - Defaults to 0.04. - num_directions (int, optional): - Number of unit‐vector directions to sample on the sphere for each - voxel. More directions improve angular coverage at the cost of - additional IK calls. Defaults to 50. - num_yaws (int, optional): - Number of discrete yaw rotations **around the local Z‐axis** to - attempt for each direction when solving IK. Higher values increase - rotational sampling but incur more IK calls. Defaults to 6. - pos_eps (float, optional): - Position tolerance for IK solutions in meters. - Defaults to 5e-4. - rot_eps (float, optional): - Rotation tolerance for IK solutions in radians. - Defaults to 5e-4. - max_iterations (int, optional): - Maximum iterations for IK solver. - Defaults to 1500. - num_samples (int, optional): - Number of IK samples to generate for each position. - Defaults to 5. - cache_mode (str, optional): - Caching strategy for IK results: - - `"memory"`: keep all samples in RAM (fast, memory‐intensive) - - `"disk"`: stream to disk in batches (slower, memory‐efficient) - Defaults to `"memory"`. - save_dir (str, optional): - Directory path for saving/loading cached batches when using - `cache_mode="disk"`. Required in disk mode. Defaults to None. - batch_size (int, optional): - Number of successful IK poses to accumulate before adding them to - the in‐memory pool. Larger values may improve throughput but - increase temporary memory usage. Defaults to 5000. - save_threshold (int, optional): - Number of poses in the in‐memory pool at which point they are - written out to disk as a batch file. Helps limit peak RAM use. - Defaults to 10,000,000. - use_cached: Whether to use cached results if available (only in disk mode) - - Returns: - Tuple[ - np.ndarray, # (M,3) array of voxel‐center coordinates - np.ndarray, # (M,) array of success counts per center - List[np.ndarray] # flat list of all valid 4×4 IK pose matrices - ] - """ - logger.set_log_level(level="error") - - try: - self._validate_params(cache_mode, save_dir) - - logger.log_info(f"Sampling robot workspace with voxel size {voxel_size}...") - logger.log_info(f"Cache mode: {cache_mode}") - logger.log_info(f"Sampling using: {self.device}") - - arm_base_pos = self.robot.get_base_xpos(name=self.control_part)[:3, 3] - arm_ee_pos = self.robot.get_current_xpos(name=self.control_part)[:3, 3] - arm_length = float(np.linalg.norm(arm_ee_pos - arm_base_pos)) - - if cache_mode == "memory": - return self._sample_voxels_memory_mode( - voxel_size, num_directions, num_yaws, arm_base_pos, arm_length - ) - else: - return self._sample_voxels_disk_mode( - voxel_size, - num_directions, - num_yaws, - arm_base_pos, - arm_length, - save_dir=save_dir, - save_threshold=save_threshold, - batch_size=batch_size, - use_cached=use_cached, - ) - finally: - logger.set_log_level(level="info") - - def _voxel_centers_in_sphere(self, arm_base, arm_length, voxel_size): - """ - Compute centers of all voxels of size `voxel_size` whose centers lie - within a sphere of radius `arm_length` around `arm_base`, using the - exact range definitions you provided for x, y, and z. - - Args: - arm_base (sequence of 3 floats): (x, y, z) origin. - arm_length (float): radius of the sphere. - voxel_size (float): edge length of each cubic voxel. - - Returns: - numpy.ndarray of shape (M, 3): each row is a valid (x, y, z) center. - """ - x, y, z = arm_base - r = float(arm_length) - half = voxel_size / 2.0 - - # follow your exact ranges - x_range = np.arange(x - half, x + r + half, voxel_size) - y_range = np.arange(y - half, y + r + half, voxel_size) - z_range = np.arange(z - r / 2 - half, z + r / 2 + half, voxel_size) - - # build full grid of candidate centers - xx, yy, zz = np.meshgrid(x_range, y_range, z_range, indexing="ij") - pts = np.stack((xx, yy, zz), axis=-1).reshape(-1, 3) - - # keep only those inside the sphere of radius r - d2 = np.sum((pts - np.array(arm_base)) ** 2, axis=1) - return pts[d2 <= r**2] - - def _generate_uniform_directions(self, num_directions: int = 50): - """ - Generate vectors in evenly distributed n directions - """ - phi = np.pi * (3.0 - np.sqrt(5.0)) - directions = [] - for i in range(num_directions): - z = 1 - 2 * i / float(num_directions - 1) - theta = phi * i - x = np.sqrt(1 - z * z) * np.cos(theta) - y = np.sqrt(1 - z * z) * np.sin(theta) - directions.append(np.array([x, y, z])) - - return directions - - # Helper function - def normalize(self, v: np.ndarray) -> np.ndarray: - """Normalize a vector to unit length.""" - norm = np.linalg.norm(v) - if norm == 0: - return v # Avoid division by zero - return v / norm - - def _compute_ik_solutions( - self, - centers: List[np.ndarray], - directions: List[np.ndarray], - voxel_size: float, - num_yaws: int, - pos_eps: float = 2e-4, - rot_eps: float = 2e-4, - max_iterations: int = 1500, - num_samples: int = 5, - ) -> List[np.ndarray]: - """ - Compute IK solutions for a set of centers and directions. - This function will process the centers and directions in batches if `get_batch_ik_solution` is available. - - Args: - centers: List of center positions to compute IK for. - directions: List of direction vectors to compute IK for. - voxel_size: Size of the voxel to offset the centers. - num_yaws: Number of yaw sweeps to attempt. - robot_base: Transformation matrix of the robot base. - yaw_rot: Rotation matrix for yaw rotation. - - Returns: - List[np.ndarray]: List of valid IK poses. - """ - valid_poses = [] - success_counts = [0] * len(centers) - - # Create progress bar - pbar = tqdm(total=len(centers), ncols=100, desc="Computing IK (per-center)") - - yaw_angle = 360.0 / num_yaws - yaw_rot = R.from_euler("z", yaw_angle, degrees=True).as_matrix() - robot_base = self.robot.get_base_xpos(name=self.control_part) - - # Check if self.robot has the method get_batch_ik_solution - if hasattr(self.robot, "get_batch_ik_solution"): - # If get_batch_ik_solution is available, we process in batches - for i, center in enumerate(centers): - batch_positions = [] - batch_xpos = [] - - for d in directions: - # Build local frame so that its Z-axis = -d - z_axis = -d - up = ( - np.array([0, 1, 0]) - if abs(z_axis[1]) < 0.9 - else np.array([1, 0, 0]) - ) - x_axis = self.normalize(np.cross(up, z_axis)) - y_axis = np.cross(z_axis, x_axis) - frame = np.stack([x_axis, y_axis, z_axis], axis=1) - - # Shift out to the surface of the voxel - pos = center + d * (voxel_size * 0.5) - - # Try yaw sweeps - for _ in range(num_yaws): - frame = frame @ yaw_rot - xpos = np.eye(4) - xpos[:3, :3] = frame - xpos[:3, 3] = pos - xpos_robot = np.linalg.inv(robot_base) @ xpos - - # Prepare batch for IK computation - batch_positions.append(pos) - batch_xpos.append(xpos_robot) - - # Convert lists to numpy arrays (batch_size, 4, 4) - batch_xpos_array = np.array(batch_xpos) - - # Create the joint seed batch (batch_size, dof) - dof_number = self.robot.get_dof(self.control_part) - joint_seed_batch = np.zeros((len(batch_xpos), dof_number)) - - # Use get_batch_ik_solution for batch processing - res, _ = self.robot.get_batch_ik_solution( - target_xpos_list=batch_xpos_array, # Batch of target poses - joint_seed_list=joint_seed_batch, # Batch of joint seeds (zeros) - uid=self.control_part, - is_world_coordinates=False, # Set based on your use case - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - ) - - # Append valid target poses to valid_poses - for j, is_valid in enumerate(res): - if is_valid: - success_counts[i] += 1 - valid_poses.append(batch_xpos_array[j]) - - # Update the progress bar after processing the batch - pbar.update(1) - - else: - # Fallback to the previous method (get_ik) if get_batch_ik_solution is not available - for i, center in enumerate(centers): - for d in directions: - # Build local frame so that its Z-axis = -d - z_axis = -d - up = ( - np.array([0, 1, 0]) - if abs(z_axis[1]) < 0.9 - else np.array([1, 0, 0]) - ) - x_axis = self.normalize(np.cross(up, z_axis)) - y_axis = np.cross(z_axis, x_axis) - frame = np.stack([x_axis, y_axis, z_axis], axis=1) - - # Shift out to the surface of the voxel - pos = center + d * (voxel_size * 0.5) - - # Try yaw sweeps - for _ in range(num_yaws): - frame = frame @ yaw_rot - xpos = np.eye(4) - xpos[:3, :3] = frame - xpos[:3, 3] = pos - xpos_robot = np.linalg.inv(robot_base) @ xpos - - # Calculate IK using the old method (get_ik) - is_success, _ = self.robot.get_ik( - xpos=xpos_robot, uid=self.control_part - ) - if is_success: - success_counts[i] += 1 - valid_poses.append(xpos_robot.copy()) - break # stop yaw for this direction - - pbar.update(1) - - logger.log_info(f"Sampling complete: {sum(success_counts)} valid positions.") - - return success_counts, valid_poses - - def _sample_voxels_memory_mode( - self, - voxel_size: float, - num_directions: int, - num_yaws: int, - arm_base: np.ndarray, - arm_length: float, - pos_eps: float, - rot_eps: float, - max_iterations: int, - num_samples: int, - ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray]]: - - dirs = self._generate_uniform_directions(num_directions) - centers = self._voxel_centers_in_sphere(arm_base, arm_length, voxel_size) - - success_counts, ik_matrices = self._compute_ik_solutions( - centers, - dirs, - voxel_size, - num_yaws, - pos_eps, - rot_eps, - max_iterations, - num_samples, - ) - - return centers, success_counts, ik_matrices - - def _sample_voxels_disk_mode( - self, - voxel_size: float, - num_directions: int, - num_yaws: int, - arm_base: np.ndarray, - arm_length: float, - pos_eps: float, - rot_eps: float, - max_iterations: int, - num_samples: int, - save_dir: str, - batch_size: int, - save_threshold: int, - use_cached: bool = True, - ) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]: - """ - Returns: - centers: (M,3) np.ndarray of voxel centers - success_counts: (M,) np.ndarray of ints - valid_poses: list of 4x4 np.ndarrays - """ - counts_file = os.path.join(save_dir, "success_counts.npy") - batches_dir = os.path.join(save_dir, "batches") - - # 1) generate dirs & centers - dirs = self._generate_uniform_directions(num_directions) - centers = self._voxel_centers_in_sphere(arm_base, arm_length, voxel_size) - - # 2) if already computed, load & return - if os.path.isdir(batches_dir) and os.path.exists(counts_file) and use_cached: - npy_files = [f for f in os.listdir(batches_dir) if f.endswith(".npy")] - if npy_files: - success_counts = np.load(counts_file) - valid_poses = self._merge_batch_files(save_dir, len(npy_files)) - return centers, success_counts, valid_poses - - os.makedirs(batches_dir, exist_ok=True) - - # 3) run IK sweep - success_counts, valid_poses = self._compute_ik_solutions( - centers, - dirs, - voxel_size, - num_yaws, - pos_eps, - rot_eps, - max_iterations, - num_samples, - ) - if success_counts.sum() == 0: - return centers, success_counts, [] - - # 4) save counts - np.save(counts_file, success_counts) - - # 5) batch & save using a local temp buffer - temp_valid = [] - valid_block = [] - batch_count = 0 - - for pose in valid_poses: - # collect into small blocks of batch_size - valid_block.append(pose) - if len(valid_block) >= batch_size: - # move into temp_valid - temp_valid.extend(valid_block) - valid_block = [] - - # once buffer reaches save_threshold, flush to disk - if len(temp_valid) >= save_threshold: - self._save_batch_results(temp_valid, save_dir, batch_count) - batch_count += 1 - temp_valid = [] - gc.collect() - - # move any remaining block into temp_valid - if valid_block: - temp_valid.extend(valid_block) - - # final flush of anything left in temp_valid - if temp_valid: - self._save_batch_results(temp_valid, save_dir, batch_count) - batch_count += 1 - - # 6) merge all batch files and return - all_poses = self._merge_batch_files(save_dir, batch_count) - return centers, success_counts, all_poses - - -def compute_xpos_reachability( - robot: Robot, - name: str, - ref_xpos: np.ndarray, - xpos_resolution: float = 0.2, - qpos_resolution: float = np.radians(60), - pos_eps: float = 5e-4, - rot_eps: float = 5e-4, - max_iterations: int = 1500, - num_samples: int = 5, - batch_size: int = 100000, - save_threshold: int = 10000000, - qpos_limits: np.ndarray = None, - cache_mode: str = "disk", - visualize: bool = True, - use_cached: bool = True, - **kwargs, -) -> Tuple[ - Optional[list[np.ndarray]], # First return: list of sampled 4x4 poses - Optional[ - dexsim.models.PointCloud - ], # Second return: point cloud handle if visualization is enabled -]: - """Compute the robot's reachable workspace by Cartesian space sampling. - - Samples points in Cartesian space and checks reachability using inverse kinematics. - If `visualize` is True, visualizes reachable positions as a colored point cloud; - Otherwise, only performs the sampling result as open3d PointCloud. - - - Args: - name (str): Identifier of the robot drive controller to analyze - ref_xpos (np.ndarray): Reference end-effector pose matrix (4x4) defining the - orientation for IK solutions - xpos_resolution (float, optional): Cartesian space sampling resolution in meters. - Smaller values provide finer sampling but increase - computation time. Defaults to 0.2 meters. - qpos_resolution (float, optional): Angular resolution for initial joint space - sampling in radians. Used to determine workspace - bounds. Defaults to 60 degrees. - pos_eps (float, optional): Position tolerance for IK solutions in meters. - Defaults to 2e-4 meters. - rot_eps (float, optional): Rotation tolerance for IK solutions in radians. - Defaults to 2e-4 radians. - max_iterations (int, optional): Maximum number of IK iterations per sample. - Defaults to 2000. - num_samples (int, optional): Number of samples to generate in Cartesian space. - Defaults to 10. - qpos_limits (np.ndarray, optional): Custom joint limits array of shape (n_joints, 2). - If None, uses limits from drive controller or - articulation. Defaults to None - cache_mode (str, optional): Cache mode for workspace analysis. Options include "memory" and "disk". - Defaults to "memory". - visualize (bool, optional): If set to True, returns an extra Dexsim PointCloud handle for visualization. - Defaults to True. - use_cached (bool, optional): If True and `cache_mode` is "disk", attempts to load precomputed results. - Ignored for "memory" mode. Defaults to True. - - Returns: - Tuple[Optional[list[np.ndarray]], Optional[dexsim.models.PointCloud]]: - The first element is a list of sampled end-effector poses (4×4 transformation matrices) if sampling succeeds, otherwise None. - The second element is a point cloud handle if visualization is enabled and successful, otherwise None. - """ - from embodichain.lab.sim import REACHABLE_XPOS_DIR - from dexsim.utility.env_utils import create_point_cloud_from_o3d_pcd - from dexsim.utility import inv_transform - - if name not in robot.control_parts: - logger.log_warning(f"Drive controller '{name}' not found") - return None, None - - # try: - # Get robot configuration - # base_xpos = robot.get_control_part_base_pose(name=name, to_matrix=True).squeeze(0).cpu().numpy() - # ref_xpos_robot = inv_transform(base_xpos) @ ref_xpos - ref_xpos_robot = ref_xpos - - if qpos_limits is None: - joint_ranges = ( - robot.body_data.qpos_limits[0].cpu().numpy()[robot.get_joint_ids(name=name)] - ) - else: - joint_ranges = qpos_limits - - urdf_path = robot.cfg.fpath - robot_name = os.path.splitext(os.path.basename(urdf_path))[0] - - qpos_resolution_str = f"{qpos_resolution:.2f}".replace(".", "_") - xpos_resolution_str = f"{xpos_resolution:.2f}".replace(".", "_") - # Join into one directory name - save_dir = ( - REACHABLE_XPOS_DIR - / f"{robot_name}_{name}_{qpos_resolution_str}_{xpos_resolution_str}" - ) - - # Initialize workspace analyzer - analyzer = WorkspaceAnalyzer( - robot=robot, - name=name, - resolution=qpos_resolution, - joint_ranges=joint_ranges, - ) - # Sample workspace points - sampled_xpos = analyzer.sample_xpos_workspace( - ref_xpos=ref_xpos_robot, - xpos_resolution=xpos_resolution, - qpos_resolution=qpos_resolution, - cache_mode=cache_mode, - batch_size=batch_size, - save_dir=save_dir, - save_threshold=save_threshold, - pos_eps=pos_eps, - rot_eps=rot_eps, - max_iterations=max_iterations, - num_samples=num_samples, - use_cached=use_cached, - ) - - if visualize: - if sampled_xpos is None: - logger.log_warning("No reachable positions found.") - return None, None - all_positions = [xpos[:3, 3] for xpos in sampled_xpos] - pcd = analyzer._process_point_cloud( - positions=all_positions, is_voxel_down=False - ) - # Transfer to World Coordinate - # pcd.transform(base_xpos) - # Create and configure point cloud visualization - from embodichain.lab.sim.utility.sim_utils import get_dexsim_arenas - - pcd_handle = create_point_cloud_from_o3d_pcd( - pcd=pcd, env=get_dexsim_arenas()[0] - ) - else: - return sampled_xpos, None - - return sampled_xpos, pcd_handle diff --git a/embodichain/toolkits/interfaces.py b/embodichain/toolkits/interfaces.py index cba8e55..630493f 100644 --- a/embodichain/toolkits/interfaces.py +++ b/embodichain/toolkits/interfaces.py @@ -12,7 +12,6 @@ from embodichain.lab.gym.motion_generation.action.arm_action import ArmAction from embodichain.data.enum import ControlParts, EndEffector, JointType from scipy.spatial.transform import Rotation as R -from embodichain.lab.sim.utility.workspace_analyzer_new import compute_xpos_reachability from embodichain.utils.utility import encode_image import ast ''' @@ -54,7 +53,7 @@ def find_nearest_valid_pose(env, select_arm, pose, xpos_resolution=0.1): # delete the cache every time if isinstance(pose, torch.Tensor): pose = pose.detach().cpu().numpy() - ret, _ = compute_xpos_reachability(env.robot, select_arm, pose, xpos_resolution=xpos_resolution, + ret, _ = env.robot.compute_xpos_reachability(select_arm, pose, xpos_resolution=xpos_resolution, qpos_resolution=np.radians(60), cache_mode="disk", use_cached=False, visualize=False) ret = np.stack(ret, axis=0) From 6ded7071798da6f952b4f99582dcd01e571b96e6 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 20 Jan 2026 11:50:12 +0800 Subject: [PATCH 43/49] Fix: use environment variable to add api key --- embodichain/agents/hierarchy/llm.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py index 06f3c6f..67c3c35 100644 --- a/embodichain/agents/hierarchy/llm.py +++ b/embodichain/agents/hierarchy/llm.py @@ -5,13 +5,22 @@ # Environment configuration # ------------------------------------------------------------------------------ +# Clear proxy if not needed (optional, can be set via environment variables) + os.environ["ALL_PROXY"] = "" os.environ["all_proxy"] = "" -#os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" -#os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" -os.environ["OPENAI_API_VERSION"] = "2024-10-21" -os.environ["AZURE_OPENAI_ENDPOINT"] = "YOUR_ENDPOINT_HERE" -os.environ["AZURE_OPENAI_API_KEY"] = "YOUR_API_KEY_HERE" + +# Proxy configuration (optional, uncomment if needed) +# os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" +# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" + +# API version (optional, defaults to "2024-10-21" if not set) +# os.environ["OPENAI_API_VERSION"] = "2024-10-21" + +# Note: AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set via environment variables +# Example in bash: +# export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/" +# export AZURE_OPENAI_API_KEY="your-api-key" # ------------------------------------------------------------------------------ # LLM factory From a25bf61e82b4b04084d44d26289a0eb992817265 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 20 Jan 2026 16:44:30 +0800 Subject: [PATCH 44/49] Add Readme for EmbodiAgent --- embodichain/agents/README.md | 97 ++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 embodichain/agents/README.md diff --git a/embodichain/agents/README.md b/embodichain/agents/README.md new file mode 100644 index 0000000..b6ca144 --- /dev/null +++ b/embodichain/agents/README.md @@ -0,0 +1,97 @@ +# EmbodiAgent System + + +## Quick Start + +### 1. Prerequisites +Ensure you have access to Azure OpenAI or a compatible LLM endpoint. + +```bash +# Set environment variables +export AZURE_OPENAI_ENDPOINT="[https://your-endpoint.openai.azure.com/](https://your-endpoint.openai.azure.com/)" +export AZURE_OPENAI_API_KEY="your-api-key" +``` + + +### 2. Run the System + +```bash +python embodichain/lab/scripts/run_agent.py \ + --task_name YourTask \ + --gym_config configs/gym/your_task/gym_config.json \ + --agent_config configs/gym/agent/your_agent/agent_config.json \ + --regenerate False +``` + + + +## System Architecture + +The system operates on a closed-loop control cycle: + +1. **Observe**: The `TaskAgent` perceives the environment via multi-view camera inputs. +2. **Plan**: It decomposes the goal into natural language steps. +3. **Code**: The `CodeAgent` translates steps into executable Python code using atomic actions. +4. **Execute**: The code runs in the environment; runtime errors are caught immediately. +5. **Validate**: The `ValidationAgent` analyzes the result images, selects the best camera angle, and judges success. +6. **Refine**: If validation fails, feedback is sent back to the agents to regenerate the plan or code. + + +--- + +## Core Components + +### 1. TaskAgent ("The Planner") +*Located in:* `embodichain/agents/hierarchy/task_agent.py` + +Responsible for high-level reasoning. It parses visual observations and outputs a structured plan. + +* For every step, it generates a specific condition (e.g., "The cup must be held by the gripper") which is used later by the ValidationAgent. +* Prompt Strategies: + * `one_stage_prompt`: Direct VLM-to-Plan generation. + * `two_stage_prompt`: Separates visual analysis from planning logic. + +### 2. CodeAgent ("The Coder") +*Located in:* `embodichain/agents/hierarchy/code_agent.py` + +Translates natural language plans into executable Python code. + + +### 3. ValidationAgent ("The Judger") +*Located in:* `embodichain/agents/hierarchy/validation_agent.py` + +Closes the loop by verifying if the robot actually achieved what it planned. + +* Uses a specialized LLM call (`select_best_view_dir`) to analyze images from all cameras and pick the single best angle that proves the action's outcome, ignoring irrelevant views. +* If an error occurs (runtime or logic), it generates a detailed explanation which is fed back to the `TaskAgent` or `CodeAgent` for the next attempt. + +--- + +## Configuration Guide + +The `Agent` configuration block controls the context provided to the LLMs. All files are resolved relative to `embodichain/database/agent_prompt/`. + +| Parameter | Description | Typical Use | +| :--- | :--- | :--- | +| `task_prompt` | Task-specific goal description | "Pour water from the red cup to the blue cup." | +| `basic_background` | Physical rules & constraints | World coordinate system definitions, safety rules. | +| `atom_actions` | API Documentation | List of available functions (e.g., `drive(action='pick', ...)`). | +| `code_prompt` | Coding guidelines | "Use provided APIs only. Do not use loops." | +| `code_example` | Few-shot examples | Previous successful code snippets to guide style. | + +--- + +## File Structure + +```text +embodichain/agents/ +├── hierarchy/ +│ ├── agent_base.py # Abstract base handling prompts & images +│ ├── task_agent.py # Plan generation logic +│ ├── code_agent.py # Code generation & AST execution engine +│ ├── validation_agent.py # Visual analysis & view selection +│ └── llm.py # LLM configuration and instances +├── mllm/ +│ └── prompt/ # Prompt templates (LangChain) +└── README.md # This file +``` From 0d85e007e886099b1b4f0843d0ec5fea517639dc Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 21 Jan 2026 09:57:06 +0800 Subject: [PATCH 45/49] Remove generated prompt --- .../agents/hierarchy/validation_agent.py | 6 +- .../PourWaterAgent-v3/agent_generated_code.py | 59 ------------------- .../agent_generated_plan.txt | 27 --------- 3 files changed, 2 insertions(+), 90 deletions(-) delete mode 100644 embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py delete mode 100644 embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt diff --git a/embodichain/agents/hierarchy/validation_agent.py b/embodichain/agents/hierarchy/validation_agent.py index 4cb1b99..910988a 100644 --- a/embodichain/agents/hierarchy/validation_agent.py +++ b/embodichain/agents/hierarchy/validation_agent.py @@ -73,10 +73,8 @@ def validate(self, step_names, problematic_code, error_message, image_files): {', '.join(step_names)} Provide the following analysis: - 1. Determine whether each step was executed correctly. - 2. If a step failed, identify which one and explain the cause. - 3. Decide whether the full task succeeded or failed. - 4. If the task failed, provide a precise and detailed explanation. + 1. Decide whether the full task succeeded or failed. + 2. If the task failed, provide a precise and detailed explanation. Below is a potentially problematic piece of code and the corresponding execution error: diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py deleted file mode 100644 index 3399a23..0000000 --- a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_code.py +++ /dev/null @@ -1,59 +0,0 @@ -# Step 1: Grasp the bottle -drive( - right_arm_action=grasp( - robot_name='right_arm', - obj_name='bottle', - pre_grasp_dis=0.10 - ), - left_arm_action=None -) - -# Step 2: Move the bottle to the pouring position relative to the cup -drive( - right_arm_action=move_relative_to_object( - robot_name='right_arm', - obj_name='cup', - x_offset=0.05, - y_offset=-0.10, - z_offset=0.125 - ), - left_arm_action=None -) - -# Step 3: Pour water into the cup -drive( - right_arm_action=rotate_eef( - robot_name='right_arm', - degree=-90 - ), - left_arm_action=None -) - -# Step 4: Return the bottle to its upright position -drive( - right_arm_action=rotate_eef( - robot_name='right_arm', - degree=90 - ), - left_arm_action=None -) - -# Step 5: Place the bottle at the specified location -drive( - right_arm_action=place_on_table( - robot_name='right_arm', - obj_name='bottle', - x=0.7, - y=-0.1, - pre_place_dis=0.08 - ), - left_arm_action=None -) - -# Step 6: Return the right arm to its initial pose -drive( - right_arm_action=back_to_initial_pose( - robot_name='right_arm' - ), - left_arm_action=None -) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt b/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt deleted file mode 100644 index 1eeaef8..0000000 --- a/embodichain/database/agent_prompt/PourWaterAgent-v3/agent_generated_plan.txt +++ /dev/null @@ -1,27 +0,0 @@ -**[PLANS]:** - -Step 1: Grasp the bottle — `grasp(robot_name='right_arm', obj_name='bottle', pre_grasp_dis=0.10)` - -Step 2: Move the bottle to the pouring position relative to the cup — `move_relative_to_object(robot_name='right_arm', obj_name='cup', x_offset=0.05, y_offset=-0.10, z_offset=0.125)` - -Step 3: Pour water into the cup — `rotate_eef(robot_name='right_arm', degree=-90)` - -Step 4: Return the bottle to its upright position — `rotate_eef(robot_name='right_arm', degree=90)` - -Step 5: Place the bottle at the specified location — `place_on_table(robot_name='right_arm', obj_name='bottle', x=0.7, y=-0.1, pre_place_dis=0.08)` - -Step 6: Return the right arm to its initial pose — `back_to_initial_pose(robot_name='right_arm')` - -**[VALIDATION_CONDITIONS]:** - -Step 1: The right arm should be holding the bottle securely. - -Step 2: The bottle should be positioned at an offset of [0.05, -0.10, 0.125] relative to the cup. - -Step 3: The bottle should be tilted, pouring water into the cup. - -Step 4: The bottle should be returned to an upright position, held by the right arm. - -Step 5: The bottle should be placed at the location [0.7, -0.1] on the table, and the right arm should release it. - -Step 6: The right arm should be in its initial pose, not holding any object. \ No newline at end of file From 7542965770f2160859588a5f408472e704e7e7c6 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 21 Jan 2026 10:12:49 +0800 Subject: [PATCH 46/49] Remove scripts with feedback --- embodichain/lab/scripts/run_agent_feedback.py | 340 --------------- .../lab/scripts/run_agent_visual_feedback.py | 395 ------------------ 2 files changed, 735 deletions(-) delete mode 100644 embodichain/lab/scripts/run_agent_feedback.py delete mode 100644 embodichain/lab/scripts/run_agent_visual_feedback.py diff --git a/embodichain/lab/scripts/run_agent_feedback.py b/embodichain/lab/scripts/run_agent_feedback.py deleted file mode 100644 index ba0f5a2..0000000 --- a/embodichain/lab/scripts/run_agent_feedback.py +++ /dev/null @@ -1,340 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import gymnasium -import numpy as np -import argparse -import os -import torch -import json - -from threading import Thread -from tqdm import tqdm -from embodichain.utils.utility import load_json -from embodichain.lab.sim import SimulationManagerCfg -from embodichain.lab.gym.envs import EmbodiedEnvCfg -from embodichain.lab.gym.utils.gym_utils import ( - config_to_cfg, -) -from embodichain.lab.scripts.generate_video import visualize_data_dict -from embodichain.data.data_engine.online.online_generator import ( - OnlineGenerator, -) -from embodichain.utils.logger import log_warning, log_info, log_error -from embodichain.data import database_agent_prompt_dir -from pathlib import Path -import traceback - -def test_code(env, code_file_path, check_num=10, kwargs=None): - """Test the generated code multiple times and evaluate task success rate. - - Uses env.code_agent.act() to execute the code, which handles all the - necessary imports and execution logic. - """ - # ====== Read code content for display ====== - with open(code_file_path, "r", encoding="utf-8") as f: - code_content = f.read() - - # ====== Initialize kwargs ====== - if kwargs is None: - kwargs = {} - if "env" not in kwargs: - kwargs["env"] = env - - # ====== Initialize counters ====== - epid, suc_num, fail_num = 0, 0, 0 - run_records = [] - - # Error categories (same style as previous run() function) - error_list = [ - "Code can not run", # 0 - "Task executed but failed", # 1 - "No error occurred" # 2 - ] - error_num = [0, 0, 0] - - print("\033[93m" + "[Start Testing Task Success Rate]" + "\033[0m") - - # ====== Print generated source ====== - print("\n\033[92m=== generated source code ===\033[0m") - print(code_content) - print("\033[92m=== End ===\033[0m\n") - - # ====== Main loop ====== - for epid in range(check_num): - env.reset() - kwargs['current_check_num'] = epid - error_id = None - - try: - # Use code_agent.act() to execute the code - # This method handles all imports and execution logic - env.get_wrapper_attr("code_agent").act(code_file_path, **kwargs) - - # Check result - if env.get_wrapper_attr("is_task_success")().item(): - print(f"simulate data episode {suc_num} success! (seed = {epid})") - suc_num += 1 - run_records.append("Success!") - else: - print(f"simulate data episode {suc_num} fail! (seed = {epid})") - fail_num += 1 - error_id = 1 - run_records.append(error_list[1]) - - except Exception as e: - # Execution error - exec_trace = traceback.format_exc() - error_list[0] = exec_trace # store full traceback for summary - error_id = 0 - fail_num += 1 - - run_records.append(f"Code can not run, error: {exec_trace}") - - print("-------------") - print(f"simulate data episode {suc_num} fail! (seed = {epid})") - print("Error:", exec_trace) - print("-------------") - - # Count error category - if error_id is not None: - error_num[error_id] += 1 - - # ====== Find most frequent error ====== - if sum(error_num) == 0: - max_error_index = 2 # no errors, fallback to "NO error" - max_error_count = 0 - else: - max_error_index = error_num.index(max(error_num)) - max_error_count = error_num[max_error_index] - - # ====== Summary ====== - print(f'\nComplete test, success rate: {suc_num}/{check_num}') - print(f'Error message list: {error_list}') - print(f'Error count: {error_num}') - print(f'Run records: {run_records}') - - return suc_num / check_num, error_list[max_error_index], max_error_count, run_records - - -def generate_function( - env, - generated_codes, - error_messages, - log_dir=None, -): - # Initialize env - env.reset() - - # First attempt case - create initial code file - if len(error_messages) == 0: - code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")(regenerate=True, log_dir=log_dir) - # Generate code based on error status - else: - code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")( - regenerate=True, log_dir=log_dir, generated_codes=generated_codes, error_messages=error_messages) - - try: - # Update this section to match the new return values of the run function - success_rate, error_message, error_count, run_records = test_code(env, code_file_path, check_num=5, kwargs=kwargs) - generated_codes.append(code) - error_messages.append(error_message) - return code, success_rate, error_message, error_count, run_records - except KeyboardInterrupt: - print("Test interrupted by user") - return code, 0, "Test interrupted by user", 10, None - except Exception as e: - import traceback - error_trace = traceback.format_exc() - print(f"Error occurred during testing: {e}\n{error_trace}") - return code, 0, f"Error occurred during testing: {e}", 10, None - - -def main(args, env, gym_config): - - log_info("Start agent data generation with feedback.", color="green") - - # Initialize variables - generate_num = 5 - success_threshold = 0.6 - suc_list = [] - - # Store each round's code and error - error_messages = [] - generated_codes = [] - - # Store the best code and its success rate - best_code = None - best_success_rate = 0 - best_run_records = None - - # Create log file name with timestamp - import datetime - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - log_dir = Path(database_agent_prompt_dir) / args.task_name / "feedback_logs" / timestamp - os.makedirs(log_dir, exist_ok=True) - log_filename = f"{log_dir}/{args.task_name}.log" - - # Store all attempt records - all_attempts = [] - - # Try multiple generations until success or limit reached - for id in range(generate_num): - log_info(f"Generate code for task: {args.task_name} ({id + 1}/{generate_num})", color='green') - - # Generate and test code - code, success_rate, error_message, error_count, run_records = generate_function( - env, generated_codes, error_messages, log_dir) - - # Track success rates - suc_list.append(success_rate) - - # Record this attempt - attempt_record = { - "attempt_id": id + 1, - "success_rate": success_rate, - "error_message": error_message, - "error_count": error_count, - "code": code, - "run_records": run_records - } - all_attempts.append(attempt_record) - - # Save best code - if success_rate > best_success_rate: - best_success_rate = success_rate - best_code = code - best_run_records = run_records - print(f"New best code found, success rate: {best_success_rate}") - - # Check if generation was successful - if success_rate >= success_threshold: - print(f"Successfully generated code for task: {args.task_name}") - break - - # Handle failure case - log_warning(f"The generated code fail for task: {args.task_name} (attempt {id+1}) with succuss rate {success_rate}\nError message: \n{error_message}") - - # Ensure the final saved code is the best one - if best_code is not None: - file_name = log_dir / "agent_generated_code.py" - print(f"Saving best code, success rate: {best_success_rate}") - with open(file_name, 'w') as file: - file.write(best_code) - - print(f"Best success rate: {best_success_rate}") - print(f"All success rates: {suc_list}") - - # Save log data to file - with open(log_filename, 'w') as log_file: - log_data = { - "task_name": args.task_name, - "best_success_rate": best_success_rate, - "success_rates": suc_list, - "best_code": best_code, - "best_run_records": best_run_records, - "all_attempts": all_attempts - } - json.dump(log_data, log_file, indent=2) - - print(f"Log has been saved to: {log_filename}") - - if args.headless: - env.reset(options={"final": True}) - - -if __name__ == "__main__": - np.set_printoptions(5, suppress=True) - torch.set_printoptions(precision=5, sci_mode=False) - - parser = argparse.ArgumentParser() - parser.add_argument( - "--num_envs", - help="The number of environments to run in parallel.", - default=1, - type=int, - ) - parser.add_argument( - "--device", - type=str, - default="cpu", - help="device to run the environment on, e.g., 'cpu' or 'cuda'", - ) - parser.add_argument( - "--headless", - help="Whether to perform the simulation in headless mode.", - default=False, - action="store_true", - ) - parser.add_argument( - "--enable_rt", - help="Whether to use RTX rendering backend for the simulation.", - default=False, - action="store_true", - ) - parser.add_argument( - "--render_backend", - help="The rendering backend to use for the simulation.", - default="egl", - type=str, - ) - parser.add_argument( - "--gpu_id", - help="The GPU ID to use for the simulation.", - default=0, - type=int, - ) - parser.add_argument( - "--save_video", - help="Whether to save data as video.", - default=False, - action="store_true", - ) - parser.add_argument( - "--save_path", help="path", default="./outputs/thirdviewvideo", type=str - ) - parser.add_argument( - "--debug_mode", - help="Enable debug mode.", - default=False, - action="store_true", - ) - parser.add_argument( - "--filter_visual_rand", - help="Whether to filter out visual randomization.", - default=False, - action="store_true", - ) - - parser.add_argument("--online_config", type=str, help="online_config", default="") - parser.add_argument("--gym_config", type=str, help="gym_config", default="") - parser.add_argument("--task_name", type=str, help="Name of the task.", required=True) - - # Agent related configs - parser.add_argument("--agent_config", type=str, help="agent_config", default=None, required=True) - parser.add_argument("--regenerate", type=bool, help="Whether regenerate code if already existed.", default=False) - - args = parser.parse_args() - - if args.num_envs != 1: - log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") - - gym_config = load_json(args.gym_config) - cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) - cfg.filter_visual_rand = args.filter_visual_rand - - agent_config = load_json(args.agent_config) - - cfg.num_envs = args.num_envs - cfg.sim_cfg = SimulationManagerCfg( - headless=args.headless, - sim_device=args.device, - enable_rt=args.enable_rt, - gpu_id=args.gpu_id, - ) - - env = gymnasium.make(id=gym_config["id"], cfg=cfg, agent_config=agent_config, task_name=args.task_name) - main(args, env, gym_config) diff --git a/embodichain/lab/scripts/run_agent_visual_feedback.py b/embodichain/lab/scripts/run_agent_visual_feedback.py deleted file mode 100644 index 50be9e7..0000000 --- a/embodichain/lab/scripts/run_agent_visual_feedback.py +++ /dev/null @@ -1,395 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. -# -# All rights reserved. -# ---------------------------------------------------------------------------- - -import gymnasium -import numpy as np -import argparse -import os -import torch -import json - -from threading import Thread -from tqdm import tqdm -from embodichain.utils.utility import load_json -from embodichain.lab.sim import SimulationManagerCfg -from embodichain.lab.gym.envs import EmbodiedEnvCfg -from embodichain.lab.gym.utils.gym_utils import ( - config_to_cfg, -) -from embodichain.lab.scripts.generate_video import visualize_data_dict -from embodichain.data.data_engine.online.online_generator import ( - OnlineGenerator, -) -from embodichain.utils.logger import log_warning, log_info, log_error -from embodichain.data import database_agent_prompt_dir -from pathlib import Path -import traceback -import glob - -def test_code(env, code_file_path, check_num=10, kwargs=None): - """Test the generated code multiple times and evaluate task success rate. - - Uses env.code_agent.act() to execute the code, which handles all the - necessary imports and execution logic. - """ - # ====== Read code content for display ====== - with open(code_file_path, "r", encoding="utf-8") as f: - code_content = f.read() - - # ====== Initialize kwargs ====== - if kwargs is None: - kwargs = {} - if "env" not in kwargs: - kwargs["env"] = env - - # ====== Initialize counters ====== - epid, suc_num, fail_num = 0, 0, 0 - run_records = [] - - # Error categories (same style as previous run() function) - error_list = [ - "Code can not run", # 0 - "Task executed but failed", # 1 - "No error occurred" # 2 - ] - error_num = [0, 0, 0] - - print("\033[93m" + "[Start Testing Task Success Rate]" + "\033[0m") - - # ====== Print generated source ====== - print("\n\033[92m=== generated source code ===\033[0m") - print(code_content) - print("\033[92m=== End ===\033[0m\n") - - # ====== Main loop ====== - for epid in range(check_num): - env.reset() - kwargs['current_check_num'] = epid - error_id = None - - try: - # Use code_agent.act() to execute the code - # This method handles all imports and execution logic - env.get_wrapper_attr("code_agent").act(code_file_path, **kwargs) - - # Check result - if env.get_wrapper_attr("is_task_success")().item(): - print(f"simulate data episode {epid} success!") - suc_num += 1 - run_records.append("Success!") - else: - print(f"simulate data episode {epid} fail!") - fail_num += 1 - error_id = 1 - run_records.append(error_list[1]) - - except Exception as e: - # Execution error - exec_trace = traceback.format_exc() - error_list[0] = exec_trace # store full traceback for summary - error_id = 0 - fail_num += 1 - - run_records.append(f"Code can not run, error: {exec_trace}") - - print("-------------") - print(f"simulate data episode {epid} fail!") - print("Error:", exec_trace) - print("-------------") - - # Count error category - if error_id is not None: - error_num[error_id] += 1 - - # ====== Find most frequent error ====== - if sum(error_num) == 0: - max_error_index = 2 # no errors, fallback to "NO error" - max_error_count = 0 - else: - max_error_index = error_num.index(max(error_num)) - max_error_count = error_num[max_error_index] - - # ====== Observe at the most frequently occurred error ====== - observation_feedback = None - if max_error_count > 0: - observe_index = 0 - highest_priority = len(error_list) - - for i, record in enumerate(run_records): - if record == "Success!": - continue - - current_priority = len(error_list) - for p, error_pattern in enumerate(error_list): - if error_pattern in record: - current_priority = p - break - - if current_priority < highest_priority: - highest_priority = current_priority - observe_index = i - - if highest_priority == len(error_list) and len(run_records) > 0: - observe_index = 0 - - print(f"Selected observation index observe_index={observe_index}, corresponding error: {run_records[observe_index]}") - - log_dir = kwargs["log_dir"] # require log_dir - gen_id = kwargs.get("id", "unknown") # fallback to a safe string - episode_id = observe_index - save_dir = log_dir / "camera_images" / f"{gen_id}_generate_num" / f"episode{episode_id}" - print(f"Looking for images in: {save_dir}") - - image_files = sorted(glob.glob(os.path.join(save_dir, f"*.png"))) - - # Extract step names from image filenames - step_names = [] - for f in image_files: - filename = os.path.basename(f) - first_underscore_pos = filename.find('_') - if first_underscore_pos != -1: - step_name = filename[first_underscore_pos + 1:].rsplit('.', 1)[0] - step_names.append(step_name) - else: - step_names.append(filename.rsplit('.', 1)[0]) - print(f"Image search pattern: episode{episode_id}_*.png, number of files found: {len(image_files)}") - - observation_feedback = env.get_wrapper_attr("validation_agent").validate(step_names, code_content, error_list[observe_index], image_files) - log_info(f"Observation feedback: {observation_feedback}") - - # ====== Summary ====== - print(f'\nComplete test, success rate: {suc_num}/{check_num}') - print(f'Error message list: {error_list}') - print(f'Error count: {error_num}') - print(f'Run records: {run_records}') - - return suc_num / check_num, error_list[max_error_index], observation_feedback, max_error_count, run_records - - -def generate_function( - env, - generated_codes, - error_messages, - observation_feedbacks, - log_dir=None, - id=0, -): - # Initialize env - env.reset() - - # First attempt case - create initial code file - if len(error_messages) == 0: - code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")(regenerate=True, log_dir=log_dir, id=id) - # Generate code based on error status - else: - code_file_path, kwargs, code = env.get_wrapper_attr("generate_code_for_actions")( - regenerate=True, log_dir=log_dir, generated_codes=generated_codes, error_messages=error_messages, - observation_feedbacks=observation_feedbacks, id=id) - - try: - # Update this section to match the new return values of the run function - success_rate, error_message, observation_feedback, error_count, run_records = test_code(env, code_file_path, check_num=5, kwargs=kwargs) - generated_codes.append(code) - error_messages.append(error_message) - observation_feedbacks.append(observation_feedback) - return code, success_rate, error_message, observation_feedback, error_count, run_records - except KeyboardInterrupt: - print("Test interrupted by user") - return code, 0, "Test interrupted by user", None, 10, None - except Exception as e: - import traceback - error_trace = traceback.format_exc() - print(f"Error occurred during testing: {e}\n{error_trace}") - return code, 0, f"Error occurred during testing: {e}", None, 10, None - - -def main(args, env, gym_config): - - log_info("Start agent data generation with visual feedback.", color="green") - - # Initialize variables - generate_num = 5 - success_threshold = 0.6 - suc_list = [] - - # Store each round's code and error - generated_codes = [] - error_messages = [] - observation_feedbacks = [] - - # Store the best code and its success rate - best_code = None - best_success_rate = 0 - best_run_records = None - - # Create log file name with timestamp - import datetime - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - log_dir = Path(database_agent_prompt_dir) / args.task_name / "visual_feedback_logs" / timestamp - os.makedirs(log_dir, exist_ok=True) - log_filename = f"{log_dir}/{args.task_name}.log" - - # Store all attempt records - all_attempts = [] - - # Try multiple generations until success or limit reached - for id in range(generate_num): - log_info(f"Generate code for task: {args.task_name} ({id + 1}/{generate_num})", color='green') - - # Generate and test code - code, success_rate, error_message, observation_feedback, error_count, run_records = generate_function( - env, generated_codes, error_messages, observation_feedbacks, log_dir, id=id+1) - - # Track success rates - suc_list.append(success_rate) - - # Record this attempt - attempt_record = { - "attempt_id": id + 1, - "success_rate": success_rate, - "error_message": error_message, - "observation_feedback": observation_feedback, - "error_count": error_count, - "code": code, - "run_records": run_records - } - all_attempts.append(attempt_record) - - # Save best code - if success_rate > best_success_rate: - best_success_rate = success_rate - best_code = code - best_run_records = run_records - print(f"New best code found, success rate: {best_success_rate}") - - # Check if generation was successful - if success_rate >= success_threshold: - print(f"Successfully generated code for task: {args.task_name}") - break - - # Handle failure case - log_warning(f"The generated code fail for task: {args.task_name} (attempt {id+1}) with succuss rate {success_rate}\nError message: \n{error_message}") - - # Ensure the final saved code is the best one - if best_code is not None: - file_name = log_dir / "agent_generated_code.py" - print(f"Saving best code, success rate: {best_success_rate}") - with open(file_name, 'w') as file: - file.write(best_code) - - print(f"Best success rate: {best_success_rate}") - print(f"All success rates: {suc_list}") - - # Save log data to file - with open(log_filename, 'w') as log_file: - log_data = { - "task_name": args.task_name, - "best_success_rate": best_success_rate, - "success_rates": suc_list, - "best_code": best_code, - "best_run_records": best_run_records, - "all_attempts": all_attempts - } - json.dump(log_data, log_file, indent=2) - - print(f"Log has been saved to: {log_filename}") - - if args.headless: - env.reset(options={"final": True}) - - -if __name__ == "__main__": - np.set_printoptions(5, suppress=True) - torch.set_printoptions(precision=5, sci_mode=False) - - parser = argparse.ArgumentParser() - parser.add_argument( - "--num_envs", - help="The number of environments to run in parallel.", - default=1, - type=int, - ) - parser.add_argument( - "--device", - type=str, - default="cpu", - help="device to run the environment on, e.g., 'cpu' or 'cuda'", - ) - parser.add_argument( - "--headless", - help="Whether to perform the simulation in headless mode.", - default=False, - action="store_true", - ) - parser.add_argument( - "--enable_rt", - help="Whether to use RTX rendering backend for the simulation.", - default=False, - action="store_true", - ) - parser.add_argument( - "--render_backend", - help="The rendering backend to use for the simulation.", - default="egl", - type=str, - ) - parser.add_argument( - "--gpu_id", - help="The GPU ID to use for the simulation.", - default=0, - type=int, - ) - parser.add_argument( - "--save_video", - help="Whether to save data as video.", - default=False, - action="store_true", - ) - parser.add_argument( - "--save_path", help="path", default="./outputs/thirdviewvideo", type=str - ) - parser.add_argument( - "--debug_mode", - help="Enable debug mode.", - default=False, - action="store_true", - ) - parser.add_argument( - "--filter_visual_rand", - help="Whether to filter out visual randomization.", - default=False, - action="store_true", - ) - - parser.add_argument("--online_config", type=str, help="online_config", default="") - parser.add_argument("--gym_config", type=str, help="gym_config", default="") - parser.add_argument("--task_name", type=str, help="Name of the task.", required=True) - - # Agent related configs - parser.add_argument("--agent_config", type=str, help="agent_config", default=None, required=True) - parser.add_argument("--regenerate", type=bool, help="Whether regenerate code if already existed.", default=False) - - args = parser.parse_args() - - if args.num_envs != 1: - log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") - - gym_config = load_json(args.gym_config) - cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) - cfg.filter_visual_rand = args.filter_visual_rand - - agent_config = load_json(args.agent_config) - - cfg.num_envs = args.num_envs - cfg.sim_cfg = SimulationManagerCfg( - headless=args.headless, - sim_device=args.device, - enable_rt=args.enable_rt, - gpu_id=args.gpu_id, - ) - - env = gymnasium.make(id=gym_config["id"], cfg=cfg, agent_config=agent_config, task_name=args.task_name) - main(args, env, gym_config) From e0fac7192cb85ca91a1cefc4d5035b8191c1c226 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 21 Jan 2026 10:13:21 +0800 Subject: [PATCH 47/49] Reformat files --- embodichain/agents/hierarchy/code_agent.py | 39 +- embodichain/agents/hierarchy/task_agent.py | 1 - embodichain/agents/mllm/prompt/task_prompt.py | 12 +- .../datasets/sim_real_unified_dict_dataset.py | 46 +- embodichain/data/enum.py | 12 +- .../agent_generated_code.py | 46 +- .../20260120_105727/agent_generated_code.py | 54 +- .../20260120_110033/agent_generated_code.py | 50 +- .../20260120_110212/agent_generated_code.py | 46 +- embodichain/lab/gym/envs/action_bank/utils.py | 4 +- .../lab/gym/envs/managers/dataset_manager.py | 5 +- .../envs/tasks/tableware/base_agent_env.py | 17 +- .../gym/envs/tasks/tableware/pour_water_v3.py | 2 +- .../envs/tasks/tableware/rearrangement_v3.py | 6 +- .../motion_generation/action/arm_action.py | 6 +- .../gym/motion_generation/planner/utils.py | 1 + embodichain/lab/gym/utils/misc.py | 3 + embodichain/lab/scripts/run_agent.py | 4 +- embodichain/toolkits/interfaces.py | 585 +++++++++++++----- 19 files changed, 568 insertions(+), 371 deletions(-) diff --git a/embodichain/agents/hierarchy/code_agent.py b/embodichain/agents/hierarchy/code_agent.py index b528ba1..310c032 100644 --- a/embodichain/agents/hierarchy/code_agent.py +++ b/embodichain/agents/hierarchy/code_agent.py @@ -175,17 +175,17 @@ def generate(self, **kwargs): def act(self, code_file_path, **kwargs): """Execute generated code with proper execution environment. - + Supports two modes: 1. If code defines 'create_agent_action_list' function, call it 2. If code contains module-level drive() calls, execute them directly """ import ast - + # Read the generated code file with open(code_file_path, "r") as f: code_content = f.read() - + # Build execution namespace with necessary imports ns = { "__builtins__": __builtins__, @@ -193,7 +193,7 @@ def act(self, code_file_path, **kwargs): "__file__": str(code_file_path), "kwargs": kwargs, # Make kwargs available for injection } - + # Import toolkit functions into namespace try: exec( @@ -205,20 +205,21 @@ def act(self, code_file_path, **kwargs): raise RuntimeError( "Failed to import embodichain.toolkits.interfaces" ) from e - + # Parse code to check if it defines a function or contains module-level calls tree = ast.parse(code_content) - + # Check if code defines create_agent_action_list function has_function = any( - isinstance(node, ast.FunctionDef) and node.name == "create_agent_action_list" + isinstance(node, ast.FunctionDef) + and node.name == "create_agent_action_list" for node in tree.body ) - + if has_function: # Execute code (function will be defined in namespace) exec(code_content, ns, ns) - + # Call the function if it exists if "create_agent_action_list" in ns: result = ns["create_agent_action_list"](**kwargs) @@ -236,30 +237,36 @@ def visit_Call(self, node): self.generic_visit(node) # Inject **kwargs if not present has_kwargs = any( - kw.arg is None and isinstance(kw.value, ast.Name) and kw.value.id == "kwargs" + kw.arg is None + and isinstance(kw.value, ast.Name) + and kw.value.id == "kwargs" for kw in node.keywords ) if not has_kwargs: node.keywords.append( - ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load())) + ast.keyword( + arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()) + ) ) return node - + # Transform AST to inject kwargs tree = InjectKwargs().visit(tree) ast.fix_missing_locations(tree) - + # Compile and execute transformed code compiled_code = compile(tree, filename=str(code_file_path), mode="exec") exec(compiled_code, ns, ns) - + # Collect actions from drive() calls if they were executed # drive() function stores actions in env._episode_action_list if "env" in kwargs: env = kwargs["env"] if hasattr(env, "_episode_action_list") and env._episode_action_list: - print(f"Collected {len(env._episode_action_list)} actions from module-level drive() calls.") + print( + f"Collected {len(env._episode_action_list)} actions from module-level drive() calls." + ) return env._episode_action_list - + print("Code executed successfully, but no actions were collected.") return [] diff --git a/embodichain/agents/hierarchy/task_agent.py b/embodichain/agents/hierarchy/task_agent.py index 8e2e81b..11db647 100644 --- a/embodichain/agents/hierarchy/task_agent.py +++ b/embodichain/agents/hierarchy/task_agent.py @@ -137,4 +137,3 @@ def generate(self, **kwargs) -> str: def act(self, *args, **kwargs): return super().act(*args, **kwargs) - diff --git a/embodichain/agents/mllm/prompt/task_prompt.py b/embodichain/agents/mllm/prompt/task_prompt.py index de85b8d..3fa8ca9 100644 --- a/embodichain/agents/mllm/prompt/task_prompt.py +++ b/embodichain/agents/mllm/prompt/task_prompt.py @@ -22,7 +22,11 @@ def one_stage_prompt(observations, **kwargs): Step 2: LLM generates task instructions using only those IDs. """ # Encode image - observation = observations["rgb"].cpu().numpy() if isinstance(observations["rgb"], torch.Tensor) else observations["rgb"] + observation = ( + observations["rgb"].cpu().numpy() + if isinstance(observations["rgb"], torch.Tensor) + else observations["rgb"] + ) kwargs.update({"observation": encode_image(observation)}) # Build hybrid prompt @@ -89,7 +93,11 @@ def two_stage_prompt(observations, **kwargs): ] ) - observation = observations["rgb"].cpu().numpy() if isinstance(observations["rgb"], torch.Tensor) else observations["rgb"] + observation = ( + observations["rgb"].cpu().numpy() + if isinstance(observations["rgb"], torch.Tensor) + else observations["rgb"] + ) kwargs.update({"observation": encode_image(observation)}) # for LLM generate task descriptions prompt_query = ChatPromptTemplate.from_messages( diff --git a/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py index d4b3477..fdf3ad7 100644 --- a/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py +++ b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py @@ -433,9 +433,9 @@ def parse_sim_dict( "step_id": step_id, "instruction": "", "camera_used": camera_used, - "instruction": f["language_prompt"] - if f.get("language_prompt", None) - else "", + "instruction": ( + f["language_prompt"] if f.get("language_prompt", None) else "" + ), } assert ( @@ -465,26 +465,26 @@ def parse_sim_dict( camera_used=camera_used, ) if PrivilegeType.MASK.value in self.data_meta.get("privileges", []): - parse_dict[ - cam + "_{}".format(PrivilegeType.MASK.value) - ] = SimRealUnifiedDictDataset.parse_img( - f, - step_id, - first_idx, - cam, - self.img_history_size, - PrivilegeType.MASK.value, - camera_used=camera_used, + parse_dict[cam + "_{}".format(PrivilegeType.MASK.value)] = ( + SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + PrivilegeType.MASK.value, + camera_used=camera_used, + ) ) if PrivilegeType.EXTEROCEPTION.value in self.data_meta.get("privileges", []): if obs[PrivilegeType.EXTEROCEPTION.value][camera_used[0]].shape[0] != 0: - parse_dict[ - PrivilegeType.EXTEROCEPTION.value - ] = SimRealUnifiedDictDataset.parse_exteroception( - f, - step_id, - chunk_size, - camera_used=camera_used, + parse_dict[PrivilegeType.EXTEROCEPTION.value] = ( + SimRealUnifiedDictDataset.parse_exteroception( + f, + step_id, + chunk_size, + camera_used=camera_used, + ) ) if Modality.GEOMAP.value in self.data_meta.get("additional_modality", []): @@ -675,9 +675,9 @@ def realdata2simdata( "camera_used": [ cam_name for cam_name in given_camera_used if cam_name in camera_used ], - "instruction": f["language_prompt"] - if f.get("language_prompt", None) - else "", + "instruction": ( + f["language_prompt"] if f.get("language_prompt", None) else "" + ), } # save all supported proprio and action types. robot_meta_config = {"arm_dofs": self.arm_dofs, "observation": {}} diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index 920bb3a..b629a2f 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -33,6 +33,7 @@ class SemanticMask(IntEnum): FOREGROUND = 1 ROBOT = 2 + class Modality(Enum): STATES = "states" STATE_INDICATOR = "state_indicator" @@ -44,6 +45,7 @@ class Modality(Enum): GEOMAP = "geomap" # e.g., depth, point cloud, etc. VISION_LANGUAGE = "vision_language" # e.g., image + lang + class EndEffector(Enum): GRIPPER = "gripper" DEXTROUSHAND = "hand" @@ -62,6 +64,7 @@ class ControlParts(Enum): HEAD = "head" WAIST = "waist" + class ControlPartsMappingW1(Enum): ANKLE_IN_TORSO = 0 KNEE_IN_TORSO = 1 @@ -71,6 +74,7 @@ class ControlPartsMappingW1(Enum): NECK1_IN_HEAD = 0 NECK2_IN_HEAD = 1 + class Hints(Enum): EEF = ( ControlParts.LEFT_EEF.value, @@ -93,12 +97,14 @@ class ActionMode(Enum): ABSOLUTE = "" RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. + class PrivilegeType(Enum): EXTEROCEPTION = "exteroception" MASK = "mask" STATE = "state" PROGRESS = "progress" + SUPPORTED_PROPRIO_TYPES = [ ControlParts.LEFT_ARM.value + EefType.POSE.value, ControlParts.RIGHT_ARM.value + EefType.POSE.value, @@ -121,18 +127,22 @@ class PrivilegeType(Enum): PrivilegeType.MASK.value, ] + class ArmEnum(IntEnum): LEFT_ARM_ONLY = 1 RIGHT_ARM_ONLY = 2 DUAL_ARM = 3 + class ArmName(Enum): LEFT_ARM_ONLY = "left_arm" RIGHT_ARM_ONLY = "right_arm" + def is_dual_arms(dofs: int) -> bool: return dofs > 10 + class HandQposNormalizer: """ A class for normalizing and denormalizing dexterous hand qpos data. @@ -200,4 +210,4 @@ def normalize_hand_qpos( # Step 2: Normalize to [0, 1] qpos_normalized = (qpos_clipped - qpos_min) / (qpos_max - qpos_min + 1e-8) - return qpos_normalized \ No newline at end of file + return qpos_normalized diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py index c8599a4..36de8c1 100644 --- a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py @@ -1,53 +1,29 @@ # Step 1: Grasp the fork with the left arm and the spoon with the right arm drive( - left_arm_action=grasp( - robot_name='left_arm', - obj_name='fork', - pre_grasp_dis=0.10 - ), + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), right_arm_action=grasp( - robot_name='right_arm', - obj_name='spoon', - pre_grasp_dis=0.10 - ) + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), ) # Step 2: Reorient both end-effectors to a downward-facing pose drive( - left_arm_action=orient_eef( - robot_name='left_arm', - direction='down' - ), - right_arm_action=orient_eef( - robot_name='right_arm', - direction='down' - ) + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), ) # Step 3: Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center drive( left_arm_action=place_on_table( - robot_name='left_arm', - obj_name='fork', - x=0.0, - y=0.16, - pre_place_dis=0.08 + robot_name="left_arm", obj_name="fork", x=0.0, y=0.16, pre_place_dis=0.08 ), right_arm_action=place_on_table( - robot_name='right_arm', - obj_name='spoon', - x=0.0, - y=-0.16, - pre_place_dis=0.08 - ) + robot_name="right_arm", obj_name="spoon", x=0.0, y=-0.16, pre_place_dis=0.08 + ), ) # Step 4: Return both arms to their initial poses drive( - left_arm_action=back_to_initial_pose( - robot_name='left_arm' - ), - right_arm_action=back_to_initial_pose( - robot_name='right_arm' - ) -) \ No newline at end of file + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py index c4fd2fd..c94c82c 100644 --- a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py @@ -1,62 +1,34 @@ # Step 1 — Grasp the Fork and Spoon Simultaneously drive( - left_arm_action=grasp( - robot_name='left_arm', - obj_name='fork', - pre_grasp_dis=0.10 - ), + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), right_arm_action=grasp( - robot_name='right_arm', - obj_name='spoon', - pre_grasp_dis=0.10 - ) + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), ) # Step 2 — Reorient End-Effectors to Downward-Facing Pose drive( - left_arm_action=orient_eef( - robot_name='left_arm', - direction='down' - ), - right_arm_action=orient_eef( - robot_name='right_arm', - direction='down' - ) + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), ) # Step 3 — Place the Fork and Spoon on Opposite Sides of the Plate drive( left_arm_action=move_relative_to_object( - robot_name='left_arm', - obj_name='plate', - x_offset=0, - y_offset=0.16, - z_offset=0 + robot_name="left_arm", obj_name="plate", x_offset=0, y_offset=0.16, z_offset=0 ), right_arm_action=move_relative_to_object( - robot_name='right_arm', - obj_name='plate', - x_offset=0, - y_offset=-0.16, - z_offset=0 - ) + robot_name="right_arm", obj_name="plate", x_offset=0, y_offset=-0.16, z_offset=0 + ), ) drive( - left_arm_action=open_gripper( - robot_name='left_arm' - ), - right_arm_action=open_gripper( - robot_name='right_arm' - ) + left_arm_action=open_gripper(robot_name="left_arm"), + right_arm_action=open_gripper(robot_name="right_arm"), ) # Step 4 — Return Both Arms to Initial Pose drive( - left_arm_action=back_to_initial_pose( - robot_name='left_arm' - ), - right_arm_action=back_to_initial_pose( - robot_name='right_arm' - ) -) \ No newline at end of file + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py index c099786..661b61b 100644 --- a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py @@ -1,58 +1,34 @@ # Step 1 — Grasp the fork and spoon drive( - left_arm_action=grasp( - robot_name='left_arm', - obj_name='fork', - pre_grasp_dis=0.10 - ), + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), right_arm_action=grasp( - robot_name='right_arm', - obj_name='spoon', - pre_grasp_dis=0.10 - ) + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), ) # Step 2 — Reorient end-effectors to downward-facing pose drive( - left_arm_action=orient_eef( - robot_name='left_arm', - direction='down' - ), - right_arm_action=orient_eef( - robot_name='right_arm', - direction='down' - ) + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), ) # Step 3 — Place fork and spoon on opposite sides of the plate drive( left_arm_action=move_relative_to_object( - robot_name='left_arm', - obj_name='plate', - y_offset=0.16 + robot_name="left_arm", obj_name="plate", y_offset=0.16 ), right_arm_action=move_relative_to_object( - robot_name='right_arm', - obj_name='plate', - y_offset=-0.16 - ) + robot_name="right_arm", obj_name="plate", y_offset=-0.16 + ), ) drive( - left_arm_action=open_gripper( - robot_name='left_arm' - ), - right_arm_action=open_gripper( - robot_name='right_arm' - ) + left_arm_action=open_gripper(robot_name="left_arm"), + right_arm_action=open_gripper(robot_name="right_arm"), ) # Step 4 — Return arms to initial pose drive( - left_arm_action=back_to_initial_pose( - robot_name='left_arm' - ), - right_arm_action=back_to_initial_pose( - robot_name='right_arm' - ) -) \ No newline at end of file + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py index 65dc8b5..2664860 100644 --- a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py @@ -1,53 +1,29 @@ # Step 1 — Grasp the Fork and Spoon Simultaneously drive( - left_arm_action=grasp( - robot_name='left_arm', - obj_name='fork', - pre_grasp_dis=0.10 - ), + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), right_arm_action=grasp( - robot_name='right_arm', - obj_name='spoon', - pre_grasp_dis=0.10 - ) + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), ) # Step 2 — Reorient Both End-Effectors to a Downward-Facing Pose drive( - left_arm_action=orient_eef( - robot_name='left_arm', - direction='down' - ), - right_arm_action=orient_eef( - robot_name='right_arm', - direction='down' - ) + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), ) # Step 3 — Place the Fork and Spoon at Specified Positions drive( left_arm_action=place_on_table( - robot_name='left_arm', - obj_name='fork', - x=None, - y=0.16, - pre_place_dis=0.08 + robot_name="left_arm", obj_name="fork", x=None, y=0.16, pre_place_dis=0.08 ), right_arm_action=place_on_table( - robot_name='right_arm', - obj_name='spoon', - x=None, - y=-0.16, - pre_place_dis=0.08 - ) + robot_name="right_arm", obj_name="spoon", x=None, y=-0.16, pre_place_dis=0.08 + ), ) # Step 4 — Return Both Arms to Initial Pose drive( - left_arm_action=back_to_initial_pose( - robot_name='left_arm' - ), - right_arm_action=back_to_initial_pose( - robot_name='right_arm' - ) -) \ No newline at end of file + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/lab/gym/envs/action_bank/utils.py b/embodichain/lab/gym/envs/action_bank/utils.py index 8347dd0..1ca395f 100644 --- a/embodichain/lab/gym/envs/action_bank/utils.py +++ b/embodichain/lab/gym/envs/action_bank/utils.py @@ -28,6 +28,7 @@ def get_init_affordance(scope: str, tag: str = "init") -> str: return "{}_{}_qpos".format(scope, tag) + def get_control_part(env, agent_uid): from embodichain.lab.gym.utils.misc import _data_key_to_control_part @@ -42,7 +43,8 @@ def get_control_part(env, agent_uid): control_parts=control_parts, data_key=agent_uid, ) - + + def get_control_part_joint_ids(env, key: str) -> List[int]: from embodichain.data.enum import ( ControlParts, diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py index e950297..d1ea304 100644 --- a/embodichain/lab/gym/envs/managers/dataset_manager.py +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -91,7 +91,10 @@ def __init__(self, cfg: object, env: EmbodiedEnv): # This allows legacy code (like action_bank) to access robot_meta via env.metadata["dataset"]["robot_meta"] for mode_cfgs in self._mode_functor_cfgs.values(): for functor_cfg in mode_cfgs: - if "robot_meta" in functor_cfg.params or "instruction" in functor_cfg.params: + if ( + "robot_meta" in functor_cfg.params + or "instruction" in functor_cfg.params + ): if not hasattr(env, "metadata"): env.metadata = {} if "dataset" not in env.metadata: diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index e901f50..e9e1232 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -184,7 +184,7 @@ def generate_code_for_actions(self, regenerate=False, **kwargs): # Task planning print(f"\033[92m\nStart task planning.\n\033[0m") - + task_agent_input = self.task_agent.get_composed_observations( env=self, regenerate=regenerate, **kwargs ) @@ -195,8 +195,8 @@ def generate_code_for_actions(self, regenerate=False, **kwargs): code_agent_input = self.code_agent.get_composed_observations( env=self, regenerate=regenerate, **kwargs ) - code_agent_input['task_plan'] = task_plan - + code_agent_input["task_plan"] = task_plan + code_file_path, kwargs, code = self.code_agent.generate(**code_agent_input) return code_file_path, kwargs, code @@ -207,7 +207,7 @@ def create_demo_action_list(self, regenerate=False): ) action_list = self.code_agent.act(code_file_path, **kwargs) return action_list - + def to_dataset( self, id: str = None, @@ -225,7 +225,7 @@ def to_dataset( obs_list = getattr(self, "_episode_obs_list", []) if action_list is None: action_list = getattr(self, "_episode_action_list", []) - + if len(obs_list) == 0 or len(action_list) == 0: logger.log_warning("No episode data found. Returning empty dataset.") return { @@ -246,7 +246,11 @@ def to_dataset( from embodichain.lab.gym.utils.misc import camel_to_snake if not hasattr(self, "folder_name") or self.curr_episode == 0: - robot_class_name = self.robot.__class__.__name__ if hasattr(self, "robot") and self.robot is not None else "Robot" + robot_class_name = ( + self.robot.__class__.__name__ + if hasattr(self, "robot") and self.robot is not None + else "Robot" + ) self.folder_name = f"{camel_to_snake(self.__class__.__name__)}_{camel_to_snake(robot_class_name)}" if os.path.exists(os.path.join(dataset_path, self.folder_name)): self.folder_name = f"{self.folder_name}_{random.randint(0, 1000)}" @@ -254,4 +258,3 @@ def to_dataset( return fetch_imitation_dataset( self, obs_list, action_list, id, self.folder_name ) - diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py index 20963f2..c4b469b 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py @@ -75,4 +75,4 @@ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): obs, info = super().reset(seed=seed, options=options) super().get_states() - return obs, info \ No newline at end of file + return obs, info diff --git a/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py index 2d28b6b..4f14656 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py +++ b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py @@ -61,6 +61,7 @@ def is_task_success(self) -> bool: or abs(fork_y - fork_place_target_y) > tolerance ) + @register_env("RearrangementAgent-v3", max_episode_steps=600) class RearrangementAgentEnv3(BaseAgentEnv, RearrangementEnv3): def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): @@ -90,4 +91,7 @@ def is_task_success(self): tolerance = self.metadata.get("success_params", {}).get("tolerance", 0.02) # spoon and fork should with the y range of tolerance related to plate. - return (abs(spoon_y - spoon_place_target_y) <= tolerance or abs(fork_y - fork_place_target_y) <= tolerance) \ No newline at end of file + return ( + abs(spoon_y - spoon_place_target_y) <= tolerance + or abs(fork_y - fork_place_target_y) <= tolerance + ) diff --git a/embodichain/lab/gym/motion_generation/action/arm_action.py b/embodichain/lab/gym/motion_generation/action/arm_action.py index fc85e4f..bacc98b 100644 --- a/embodichain/lab/gym/motion_generation/action/arm_action.py +++ b/embodichain/lab/gym/motion_generation/action/arm_action.py @@ -226,9 +226,9 @@ def calculate_point_allocations( def create_qpos_dict(position: np.ndarray, dof: int) -> Dict: """Create qpos dictionary with zero velocity and acceleration""" return { - "position": position.tolist() - if isinstance(position, np.ndarray) - else position, + "position": ( + position.tolist() if isinstance(position, np.ndarray) else position + ), "velocity": [0.0] * dof, "acceleration": [0.0] * dof, } diff --git a/embodichain/lab/gym/motion_generation/planner/utils.py b/embodichain/lab/gym/motion_generation/planner/utils.py index 8e67195..b68a9ae 100644 --- a/embodichain/lab/gym/motion_generation/planner/utils.py +++ b/embodichain/lab/gym/motion_generation/planner/utils.py @@ -14,6 +14,7 @@ class TrajectorySampleMethod(Enum): This enum defines various methods for sampling trajectories, providing meaningful names for different sampling strategies. """ + TIME = "time" """Sample based on time intervals.""" diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py index 8972b6b..9f83949 100644 --- a/embodichain/lab/gym/utils/misc.py +++ b/embodichain/lab/gym/utils/misc.py @@ -756,12 +756,14 @@ def resolve_env_attr(obj: Any, env: Any) -> Any: _EXPR = re.compile(r"\$\{([^}]+)\}") # For searching ${...} marker + def is_binocularcam(sensor): from dexsim.sensor import BinocularCam from embodichain.lab.sim.sensors import StereoCamera return isinstance(sensor, BinocularCam) or isinstance(sensor, StereoCamera) + def resolve_formatted_string(obj, local_vars=None, global_vars=None): """Given a dict carrys "${...}"-like strings , `eval` the "${...}$" values while keep the dict structure. @@ -1388,6 +1390,7 @@ def is_stereocam(sensor) -> bool: return isinstance(sensor, StereoCamera) + def _data_key_to_control_part(robot, control_parts, data_key: str) -> Optional[str]: # TODO: Temporary workaround, should be removed after refactoring data dict extractor. # @lru_cache(max_size=None) # NOTE: no way to pass a hashable parameter diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py index a32117f..94a8ce9 100644 --- a/embodichain/lab/scripts/run_agent.py +++ b/embodichain/lab/scripts/run_agent.py @@ -86,7 +86,9 @@ def wait_for_threads(threads): num_samples = kwargs.get("num_samples", 0) is_save_dataset = time_id < num_samples - data_dict = env.get_wrapper_attr("to_dataset")(id=dataset_id if is_save_dataset else None) + data_dict = env.get_wrapper_attr("to_dataset")( + id=dataset_id if is_save_dataset else None + ) ret.append(data_dict) else: data_dict = env.get_wrapper_attr("to_dataset")(id=dataset_id) diff --git a/embodichain/toolkits/interfaces.py b/embodichain/toolkits/interfaces.py index 630493f..b91a008 100644 --- a/embodichain/toolkits/interfaces.py +++ b/embodichain/toolkits/interfaces.py @@ -4,7 +4,11 @@ from embodichain.toolkits.toolkits import ToolkitsBase from embodichain.utils.logger import log_info, log_warning, log_error from copy import deepcopy -from embodichain.lab.gym.utils.misc import mul_linear_expand, is_qpos_flip, get_rotation_replaced_pose +from embodichain.lab.gym.utils.misc import ( + mul_linear_expand, + is_qpos_flip, + get_rotation_replaced_pose, +) from embodichain.toolkits.graspkit.pg_grasp.antipodal import GraspSelectMethod from matplotlib import pyplot as plt import torch @@ -14,15 +18,19 @@ from scipy.spatial.transform import Rotation as R from embodichain.utils.utility import encode_image import ast -''' + +""" --------------------------------------------Some useful functions---------------------------------------------------- --------------------------------------------Some useful functions---------------------------------------------------- --------------------------------------------Some useful functions---------------------------------------------------- -''' +""" + + def draw_axis(env, pose): from embodichain.lab.sim.cfg import MarkerCfg + marker_cfg = MarkerCfg( - name='test', + name="test", marker_type="axis", axis_xpos=pose, axis_size=0.01, @@ -32,30 +40,60 @@ def draw_axis(env, pose): env.sim.draw_marker(cfg=marker_cfg) env.sim.update() + def get_arm_states(env, robot_name): left_arm_current_qpos, right_arm_current_qpos = env.get_current_qpos_agent() left_arm_current_pose, right_arm_current_pose = env.get_current_xpos_agent() - left_arm_current_gripper_state, right_arm_current_gripper_state = env.get_current_gripper_state_agent() + left_arm_current_gripper_state, right_arm_current_gripper_state = ( + env.get_current_gripper_state_agent() + ) side = "right" if "right" in robot_name else "left" is_left = True if side == "left" else False select_arm = "left_arm" if is_left else "right_arm" - arms = {"left": (left_arm_current_qpos, left_arm_current_pose, left_arm_current_gripper_state), - "right": (right_arm_current_qpos, right_arm_current_pose, right_arm_current_gripper_state)} - select_arm_current_qpos, select_arm_current_pose, select_arm_current_gripper_state = arms[side] + arms = { + "left": ( + left_arm_current_qpos, + left_arm_current_pose, + left_arm_current_gripper_state, + ), + "right": ( + right_arm_current_qpos, + right_arm_current_pose, + right_arm_current_gripper_state, + ), + } + ( + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = arms[side] + + return ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) - return is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, select_arm_current_gripper_state def find_nearest_valid_pose(env, select_arm, pose, xpos_resolution=0.1): # use the validator to choose the nearest valid pose # delete the cache every time if isinstance(pose, torch.Tensor): pose = pose.detach().cpu().numpy() - ret, _ = env.robot.compute_xpos_reachability(select_arm, pose, xpos_resolution=xpos_resolution, - qpos_resolution=np.radians(60), cache_mode="disk", use_cached=False, - visualize=False) + ret, _ = env.robot.compute_xpos_reachability( + select_arm, + pose, + xpos_resolution=xpos_resolution, + qpos_resolution=np.radians(60), + cache_mode="disk", + use_cached=False, + visualize=False, + ) ret = np.stack(ret, axis=0) # find the nearest valid pose xyz = pose[:3, 3] @@ -65,35 +103,33 @@ def find_nearest_valid_pose(env, select_arm, pose, xpos_resolution=0.1): nearest_valid_pose = ret[best_idx] return torch.from_numpy(nearest_valid_pose) -def get_qpos(env, is_left, select_arm, pose, qpos_seed, force_valid=False, name=''): + +def get_qpos(env, is_left, select_arm, pose, qpos_seed, force_valid=False, name=""): if force_valid: try: - ret, qpos = env.get_arm_ik( - pose, is_left=is_left, qpos_seed=qpos_seed - ) + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) if not ret: log_error(f"Generate {name} qpos failed.\n") except Exception as e: - log_warning(f"Original {name} pose invalid, using nearest valid pose. ({e})\n") + log_warning( + f"Original {name} pose invalid, using nearest valid pose. ({e})\n" + ) pose = find_nearest_valid_pose(env, select_arm, pose) - ret, qpos = env.get_arm_ik( - pose, is_left=is_left, qpos_seed=qpos_seed - ) + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) else: - ret, qpos = env.get_arm_ik( - pose, is_left=is_left, qpos_seed=qpos_seed - ) + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) if not ret: log_error(f"Generate {name} qpos failed.\n") return qpos + def get_offset_pose( - pose_to_change: torch.Tensor, - offset_value: float, - direction: str = "z", - mode: str = "intrinsic", + pose_to_change: torch.Tensor, + offset_value: float, + direction: str = "z", + mode: str = "intrinsic", ) -> torch.Tensor: device = pose_to_change.device @@ -126,8 +162,16 @@ def get_offset_pose( return offset_pose -def plan_trajectory(env, select_arm, qpos_list, sample_num, - select_arm_current_gripper_state, select_qpos_traj, ee_state_list_select): + +def plan_trajectory( + env, + select_arm, + qpos_list, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, +): traj_list, _ = ArmAction.create_discrete_trajectory( agent=env.robot, uid=select_arm, @@ -141,8 +185,16 @@ def plan_trajectory(env, select_arm, qpos_list, sample_num, select_qpos_traj.extend(traj_list) ee_state_list_select.extend([select_arm_current_gripper_state] * len(traj_list)) -def plan_gripper_trajectory(env, is_left, sample_num, execute_open, - select_arm_current_qpos, select_qpos_traj, ee_state_list_select): + +def plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select, +): open_state = env.open_state close_state = env.close_state @@ -154,15 +206,24 @@ def plan_gripper_trajectory(env, is_left, sample_num, execute_open, env.set_current_gripper_state_agent(close_state, is_left=is_left) ee_state_expand_select = mul_linear_expand(ee_state_expand_select, [sample_num]) - + select_qpos_traj.extend([select_arm_current_qpos] * sample_num) ee_state_list_select.extend(ee_state_expand_select) + def finalize_actions(select_qpos_traj, ee_state_list_select): # mimic eef state - actions = np.concatenate([np.array(select_qpos_traj), np.array(ee_state_list_select), np.array(ee_state_list_select)], axis=-1) + actions = np.concatenate( + [ + np.array(select_qpos_traj), + np.array(ee_state_list_select), + np.array(ee_state_list_select), + ], + axis=-1, + ) return actions + def extract_drive_calls(code_str: str) -> list[str]: tree = ast.parse(code_str) lines = code_str.splitlines() @@ -185,11 +246,14 @@ def extract_drive_calls(code_str: str) -> list[str]: return drive_blocks -''' + +""" --------------------------------------------Atom action functions---------------------------------------------------- --------------------------------------------Atom action functions---------------------------------------------------- --------------------------------------------Atom action functions---------------------------------------------------- -''' +""" + + # TODO: write a move_to_pose atom action, the use this action to form other atom actions def grasp( robot_name: str, @@ -197,7 +261,7 @@ def grasp( pre_grasp_dis: float = 0.05, env=None, force_valid=False, - **kwargs + **kwargs, ): # Get target object obj_uids = env.sim.get_rigid_object_uid_list() @@ -209,36 +273,85 @@ def grasp( # Open the gripper if currently closed actions = None - select_arm_current_gripper_state = env.left_arm_current_gripper_state if 'left' in robot_name else env.right_arm_current_gripper_state + select_arm_current_gripper_state = ( + env.left_arm_current_gripper_state + if "left" in robot_name + else env.right_arm_current_gripper_state + ) if select_arm_current_gripper_state <= env.open_state - 0.01: actions = open_gripper(robot_name, env, **kwargs) # Retract the end-effector to avoid collision - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) - select_arm_base_pose = env.left_arm_base_pose if is_left else env.right_arm_base_pose - base_to_eef_xy_dis = torch.norm(select_arm_base_pose[:2, 3] - select_arm_current_pose[:2, 3]) - base_to_obj_xy_dis = torch.norm(select_arm_base_pose[:2, 3] - target_obj_pose[:2, 3]) - dis_eps = kwargs.get('dis_eps', 0.05) - select_arm_init_pose = env.left_arm_init_xpos if is_left else env.right_arm_init_xpos - if base_to_eef_xy_dis > base_to_obj_xy_dis and not torch.allclose(select_arm_current_pose, select_arm_init_pose, rtol=1e-5, atol=1e-8): + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + select_arm_base_pose = ( + env.left_arm_base_pose if is_left else env.right_arm_base_pose + ) + base_to_eef_xy_dis = torch.norm( + select_arm_base_pose[:2, 3] - select_arm_current_pose[:2, 3] + ) + base_to_obj_xy_dis = torch.norm( + select_arm_base_pose[:2, 3] - target_obj_pose[:2, 3] + ) + dis_eps = kwargs.get("dis_eps", 0.05) + select_arm_init_pose = ( + env.left_arm_init_xpos if is_left else env.right_arm_init_xpos + ) + if base_to_eef_xy_dis > base_to_obj_xy_dis and not torch.allclose( + select_arm_current_pose, select_arm_init_pose, rtol=1e-5, atol=1e-8 + ): delta = base_to_eef_xy_dis - (base_to_obj_xy_dis - dis_eps) - back_actions = move_by_relative_offset(robot_name=robot_name, dx=0.0, dy=0.0, dz=-delta, env=env, force_valid=force_valid, mode="intrinsic", sample_num=15, **kwargs) - actions = np.concatenate([actions, back_actions], axis=0) if actions is not None else back_actions + back_actions = move_by_relative_offset( + robot_name=robot_name, + dx=0.0, + dy=0.0, + dz=-delta, + env=env, + force_valid=force_valid, + mode="intrinsic", + sample_num=15, + **kwargs, + ) + actions = ( + np.concatenate([actions, back_actions], axis=0) + if actions is not None + else back_actions + ) # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- # Move the end-effector to a good place for starting grasping to avoid bad poses - select_arm_retract_pose = deepcopy(env.left_arm_init_xpos if is_left else env.right_arm_init_xpos) - select_arm_retract_pose = get_offset_pose(select_arm_retract_pose, 0.15, "z", "intrinsic") - select_arm_retract_qpos = get_qpos(env, is_left, select_arm, select_arm_retract_pose, env.left_arm_init_qpos if is_left else env.right_arm_init_qpos, - force_valid=force_valid, name='retract_to_good_pose') + select_arm_retract_pose = deepcopy( + env.left_arm_init_xpos if is_left else env.right_arm_init_xpos + ) + select_arm_retract_pose = get_offset_pose( + select_arm_retract_pose, 0.15, "z", "intrinsic" + ) + select_arm_retract_qpos = get_qpos( + env, + is_left, + select_arm, + select_arm_retract_pose, + env.left_arm_init_qpos if is_left else env.right_arm_init_qpos, + force_valid=force_valid, + name="retract_to_good_pose", + ) qpos_list_back_to_retract = [select_arm_current_qpos, select_arm_retract_qpos] sample_num = 30 @@ -249,7 +362,7 @@ def grasp( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) select_arm_current_qpos = select_arm_retract_qpos @@ -264,18 +377,43 @@ def grasp( select_arm_aim_qpos[0] = aim_horizontal_angle # Get best grasp pose from affordance data - grasp_pose_object = env.init_obj_info.get(obj_name)['grasp_pose_obj'] - if grasp_pose_object[0, 2] > 0.5: # whether towards x direction TODO: make it robust + grasp_pose_object = env.init_obj_info.get(obj_name)["grasp_pose_obj"] + if ( + grasp_pose_object[0, 2] > 0.5 + ): # whether towards x direction TODO: make it robust # Align the object pose's z-axis with the arm's aiming direction - target_obj_pose = torch.tensor(get_rotation_replaced_pose(np.array(target_obj_pose), float(select_arm_aim_qpos[0]), "z","intrinsic")) + target_obj_pose = torch.tensor( + get_rotation_replaced_pose( + np.array(target_obj_pose), + float(select_arm_aim_qpos[0]), + "z", + "intrinsic", + ) + ) best_pickpose = target_obj_pose @ grasp_pose_object grasp_pose = deepcopy(best_pickpose) grasp_pose_pre1 = deepcopy(grasp_pose) - grasp_pose_pre1 = get_offset_pose(grasp_pose_pre1, -pre_grasp_dis, "z", 'intrinsic') + grasp_pose_pre1 = get_offset_pose(grasp_pose_pre1, -pre_grasp_dis, "z", "intrinsic") # Solve IK for pre-grasp and grasp poses - grasp_qpos_pre1 = get_qpos(env, is_left, select_arm, grasp_pose_pre1, select_arm_aim_qpos, force_valid=force_valid, name='grasp pre1') - grasp_qpos = get_qpos(env, is_left, select_arm, grasp_pose, grasp_qpos_pre1, force_valid=force_valid, name='grasp') + grasp_qpos_pre1 = get_qpos( + env, + is_left, + select_arm, + grasp_pose_pre1, + select_arm_aim_qpos, + force_valid=force_valid, + name="grasp pre1", + ) + grasp_qpos = get_qpos( + env, + is_left, + select_arm, + grasp_pose, + grasp_qpos_pre1, + force_valid=force_valid, + name="grasp", + ) # Update env state to final grasp pose env.set_current_qpos_agent(grasp_qpos, is_left=is_left) @@ -296,7 +434,7 @@ def grasp( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ------------------------------------ Traj 1: aim → pre-grasp ------------------------------------ @@ -310,7 +448,7 @@ def grasp( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ------------------------------------ Traj 2: pre-grasp → grasp ------------------------------------ @@ -324,21 +462,28 @@ def grasp( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- traj_actions = finalize_actions(select_qpos_traj, ee_state_list_select) - actions = traj_actions if actions is None else np.concatenate([actions, traj_actions], axis=0) + actions = ( + traj_actions + if actions is None + else np.concatenate([actions, traj_actions], axis=0) + ) # ------------------------------------ Close gripper ------------------------------------ close_gripper_actions = close_gripper(robot_name, env, **kwargs) actions = np.concatenate([actions, close_gripper_actions], axis=0) - log_info(f"Total generated trajectory number for grasp: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for grasp: {len(actions)}.", color="green" + ) return actions + # def place_on_table( # robot_name: str, # obj_name: str, @@ -422,6 +567,7 @@ def grasp( # # return actions + def place_on_table( robot_name: str, obj_name: str, @@ -430,21 +576,27 @@ def place_on_table( pre_place_dis: float = 0.08, env=None, force_valid=False, - **kwargs + **kwargs, ): - init_obj_height = env.init_obj_info.get(obj_name).get('height') - height = init_obj_height + kwargs.get('eps', 0.03) + init_obj_height = env.init_obj_info.get(obj_name).get("height") + height = init_obj_height + kwargs.get("eps", 0.03) - traj_actions = move_to_absolute_position(robot_name, x=x, y=y, z=height, env=env, force_valid=force_valid, **kwargs) + traj_actions = move_to_absolute_position( + robot_name, x=x, y=y, z=height, env=env, force_valid=force_valid, **kwargs + ) open_actions = open_gripper(robot_name, env, **kwargs) actions = np.concatenate([traj_actions, open_actions], axis=0) - log_info(f"Total generated trajectory number for place on table: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for place on table: {len(actions)}.", + color="green", + ) return actions + def move_relative_to_object( robot_name: str, obj_name: str, @@ -453,15 +605,20 @@ def move_relative_to_object( z_offset: float = 0, env=None, force_valid=False, - **kwargs + **kwargs, ): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- # Resolve target object @@ -482,7 +639,15 @@ def move_relative_to_object( move_target_pose[2, 3] += z_offset # Solve IK for target pose - move_target_qpos = get_qpos(env, is_left, select_arm, move_target_pose, select_arm_current_qpos, force_valid=force_valid, name='move relative to object') + move_target_qpos = get_qpos( + env, + is_left, + select_arm, + move_target_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move relative to object", + ) # Update env states env.set_current_qpos_agent(move_target_qpos, is_left=is_left) @@ -499,16 +664,20 @@ def move_relative_to_object( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for move relative to object: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for move relative to object: {len(actions)}.", + color="green", + ) return actions + def move_to_absolute_position( robot_name: str, x: float = None, @@ -516,15 +685,20 @@ def move_to_absolute_position( z: float = None, env=None, force_valid=False, - **kwargs + **kwargs, ): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- # Start from current pose, then selectively update xyz @@ -543,7 +717,15 @@ def move_to_absolute_position( move_pose[:3, 3] = target_xyz # Try IK on target pose - move_qpos = get_qpos(env, is_left, select_arm, move_pose, select_arm_current_qpos, force_valid=force_valid, name='move to absolute position') + move_qpos = get_qpos( + env, + is_left, + select_arm, + move_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move to absolute position", + ) # Update env states env.set_current_qpos_agent(move_qpos, is_left=is_left) @@ -560,33 +742,42 @@ def move_to_absolute_position( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for move to absolute position: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for move to absolute position: {len(actions)}.", + color="green", + ) return actions + def move_by_relative_offset( robot_name: str, dx: float = 0.0, dy: float = 0.0, dz: float = 0.0, - mode: str = 'extrinsic', + mode: str = "extrinsic", env=None, force_valid=False, - **kwargs + **kwargs, ): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- move_pose = deepcopy(select_arm_current_pose) @@ -597,7 +788,15 @@ def move_by_relative_offset( move_pose = get_offset_pose(move_pose, dz, "z", mode) # Solve IK - move_qpos = get_qpos(env, is_left, select_arm, move_pose, select_arm_current_qpos, force_valid=force_valid, name='move by relative offset') + move_qpos = get_qpos( + env, + is_left, + select_arm, + move_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move by relative offset", + ) # Update environment states env.set_current_qpos_agent(move_qpos, is_left=is_left) @@ -614,29 +813,34 @@ def move_by_relative_offset( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for move by relative offset: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for move by relative offset: {len(actions)}.", + color="green", + ) return actions -def back_to_initial_pose( - robot_name: str, - env=None, - **kwargs -): + +def back_to_initial_pose(robot_name: str, env=None, **kwargs): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] # Get arm states - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # Retrieve the initial joint configuration of this arm target_qpos = env.left_arm_init_qpos if is_left else env.right_arm_init_qpos @@ -648,7 +852,15 @@ def back_to_initial_pose( pre_back_pose = get_offset_pose(pre_back_pose, -0.08, "z", "intrinsic") # IK for pre-back - pre_back_qpos = get_qpos(env, is_left, select_arm, pre_back_pose, select_arm_current_qpos, force_valid=kwargs.get("force_valid", False), name="pre back pose") + pre_back_qpos = get_qpos( + env, + is_left, + select_arm, + pre_back_pose, + select_arm_current_qpos, + force_valid=kwargs.get("force_valid", False), + name="pre back pose", + ) # Update env states (move to target pose) target_pose = env.get_arm_fk(qpos=target_qpos, is_left=is_left) @@ -666,7 +878,7 @@ def back_to_initial_pose( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ------------------------------------ Traj: init → initial_pose ------------------------------------ @@ -680,32 +892,33 @@ def back_to_initial_pose( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- - actions = finalize_actions( - select_qpos_traj, - ee_state_list_select - ) + actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for back to initial pose: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for back to initial pose: {len(actions)}.", + color="green", + ) return actions -def rotate_eef( - robot_name: str, - degree: float = 0, - env=None, - **kwargs -): + +def rotate_eef(robot_name: str, degree: float = 0, env=None, **kwargs): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- # Compute new joint positions @@ -738,22 +951,26 @@ def rotate_eef( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for rotate eef: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for rotate eef: {len(actions)}.", + color="green", + ) return actions + def orient_eef( robot_name: str, - direction: str = 'front', # 'front' or 'down' + direction: str = "front", # 'front' or 'down' env=None, force_valid=False, - **kwargs + **kwargs, ): # ---------------------------------------- Prepare ---------------------------------------- @@ -761,18 +978,27 @@ def orient_eef( ee_state_list_select = [] # Get arm state - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Pose ---------------------------------------- # Generate replacement rotation matrix replaced_rotation_matrix = np.eye(4) - if direction == 'front': + if direction == "front": rotation_matrix = R.from_euler("xyz", [180, -90, 0], degrees=True).as_matrix() - replaced_rotation_matrix[:3, :3] = rotation_matrix @ replaced_rotation_matrix[:3, :3] - elif direction == 'down': + replaced_rotation_matrix[:3, :3] = ( + rotation_matrix @ replaced_rotation_matrix[:3, :3] + ) + elif direction == "down": rotation_matrix = R.from_euler("x", 180, degrees=True).as_matrix() - replaced_rotation_matrix[:3, :3] = rotation_matrix @ replaced_rotation_matrix[:3, :3] + replaced_rotation_matrix[:3, :3] = ( + rotation_matrix @ replaced_rotation_matrix[:3, :3] + ) else: log_error("Rotation direction must be 'front' or 'down'.") @@ -780,12 +1006,20 @@ def orient_eef( rot_torch = torch.as_tensor( replaced_rotation_matrix[:3, :3], dtype=rotation_replaced_pose.dtype, - device=rotation_replaced_pose.device + device=rotation_replaced_pose.device, ) rotation_replaced_pose[:3, :3] = rot_torch # Solve IK for the new pose - replace_target_qpos = get_qpos(env, is_left, select_arm, rotation_replaced_pose, select_arm_current_qpos, force_valid=force_valid, name='replaced-rotation') + replace_target_qpos = get_qpos( + env, + is_left, + select_arm, + rotation_replaced_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="replaced-rotation", + ) # ---------------------------------------- Update env ---------------------------------------- env.set_current_qpos_agent(replace_target_qpos, is_left=is_left) @@ -802,31 +1036,33 @@ def orient_eef( sample_num, select_arm_current_gripper_state, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- - actions = finalize_actions( - select_qpos_traj, - ee_state_list_select - ) + actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for orient eef: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for orient eef: {len(actions)}.", + color="green", + ) return actions -def close_gripper( - robot_name: str, - env=None, - **kwargs -): + +def close_gripper(robot_name: str, env=None, **kwargs): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Traj ---------------------------------------- sample_num = kwargs.get("sample_num", 15) @@ -839,31 +1075,33 @@ def close_gripper( execute_open, select_arm_current_qpos, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- - actions = finalize_actions( - select_qpos_traj, - ee_state_list_select - ) + actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for close gripper: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for close gripper: {len(actions)}.", + color="green", + ) return actions -def open_gripper( - robot_name: str, - env=None, - **kwargs -): + +def open_gripper(robot_name: str, env=None, **kwargs): # ---------------------------------------- Prepare ---------------------------------------- select_qpos_traj = [] ee_state_list_select = [] - is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ - select_arm_current_gripper_state = get_arm_states(env, robot_name) + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) # ---------------------------------------- Traj ---------------------------------------- sample_num = kwargs.get("sample_num", 15) @@ -876,22 +1114,23 @@ def open_gripper( execute_open, select_arm_current_qpos, select_qpos_traj, - ee_state_list_select + ee_state_list_select, ) # ---------------------------------------- Final ---------------------------------------- - actions = finalize_actions( - select_qpos_traj, - ee_state_list_select - ) + actions = finalize_actions(select_qpos_traj, ee_state_list_select) - log_info(f"Total generated trajectory number for open gripper: {len(actions)}.", color="green") + log_info( + f"Total generated trajectory number for open gripper: {len(actions)}.", + color="green", + ) return actions + def drive( - left_arm_action = None, - right_arm_action = None, + left_arm_action=None, + right_arm_action=None, env=None, **kwargs, ): @@ -918,20 +1157,34 @@ def drive( elif left_arm_action is None and right_arm_action is not None: left_arm_index = env.left_arm_joints + env.left_eef_joints right_arm_index = env.right_arm_joints + env.right_eef_joints - left_arm_action = finalize_actions(env.left_arm_current_qpos, env.left_arm_current_gripper_state) - left_arm_action = np.repeat(left_arm_action[None, :], len(right_arm_action), axis=0) + left_arm_action = finalize_actions( + env.left_arm_current_qpos, env.left_arm_current_gripper_state + ) + left_arm_action = np.repeat( + left_arm_action[None, :], len(right_arm_action), axis=0 + ) - actions = np.zeros((len(right_arm_action), len(env.robot.get_qpos().squeeze(0))), dtype=np.float32) + actions = np.zeros( + (len(right_arm_action), len(env.robot.get_qpos().squeeze(0))), + dtype=np.float32, + ) actions[:, left_arm_index] = left_arm_action actions[:, right_arm_index] = right_arm_action elif right_arm_action is None and left_arm_action is not None: left_arm_index = env.left_arm_joints + env.left_eef_joints right_arm_index = env.right_arm_joints + env.right_eef_joints - right_arm_action = finalize_actions(env.right_arm_current_qpos, env.right_arm_current_gripper_state) - right_arm_action = np.repeat(right_arm_action[None, :], len(left_arm_action), axis=0) + right_arm_action = finalize_actions( + env.right_arm_current_qpos, env.right_arm_current_gripper_state + ) + right_arm_action = np.repeat( + right_arm_action[None, :], len(left_arm_action), axis=0 + ) - actions = np.zeros((len(left_arm_action), len(env.robot.get_qpos().squeeze(0))), dtype=np.float32) + actions = np.zeros( + (len(left_arm_action), len(env.robot.get_qpos().squeeze(0))), + dtype=np.float32, + ) actions[:, left_arm_index] = left_arm_action actions[:, right_arm_index] = right_arm_action @@ -945,6 +1198,7 @@ def drive( obs, reward, terminated, truncated, info = env.step(action) return actions + def save_observations( step_id: int = 0, step_name: str = None, @@ -968,6 +1222,7 @@ def save_observations( # Decode Base64 back to raw image bytes import base64 + img_bytes = base64.b64decode(base64_image) # Ensure step_name is not None @@ -983,4 +1238,4 @@ def save_observations( # When only running the script (no feedback script) else: - pass \ No newline at end of file + pass From 037c97bf95ff7cb34771d933ac6c003d60cf1d3e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 21 Jan 2026 11:30:24 +0800 Subject: [PATCH 48/49] Fix: docs build without API keys --- docs/source/conf.py | 2 ++ embodichain/agents/hierarchy/llm.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3aac3af..b260b6d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,6 +9,8 @@ import os import sys +os.environ.setdefault("AZURE_OPENAI_ENDPOINT", "https://mock-endpoint.openai.azure.com/") +os.environ.setdefault("AZURE_OPENAI_API_KEY", "mock-api-key-for-docs-build") sys.path.insert(0, os.path.abspath("../..")) diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py index 67c3c35..f8e2262 100644 --- a/embodichain/agents/hierarchy/llm.py +++ b/embodichain/agents/hierarchy/llm.py @@ -41,7 +41,14 @@ def create_llm(*, temperature=0.0, model="gpt-4o"): # LLM instances # ------------------------------------------------------------------------------ -task_llm = create_llm(temperature=0.0, model="gpt-4o") -code_llm = create_llm(temperature=0.0, model="gpt-4o") -validation_llm = create_llm(temperature=0.0, model="gpt-4o") -view_selection_llm = create_llm(temperature=0.0, model="gpt-4o") +# Initialize LLM instances, but handle errors gracefully for documentation builds +def _create_llm_safe(*, temperature=0.0, model="gpt-4o"): + try: + return create_llm(temperature=temperature, model=model) + except Exception: + return None + +task_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +code_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +validation_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +view_selection_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") From ba889e340c1ce4c5308aeef795f9c8859c873e2f Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 21 Jan 2026 11:36:58 +0800 Subject: [PATCH 49/49] Reformat files --- docs/source/conf.py | 4 +++- embodichain/agents/hierarchy/llm.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b260b6d..6e392b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,9 @@ import os import sys -os.environ.setdefault("AZURE_OPENAI_ENDPOINT", "https://mock-endpoint.openai.azure.com/") +os.environ.setdefault( + "AZURE_OPENAI_ENDPOINT", "https://mock-endpoint.openai.azure.com/" +) os.environ.setdefault("AZURE_OPENAI_API_KEY", "mock-api-key-for-docs-build") sys.path.insert(0, os.path.abspath("../..")) diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py index f8e2262..b86169d 100644 --- a/embodichain/agents/hierarchy/llm.py +++ b/embodichain/agents/hierarchy/llm.py @@ -41,6 +41,7 @@ def create_llm(*, temperature=0.0, model="gpt-4o"): # LLM instances # ------------------------------------------------------------------------------ + # Initialize LLM instances, but handle errors gracefully for documentation builds def _create_llm_safe(*, temperature=0.0, model="gpt-4o"): try: @@ -48,6 +49,7 @@ def _create_llm_safe(*, temperature=0.0, model="gpt-4o"): except Exception: return None + task_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") code_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") validation_llm = _create_llm_safe(temperature=0.0, model="gpt-4o")