diff --git a/.gitignore b/.gitignore index cbe06491..040955d9 100644 --- a/.gitignore +++ b/.gitignore @@ -179,7 +179,8 @@ materials embodichain/toolkits/outputs/* embodichain/toolkits/outputs/* -embodichain/database/* +embodichain/database/tmp/* +embodichain/database/train/* 3rdparty/ Log/ diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config.json b/configs/gym/agent/pour_water_agent_v3/agent_config.json new file mode 100644 index 00000000..db38818a --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/agent_config.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "PourWaterAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} \ No newline at end of file diff --git a/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json new file mode 100644 index 00000000..415b7ac7 --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/agent_config_dual.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "DualPourWaterAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} \ No newline at end of file diff --git a/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json new file mode 100644 index 00000000..10a3d9cf --- /dev/null +++ b/configs/gym/agent/pour_water_agent_v3/fast_gym_config.json @@ -0,0 +1,392 @@ +{ + "id": "PourWaterAgent-v3", + "max_episodes": 5, + "env": { + "events": { + "init_table_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "table"}, + "position_range": [[0.0, 0.0, -0.04], [0.0, 0.0, 0.04]], + "relative_position": true + } + }, + "init_bottle_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "bottle"}, + "position_range": [[-0.08, -0.12, 0.0], [0.08, 0.04, 0.0]], + "relative_position": true + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cup"}, + "position_range": [[-0.08, -0.04, 0.0], [0.04, 0.06, 0.0]], + "relative_position": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "bottle" + }, + "value": [[ + [0.32243, 0.03245, 0.94604, 0.025], + [0.00706, -0.99947, 0.03188, -0.0 ], + [0.94657, -0.0036 , -0.32249, 0.0 ], + [0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "cup" + }, + "value": [[ + [ 0.32039, -0.03227, 0.94673, 0.0 ], + [ 0.00675, -0.99932, -0.03635, 0.0 ], + [ 0.94726, 0.01803, -0.31996, 0.0 ], + [ 0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "left_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "left_arm_base", + "to_matrix": true + } + }, + { + "name": "right_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "right_arm_base", + "to_matrix": true + } + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "bottle" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "cup" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "cup", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "bottle", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_table_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/WoodTable" + } + }, + "random_bottle_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "bottle"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Bottle" + } + }, + "random_cup_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "cup"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Cup" + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + } + }, + "dataset": { + "lerobot": { + "func": "embodichain.lab.gym.envs.managers.datasets:LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 5 + }, + "instruction": { + "lang": "Pour water from the bottle into the mug." + } + } + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": false, + "enable_depth": false, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.7186357573287919, -0.054534732904795505, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "right_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "left_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_1", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [1.7, 0, 2.3], + "target": [0.6, 0, 0.8] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_2", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.8], + "target": [0.7, 0, 0.9] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_3", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.3], + "target": [0.7, 0, 0.9] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 100.0, + "init_pos": [2, 0, 2], + "radius": 20.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid":"cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, 0.1, 0.9], + "body_scale":[0.75, 0.75, 1.0], + "max_convex_hull_num": 8 + }, + { + "uid":"bottle", + "shape": { + "shape_type": "Mesh", + "fpath": "ScannedBottle/kashijia_processed.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, -0.1, 0.932], + "body_scale":[1, 1, 1], + "max_convex_hull_num": 8 + } + ] +} \ No newline at end of file diff --git a/configs/gym/agent/rearrangement_agent_v3/agent_config.json b/configs/gym/agent/rearrangement_agent_v3/agent_config.json new file mode 100644 index 00000000..1907e987 --- /dev/null +++ b/configs/gym/agent/rearrangement_agent_v3/agent_config.json @@ -0,0 +1,31 @@ +{ "TaskAgent": { + "prompt_name": "one_stage_prompt" + }, + "CodeAgent": { + "prompt_name": "one_stage_prompt_according_to_task_plan" + }, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "RearrangementAgent-v3/task_prompt.txt" + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt" + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt" + }, + "code_prompt": { + "type": "text", + "name": "code_prompt.txt" + }, + "code_example": { + "type": "text", + "name": "code_example.txt" + } + } + } +} diff --git a/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json new file mode 100644 index 00000000..5cc6daad --- /dev/null +++ b/configs/gym/agent/rearrangement_agent_v3/fast_gym_config.json @@ -0,0 +1,384 @@ +{ + "id": "RearrangementAgent-v3", + "max_episodes": 10, + "env": { + "events": { + "init_table_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "table"}, + "position_range": [[0.0, 0.0, -0.04], [0.0, 0.0, 0.04]], + "relative_position": true + } + }, + "init_fork_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "fork"}, + "position_range": [[0.0, -0.05, 0.0], [0.1, 0.05, 0.0]], + "rotation_range": [[0, 0, -45], [0, 0, 45]], + "relative_position": true, + "relative_rotation": true + } + }, + "init_spoon_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "spoon"}, + "position_range": [[0.0, -0.05, 0.0], [0.1, 0.05, 0.0]], + "rotation_range": [[0, 0, -45], [0, 0, 45]], + "relative_position": true, + "relative_rotation": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_uids": ["fork", "spoon"], + "value": [[ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ]] + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "plate" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "fork" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "spoon" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "spoon", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "fork", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_table_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/WoodTable" + } + }, + "random_plate_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "plate"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Plate" + } + }, + "random_fork_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "fork"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Spoon" + } + }, + "random_spoon_material": { + "func": "randomize_visual_material", + "mode": "reset", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "spoon"}, + "random_texture_prob": 1.0, + "texture_path": "DexsimMaterials/Spoon" + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + } + }, + "dataset": { + "lerobot": { + "func": "embodichain.lab.gym.envs.managers.datasets:LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": {}, + "states": ["qpos"] + }, + "min_len_steps": 125 + }, + "instruction": { + "lang": "Place the spoon and fork neatly into the plate on the table." + } + } + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": false, + "enable_depth": false, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.7186357573287919, -0.054534732904795505, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "right_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 480, + "enable_mask": false, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "left_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_1", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [1.7, 0, 2.3], + "target": [0.6, 0, 0.8] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_2", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.8], + "target": [0.7, 0, 0.9] + } + }, + { + "sensor_type": "Camera", + "uid": "valid_cam_3", + "width": 1280, + "height": 960, + "enable_mask": false, + "intrinsics": [1400, 1400, 640, 480], + "extrinsics": { + "eye": [2.0, 0, 1.3], + "target": [0.7, 0, 0.9] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 100.0, + "init_pos": [2, 0, 2], + "radius": 20.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply", + "compute_uv": true + }, + "attrs" : { + "mass": 10.0, + "static_friction": 2.0, + "dynamic_friction": 1.5, + "restitution": 0.1 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.691], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid": "plate", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/plate/3_center.ply", + "compute_uv": true + }, + "body_scale": [0.001, 0.001, 0.001], + "init_pos": [0.5, 0.0, 1.0], + "init_rot": [180, 0, 0] + }, + { + "uid": "fork", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/fork/standard_fork_scale.ply", + "compute_uv": true + }, + "body_scale": [1.0, 1.0, 1.0], + "init_pos": [0.5, 0.21, 1.0], + "attrs" : { + "mass": 0.01, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.05, + "min_position_iters": 16, + "min_velocity_iters": 4 + } + }, + { + "uid": "spoon", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/spoon/standard_spoon_a_rescale.ply", + "compute_uv": true + }, + "body_scale": [1.0, 1.0, 1.0], + "init_pos": [0.5, -0.21, 1.0], + "attrs" : { + "mass": 0.01, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.05, + "min_position_iters": 16, + "min_velocity_iters": 4 + } + } + ] +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 3aac3af3..6e392b52 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,6 +9,10 @@ import os import sys +os.environ.setdefault( + "AZURE_OPENAI_ENDPOINT", "https://mock-endpoint.openai.azure.com/" +) +os.environ.setdefault("AZURE_OPENAI_API_KEY", "mock-api-key-for-docs-build") sys.path.insert(0, os.path.abspath("../..")) diff --git a/embodichain/agents/README.md b/embodichain/agents/README.md new file mode 100644 index 00000000..b6ca144e --- /dev/null +++ b/embodichain/agents/README.md @@ -0,0 +1,97 @@ +# EmbodiAgent System + + +## Quick Start + +### 1. Prerequisites +Ensure you have access to Azure OpenAI or a compatible LLM endpoint. + +```bash +# Set environment variables +export AZURE_OPENAI_ENDPOINT="[https://your-endpoint.openai.azure.com/](https://your-endpoint.openai.azure.com/)" +export AZURE_OPENAI_API_KEY="your-api-key" +``` + + +### 2. Run the System + +```bash +python embodichain/lab/scripts/run_agent.py \ + --task_name YourTask \ + --gym_config configs/gym/your_task/gym_config.json \ + --agent_config configs/gym/agent/your_agent/agent_config.json \ + --regenerate False +``` + + + +## System Architecture + +The system operates on a closed-loop control cycle: + +1. **Observe**: The `TaskAgent` perceives the environment via multi-view camera inputs. +2. **Plan**: It decomposes the goal into natural language steps. +3. **Code**: The `CodeAgent` translates steps into executable Python code using atomic actions. +4. **Execute**: The code runs in the environment; runtime errors are caught immediately. +5. **Validate**: The `ValidationAgent` analyzes the result images, selects the best camera angle, and judges success. +6. **Refine**: If validation fails, feedback is sent back to the agents to regenerate the plan or code. + + +--- + +## Core Components + +### 1. TaskAgent ("The Planner") +*Located in:* `embodichain/agents/hierarchy/task_agent.py` + +Responsible for high-level reasoning. It parses visual observations and outputs a structured plan. + +* For every step, it generates a specific condition (e.g., "The cup must be held by the gripper") which is used later by the ValidationAgent. +* Prompt Strategies: + * `one_stage_prompt`: Direct VLM-to-Plan generation. + * `two_stage_prompt`: Separates visual analysis from planning logic. + +### 2. CodeAgent ("The Coder") +*Located in:* `embodichain/agents/hierarchy/code_agent.py` + +Translates natural language plans into executable Python code. + + +### 3. ValidationAgent ("The Judger") +*Located in:* `embodichain/agents/hierarchy/validation_agent.py` + +Closes the loop by verifying if the robot actually achieved what it planned. + +* Uses a specialized LLM call (`select_best_view_dir`) to analyze images from all cameras and pick the single best angle that proves the action's outcome, ignoring irrelevant views. +* If an error occurs (runtime or logic), it generates a detailed explanation which is fed back to the `TaskAgent` or `CodeAgent` for the next attempt. + +--- + +## Configuration Guide + +The `Agent` configuration block controls the context provided to the LLMs. All files are resolved relative to `embodichain/database/agent_prompt/`. + +| Parameter | Description | Typical Use | +| :--- | :--- | :--- | +| `task_prompt` | Task-specific goal description | "Pour water from the red cup to the blue cup." | +| `basic_background` | Physical rules & constraints | World coordinate system definitions, safety rules. | +| `atom_actions` | API Documentation | List of available functions (e.g., `drive(action='pick', ...)`). | +| `code_prompt` | Coding guidelines | "Use provided APIs only. Do not use loops." | +| `code_example` | Few-shot examples | Previous successful code snippets to guide style. | + +--- + +## File Structure + +```text +embodichain/agents/ +├── hierarchy/ +│ ├── agent_base.py # Abstract base handling prompts & images +│ ├── task_agent.py # Plan generation logic +│ ├── code_agent.py # Code generation & AST execution engine +│ ├── validation_agent.py # Visual analysis & view selection +│ └── llm.py # LLM configuration and instances +├── mllm/ +│ └── prompt/ # Prompt templates (LangChain) +└── README.md # This file +``` diff --git a/embodichain/agents/hierarchy/__init__.py b/embodichain/agents/hierarchy/__init__.py new file mode 100644 index 00000000..d56fc165 --- /dev/null +++ b/embodichain/agents/hierarchy/__init__.py @@ -0,0 +1,3 @@ +from langchain_openai import AzureChatOpenAI +from langchain_openai import ChatOpenAI +import os diff --git a/embodichain/agents/hierarchy/agent_base.py b/embodichain/agents/hierarchy/agent_base.py new file mode 100644 index 00000000..75b4dae6 --- /dev/null +++ b/embodichain/agents/hierarchy/agent_base.py @@ -0,0 +1,41 @@ +from abc import ABCMeta, abstractmethod +import os +import cv2 +from embodichain.utils.utility import load_json, load_txt +from embodichain.agents.mllm.prompt import * +from embodichain.data import database_agent_prompt_dir, database_2d_dir +from embodichain.utils.utility import encode_image + + +class AgentBase(metaclass=ABCMeta): + def __init__(self, **kwargs) -> None: + + assert ( + "prompt_kwargs" in kwargs.keys() + ), "Key prompt_kwargs must exist in config." + + for key, value in kwargs.items(): + setattr(self, key, value) + + # Preload and store prompt contents inside self.prompt_kwargs + for key, val in self.prompt_kwargs.items(): + if val["type"] == "text": + file_path = os.path.join(database_agent_prompt_dir, val["name"]) + val["content"] = load_txt(file_path) # ← store content here + else: + raise ValueError( + f"Now only support `text` type but {val['type']} is given." + ) + + def generate(self, *args, **kwargs): + pass + + def act(self, *args, **kwargs): + pass + + def get_composed_observations(self, **kwargs): + ret = {"observations": kwargs.get("env").get_obs_for_agent()} + for key, val in self.prompt_kwargs.items(): + ret[key] = val["content"] + ret.update(kwargs) + return ret diff --git a/embodichain/agents/hierarchy/code_agent.py b/embodichain/agents/hierarchy/code_agent.py new file mode 100644 index 00000000..310c0322 --- /dev/null +++ b/embodichain/agents/hierarchy/code_agent.py @@ -0,0 +1,272 @@ +from embodichain.agents.hierarchy.agent_base import AgentBase +from langchain_core.prompts import ChatPromptTemplate +import os +import numpy as np +import functools +from typing import Dict, Tuple, Any +from embodichain.toolkits.code_generation import ( + ExecutableOutputParser, + OutputFormatting, +) +from embodichain.toolkits.toolkits import ToolkitsBase +from embodichain.agents.mllm.prompt import CodePrompt +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import re +import importlib.util +from langchain_core.messages import HumanMessage +from datetime import datetime +from embodichain.utils.utility import encode_image +import base64 + + +def format_execution_history(execution_history): + if not execution_history or len(execution_history) == 0: + return "None." + + return "\n\n".join(f"{i}. {entry}" for i, entry in enumerate(execution_history, 1)) + + +def extract_python_code_and_text(llm_response: str) -> Tuple[str, str]: + """ + Extract exactly ONE python code block from the LLM response, + and return: + - code: the content inside the python block + - text: all remaining explanation text (outside the code block) + + Raises ValueError if zero or multiple python blocks are found. + """ + + pattern = r"```python\s*(.*?)\s*```" + matches = list(re.finditer(pattern, llm_response, re.DOTALL)) + + if len(matches) == 0: + raise ValueError("No python code block found in LLM response.") + if len(matches) > 1: + raise ValueError("Multiple python code blocks found in LLM response.") + + match = matches[0] + code = match.group(1).strip() + + # Optional sanity check + if not code.startswith("#") and not code.startswith("drive("): + raise ValueError( + f"Invalid code block content. Expected `drive(...)` or `# TASK_COMPLETE`, got:\n{code}" + ) + + # Extract remaining text (before + after the code block) + text_before = llm_response[: match.start()].strip() + text_after = llm_response[match.end() :].strip() + + explanation_text = "\n\n".join(part for part in [text_before, text_after] if part) + + return code, explanation_text + + +def format_llm_response_md( + llm_analysis: str, # plain-text explanation (NO code) + extracted_code: str, # validated executable code + step_id: int = None, + execution_history: str = None, + obs_image_path: Path = None, + md_file_path: Path = None, +) -> str: + ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + header = f"## Step: {step_id if step_id is not None else '-'} | {ts}\n\n" + + history_block = "" + if execution_history: + history_block = ( + "### Execution History (Input to LLM)\n\n" + "```\n" + f"{execution_history}\n" + "```\n\n" + ) + + image_block = "" + if obs_image_path is not None and md_file_path is not None: + try: + rel_path = obs_image_path.relative_to(md_file_path.parent) + except ValueError: + # Fallback: just use filename + rel_path = obs_image_path.name + + image_block = ( + "### Observation Image\n\n" f"![]({Path(rel_path).as_posix()})\n\n" + ) + + body = ( + image_block + history_block + "### LLM Analysis\n\n" + f"{llm_analysis.strip()}\n\n" + "### Executed Code\n\n" + "```python\n" + f"{extracted_code.strip()}\n" + "```\n\n" + "---\n\n" + ) + + return header + body + + +class CodeAgent(AgentBase): + query_prefix = "# " + query_suffix = "." + prompt: ChatPromptTemplate + prompt_kwargs: Dict[str, Dict] + + def __init__(self, llm, **kwargs) -> None: + super().__init__(**kwargs) + self.llm = llm + + def generate(self, **kwargs): + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = log_dir / "agent_generated_code.py" + + # Check if the file already exists + if not kwargs.get("regenerate", False): + if file_path.exists(): + print(f"Code file already exists at {file_path}, skipping writing.") + return file_path, kwargs, None + + # Generate code via LLM + prompt = getattr(CodePrompt, self.prompt_name)( + **kwargs, + ) + + # insert feedback if exists + if len(kwargs.get("error_messages", [])) != 0: + # just use the last one + last_code = kwargs["generated_codes"][-1] + last_error = kwargs["error_messages"][-1] + last_observation = ( + kwargs.get("observation_feedbacks")[-1] + if kwargs.get("observation_feedbacks") + else None + ) + + # Add extra human message with feedback + feedback_msg = self.build_feedback_message( + last_code, last_error, last_observation + ) + prompt.messages.append(feedback_msg) + + llm_code = self.llm.invoke(prompt) + + # Normalize content + llm_code = getattr(llm_code, "content", str(llm_code)) + + print(f"\033[92m\nCode agent output:\n{llm_code}\n\033[0m") + + # Write the code to the file if it does not exist + match = re.search(r"```python\n(.*?)\n```", llm_code, re.DOTALL) + if match: + code_to_save = match.group(1).strip() + else: + code_to_save = llm_code.strip() + + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "w") as f: + f.write(code_to_save) + print(f"Generated function code saved to {file_path}") + + return file_path, kwargs, code_to_save + + def act(self, code_file_path, **kwargs): + """Execute generated code with proper execution environment. + + Supports two modes: + 1. If code defines 'create_agent_action_list' function, call it + 2. If code contains module-level drive() calls, execute them directly + """ + import ast + + # Read the generated code file + with open(code_file_path, "r") as f: + code_content = f.read() + + # Build execution namespace with necessary imports + ns = { + "__builtins__": __builtins__, + "__name__": "__main__", + "__file__": str(code_file_path), + "kwargs": kwargs, # Make kwargs available for injection + } + + # Import toolkit functions into namespace + try: + exec( + "from embodichain.toolkits.interfaces import *", + ns, + ns, + ) + except Exception as e: + raise RuntimeError( + "Failed to import embodichain.toolkits.interfaces" + ) from e + + # Parse code to check if it defines a function or contains module-level calls + tree = ast.parse(code_content) + + # Check if code defines create_agent_action_list function + has_function = any( + isinstance(node, ast.FunctionDef) + and node.name == "create_agent_action_list" + for node in tree.body + ) + + if has_function: + # Execute code (function will be defined in namespace) + exec(code_content, ns, ns) + + # Call the function if it exists + if "create_agent_action_list" in ns: + result = ns["create_agent_action_list"](**kwargs) + print("Function executed successfully.") + return result + else: + raise AttributeError( + "The function 'create_agent_action_list' was not found after execution." + ) + else: + # Code contains module-level drive() calls + # AST transformer to inject **kwargs into function calls + class InjectKwargs(ast.NodeTransformer): + def visit_Call(self, node): + self.generic_visit(node) + # Inject **kwargs if not present + has_kwargs = any( + kw.arg is None + and isinstance(kw.value, ast.Name) + and kw.value.id == "kwargs" + for kw in node.keywords + ) + if not has_kwargs: + node.keywords.append( + ast.keyword( + arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()) + ) + ) + return node + + # Transform AST to inject kwargs + tree = InjectKwargs().visit(tree) + ast.fix_missing_locations(tree) + + # Compile and execute transformed code + compiled_code = compile(tree, filename=str(code_file_path), mode="exec") + exec(compiled_code, ns, ns) + + # Collect actions from drive() calls if they were executed + # drive() function stores actions in env._episode_action_list + if "env" in kwargs: + env = kwargs["env"] + if hasattr(env, "_episode_action_list") and env._episode_action_list: + print( + f"Collected {len(env._episode_action_list)} actions from module-level drive() calls." + ) + return env._episode_action_list + + print("Code executed successfully, but no actions were collected.") + return [] diff --git a/embodichain/agents/hierarchy/llm.py b/embodichain/agents/hierarchy/llm.py new file mode 100644 index 00000000..b86169d5 --- /dev/null +++ b/embodichain/agents/hierarchy/llm.py @@ -0,0 +1,56 @@ +import os +from langchain_openai import AzureChatOpenAI + +# ------------------------------------------------------------------------------ +# Environment configuration +# ------------------------------------------------------------------------------ + +# Clear proxy if not needed (optional, can be set via environment variables) + +os.environ["ALL_PROXY"] = "" +os.environ["all_proxy"] = "" + +# Proxy configuration (optional, uncomment if needed) +# os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" +# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" + +# API version (optional, defaults to "2024-10-21" if not set) +# os.environ["OPENAI_API_VERSION"] = "2024-10-21" + +# Note: AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set via environment variables +# Example in bash: +# export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/" +# export AZURE_OPENAI_API_KEY="your-api-key" + +# ------------------------------------------------------------------------------ +# LLM factory +# ------------------------------------------------------------------------------ + + +def create_llm(*, temperature=0.0, model="gpt-4o"): + return AzureChatOpenAI( + temperature=temperature, + model=model, + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("OPENAI_API_VERSION", "2024-10-21"), + ) + + +# ------------------------------------------------------------------------------ +# LLM instances +# ------------------------------------------------------------------------------ + + +# Initialize LLM instances, but handle errors gracefully for documentation builds +def _create_llm_safe(*, temperature=0.0, model="gpt-4o"): + try: + return create_llm(temperature=temperature, model=model) + except Exception: + return None + + +task_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +code_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +validation_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") +view_selection_llm = _create_llm_safe(temperature=0.0, model="gpt-4o") diff --git a/embodichain/agents/hierarchy/task_agent.py b/embodichain/agents/hierarchy/task_agent.py new file mode 100644 index 00000000..11db647b --- /dev/null +++ b/embodichain/agents/hierarchy/task_agent.py @@ -0,0 +1,139 @@ +from typing import List, Dict, Tuple +from embodichain.agents.hierarchy.agent_base import AgentBase +from langchain_core.prompts import ChatPromptTemplate +from embodichain.data import database_2d_dir +from embodichain.utils.utility import load_txt, encode_image +from embodichain.agents.mllm.prompt import TaskPrompt +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +from langchain_core.messages import HumanMessage +import numpy as np + +# from openai import OpenAI +import os +import time +import cv2 +import glob +import json +import re + +USEFUL_INFO = """The error may be caused by: +1. You did not follow the basic background information, especially the world coordinate system with its xyz directions. +2. You did not take into account the NOTE given in the atom actions or in the example functions. +3. You did not follow the steps of the task descriptions.\n +""" + + +def extract_plan_and_validation(text: str) -> Tuple[str, List[str], List[str]]: + def get_section(src: str, name: str, next_name) -> str: + if next_name: + pat = re.compile( + rf"\[{name}\]\s*:\s*(.*?)\s*(?=\[{next_name}\]\s*:|\Z)", + re.DOTALL | re.IGNORECASE, + ) + else: + pat = re.compile( + rf"\[{name}\]\s*:\s*(.*?)\s*\Z", + re.DOTALL | re.IGNORECASE, + ) + m = pat.search(src) + return m.group(1).strip() if m else "" + + step_re = re.compile( + r"Step\s*\d+\s*:.*?(?=Step\s*\d+\s*:|\Z)", + re.DOTALL | re.IGNORECASE, + ) + + # ---- plans ---- + plans_raw = get_section(text, "PLANS", "VALIDATION_CONDITIONS") + plan_steps = [m.group(0).rstrip() for m in step_re.finditer(plans_raw)] + plan_str = "\n".join(plan_steps) + + # normalized plan list (strip "Step k:") + plan_list = [] + for step in plan_steps: + content = re.sub(r"^Step\s*\d+\s*:\s*", "", step, flags=re.IGNORECASE).strip() + if content: + plan_list.append(content) + + # ---- validations ---- + vals_raw = get_section(text, "VALIDATION_CONDITIONS", None) + validation_list = [] + for m in step_re.finditer(vals_raw): + content = re.sub( + r"^Step\s*\d+\s*:\s*", "", m.group(0), flags=re.IGNORECASE + ).strip() + if content: + validation_list.append(content) + + return plan_str, plan_list, validation_list + + +class TaskAgent(AgentBase): + prompt: ChatPromptTemplate + object_list: List[str] + target: np.ndarray + prompt_name: str + prompt_kwargs: Dict[str, Dict] + + def __init__(self, llm, **kwargs) -> None: + super().__init__(**kwargs) + self.llm = llm + + def generate(self, **kwargs) -> str: + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = log_dir / "agent_generated_plan.txt" + + # Check if the file already exists + if not kwargs.get("regenerate", False): + if file_path.exists(): + print(f"Plan file already exists at {file_path}, skipping writing.") + return load_txt(file_path) + + # Generate query via LLM + prompts_ = getattr(TaskPrompt, self.prompt_name)(**kwargs) + if isinstance(prompts_, list): + # TODO: support two-stage prompts with feedback + start_time = time.time() + response = self.llm.invoke(prompts_[0]) + query = response.content + print( + f"\033[92m\nSystem tasks output ({np.round(time.time()-start_time, 4)}s):\n{query}\n\033[0m" + ) + for prompt in prompts_[1:]: + temp = prompt["kwargs"] + temp.update({"query": query}) + start_time = time.time() + response = self.llm.invoke(prompt["prompt"].invoke(temp)) + query = response.content + print( + f"\033[92m\nSystem tasks output({np.round(time.time()-start_time, 4)}s):\n{query}\n\033[0m" + ) + else: + # insert feedback if exists + if len(kwargs.get("error_messages", [])) != 0: + # just use the last one + last_plan = kwargs["generated_plans"][-1] + last_code = kwargs["generated_codes"][-1] + last_error = kwargs["error_messages"][-1] + + # Add extra human message with feedback + feedback_msg = self.build_feedback_message( + last_plan, last_code, last_error + ) + prompts_.messages.append(feedback_msg) + + response = self.llm.invoke(prompts_) + print(f"\033[92m\nTask agent output:\n{response.content}\n\033[0m") + + file_path.parent.mkdir(parents=True, exist_ok=True) + with open(file_path, "w") as f: + f.write(response.content) + print(f"Generated task plan saved to {file_path}") + + return response.content + + def act(self, *args, **kwargs): + return super().act(*args, **kwargs) diff --git a/embodichain/agents/hierarchy/validation_agent.py b/embodichain/agents/hierarchy/validation_agent.py new file mode 100644 index 00000000..910988a6 --- /dev/null +++ b/embodichain/agents/hierarchy/validation_agent.py @@ -0,0 +1,217 @@ +from langchain_core.prompts import ChatPromptTemplate +import os +from langchain_core.messages import SystemMessage, HumanMessage +from abc import ABCMeta +from embodichain.utils.utility import encode_image_from_path +import glob +from embodichain.agents.hierarchy.llm import view_selection_llm + + +def save_obs_image(obs_image, save_dir, step_id=None): + """ + Save observation image using encode_image() and return its file path. + """ + import base64 + from embodichain.utils.utility import encode_image + + if obs_image is None: + return None + + if isinstance(save_dir, str): + from pathlib import Path + + save_dir = Path(save_dir) + + save_dir.mkdir(parents=True, exist_ok=True) + + name = f"obs_step_{step_id}.png" if step_id is not None else "obs.png" + img_path = save_dir / name + + # Encode to base64 + base64_image = encode_image(obs_image) + + # Decode base64 → bytes + img_bytes = base64.b64decode(base64_image) + + # Write to file + with open(img_path, "wb") as f: + f.write(img_bytes) + + return img_path + + +def get_obj_position_info(env): + import json + + position_info = {} + obj_uids = env.sim.get_rigid_object_uid_list() + for obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0)[:3, 3] + position_info[obj_name] = target_obj_pose.tolist() + return json.dumps(position_info, indent=4) + + +class ValidationAgent(metaclass=ABCMeta): + + def __init__(self, llm, **kwargs) -> None: + super().__init__() + for key, value in kwargs.items(): + setattr(self, key, value) + self.llm = llm + + def validate(self, step_names, problematic_code, error_message, image_files): + # Construct the prompt + prompt = f""" + Analyze the execution of the following robot task: + + Task name: {self.task_name} + Task description: {self.task_description} + Basic background knowledge: {self.basic_background} + + You will be given images showing each step of the execution. For the step sequence: + {', '.join(step_names)} + + Provide the following analysis: + 1. Decide whether the full task succeeded or failed. + 2. If the task failed, provide a precise and detailed explanation. + + Below is a potentially problematic piece of code and the corresponding execution error: + + ```python + {problematic_code} + # Execution error: + {error_message} + Explain whether (and how) this code contributed to the observed failure. + """ + + # Prepare message content for API call + user_content = [] + + # Add textual prompt + user_content.append({"type": "text", "text": prompt}) + + # Add images and step names + for img_path in image_files: + filename = os.path.basename(img_path) + first_underscore_pos = filename.find("_") + if first_underscore_pos != -1: + step_name = filename[first_underscore_pos + 1 :].rsplit(".", 1)[0] + else: + step_name = filename.rsplit(".", 1)[0] + + # Add step name + user_content.append({"type": "text", "text": f"Step: {step_name}"}) + + # Add image as base64 + base64_image = encode_image_from_path(img_path) + user_content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + } + ) + + messages = [ + SystemMessage( + content="You are a robot task execution analysis expert. Please analyze the provided image sequence." + ), + HumanMessage(content=user_content), + ] + + response = self.llm.invoke(messages) + return response.content + + def select_best_view_dir( + self, img_dirs: dict, action_description: str, valid_condition: str + ): + """ + img_dirs: { + "cam_1": Path, + "cam_2": Path, + "cam_3": Path + } + """ + + # --- collect final images --- + last_images = {} + for cam_id, cam_dir in img_dirs.items(): + imgs = sorted( + glob.glob(os.path.join(cam_dir, "obs_step_*.png")), + key=lambda p: int(os.path.basename(p).split("_")[-1].split(".")[0]), + ) + if imgs: + last_images[cam_id] = imgs[-1] + + if not last_images: + raise ValueError("No images found in any camera directory.") + + # --- system prompt --- + system_prompt = ( + "You are a robot perception assistant specialized in VIEW SELECTION.\n\n" + "TASK:\n" + "- You are given ONE final observation image from EACH camera view.\n" + "- Your job is NOT to judge success or failure.\n" + "- Your job is ONLY to select the SINGLE camera view that is MOST SUITABLE\n" + " for OBJECT-LEVEL validation of the action result.\n\n" + "ACTION CONTEXT:\n" + "- The robot has just executed ONE atomic action.\n" + "- You are given the action intention and the expected object-level outcome\n" + " ONLY to help you decide which view best reveals that outcome.\n\n" + "SELECTION CRITERIA (PRIORITY ORDER):\n" + "- Prefer views with:\n" + " * the clearest visibility of the relevant object(s)\n" + " * minimal occlusion by the arm or environment\n" + " * the clearest evidence related to the expected object-level result\n" + " (e.g., contact, separation, support, stability)\n\n" + "STRICT CONSTRAINTS:\n" + "- Do NOT judge robot motion quality or execution accuracy.\n" + "- Do NOT reason about numeric values (distance, angle, offset).\n" + "- Do NOT decide whether the action succeeded or failed.\n" + "- If multiple views are acceptable, choose the clearest overall view.\n\n" + "OUTPUT FORMAT (STRICT):\n" + "Output EXACTLY ONE of the following tokens:\n" + "- cam_1\n" + "- cam_2\n" + "- cam_3\n" + ) + + # --- human content --- + human_content = [ + { + "type": "text", + "text": ( + "Select the best camera view for object-level validation.\n\n" + "--------------------------------------------------\n" + "ACTION DESCRIPTION (INTENT ONLY):\n" + f"{action_description}\n\n" + "EXPECTED OBJECT-LEVEL RESULT (REFERENCE ONLY):\n" + f"{valid_condition}\n" + "--------------------------------------------------" + ), + } + ] + + for cam_id, img_path in last_images.items(): + img_b64 = encode_image_from_path(img_path) + human_content.extend( + [ + {"type": "text", "text": f"View candidate: {cam_id}"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + }, + ] + ) + + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=human_content), + ] + + response = view_selection_llm.invoke(messages).content.strip() + + if response not in img_dirs: + raise ValueError(f"Invalid camera selection from LLM: {response}") + + return response diff --git a/embodichain/agents/mllm/prompt/__init__.py b/embodichain/agents/mllm/prompt/__init__.py new file mode 100644 index 00000000..55bc408b --- /dev/null +++ b/embodichain/agents/mllm/prompt/__init__.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .task_prompt import TaskPrompt +from .code_prompt import CodePrompt diff --git a/embodichain/agents/mllm/prompt/code_prompt.py b/embodichain/agents/mllm/prompt/code_prompt.py new file mode 100644 index 00000000..efbaae75 --- /dev/null +++ b/embodichain/agents/mllm/prompt/code_prompt.py @@ -0,0 +1,138 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from embodichain.utils.utility import encode_image, encode_image_from_path + + +class CodePrompt: + @staticmethod + def one_stage_prompt(**kwargs) -> ChatPromptTemplate: + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are an AI assistant that can generate python code to execute robot arms." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "text", + "text": ( + "Generate a Python code snippet that accomplishes the following task:\n" + "{query}\n\n" + "You must strictly follow the rules and available functions described below:\n" + "{code_prompt}\n\n" + "Here are some reference examples of the expected output code:\n" + "{code_example}\n\n" + ), + } + ] + ), + ] + ) + return prompt.invoke(kwargs) + + @staticmethod + def unified_prompt(observations, **kwargs): + """ + Unified Vision→Code prompt: + - Model observes the image + - Understands the scene and the task goal + - Generates final executable Python code using atomic robot APIs + """ + + # Encode the image + observation = observations["rgb"] + kwargs.update({"observation": encode_image(observation)}) + + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a reliable Vision-Language-Code robot assistant. " + "You observe an image, understand the scene and the task goal, " + "and generate correct Python code using ONLY the allowed atomic robot actions. " + "Your final output must be a single Python code block." + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "### Task Goal\n" + "{task_prompt}\n\n" + "### Environment Background\n" + "{basic_background}\n\n" + "### Allowed Atomic Actions\n" + "{atom_actions}\n\n" + "### Code Rules\n" + "{code_prompt}\n\n" + "### Reference Code Examples\n" + "{code_example}\n\n" + "### Final Instructions\n" + "Understand the scene from the image and generate final executable Python code " + "that performs the task using ONLY the allowed atomic actions.\n\n" + "Your entire response must be EXACTLY one Python code block:\n" + "```python\n" + "# your solution code here\n" + "```\n" + ), + }, + ] + ), + ] + ) + + return prompt.invoke(kwargs) + + @staticmethod + def one_stage_prompt_according_to_task_plan(**kwargs) -> ChatPromptTemplate: + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a reliable robot control code generator.\n" + "Your task is to generate Python code that executes robot arm actions.\n\n" + "CRITICAL RULES:\n" + "- The TASK PLAN defines the available atomic actions, rules, and execution logic.\n" + "- You MUST strictly follow the TASK PLAN.\n" + "- The CONSTRAINTS section contains additional global constraints you must obey.\n" + "- Do NOT invent new actions, functions, parameters, or control flow.\n" + "- You MAY include Python comments (# ...) inside the code.\n" + "- Your ENTIRE response MUST be a single Python code block.\n" + "- The code block MUST be directly executable without modification.\n" + "- Do NOT include any text, explanation, or markdown outside the Python code block.\n" + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "text", + "text": ( + "TASK PLAN (atomic actions, rules, and intended behavior):\n" + "{task_plan}\n\n" + "GLOBAL CONSTRAINTS (must be satisfied):\n" + "{code_prompt}\n\n" + "REFERENCE CODE (style and structure only; do NOT copy logic):\n" + "{code_example}\n\n" + "Generate the corrected Python code now." + ), + } + ] + ), + ] + ) + return prompt.invoke(kwargs) diff --git a/embodichain/agents/mllm/prompt/task_prompt.py b/embodichain/agents/mllm/prompt/task_prompt.py new file mode 100644 index 00000000..3fa8ca9e --- /dev/null +++ b/embodichain/agents/mllm/prompt/task_prompt.py @@ -0,0 +1,134 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- +import torch +import numpy as np +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from embodichain.utils.utility import encode_image, encode_image_from_path + + +class TaskPrompt: + @staticmethod + def one_stage_prompt(observations, **kwargs): + """ + Hybrid one-pass prompt: + Step 1: VLM analyzes the image and extracts object IDs. + Step 2: LLM generates task instructions using only those IDs. + """ + # Encode image + observation = ( + observations["rgb"].cpu().numpy() + if isinstance(observations["rgb"], torch.Tensor) + else observations["rgb"] + ) + kwargs.update({"observation": encode_image(observation)}) + + # Build hybrid prompt + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a precise and reliable robotic manipulation planner. " + "Given a camera observation and a task description, you must generate " + "a clear, step-by-step task plan for a robotic arm. " + "All actions must strictly use the provided atomic API functions, " + "and the plan must be executable without ambiguity." + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "Here is the latest camera observation.\n" + "First, analyze the scene in the image.\n" + "Then, using the context below, produce an actionable task plan.\n\n" + "**Environment background:** \n{basic_background}\n\n" + '**Task goal:** \n"{task_prompt}"\n\n' + "**Available atomic actions:** \n{atom_actions}\n" + ), + }, + ] + ), + ] + ) + + # Return the prompt template and kwargs to be executed by the caller + return prompt.invoke(kwargs) + + @staticmethod + def two_stage_prompt(observations, **kwargs): + # for VLM generate image descriptions + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are a helpful assistant to operate a robotic arm with a camera to generate task plans according to descriptions." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpg;base64,{observation}", + }, + }, + { + "type": "text", + "text": "What is in the image? Return answer with their potential effects.", + }, + ] + ), + ] + ) + + observation = ( + observations["rgb"].cpu().numpy() + if isinstance(observations["rgb"], torch.Tensor) + else observations["rgb"] + ) + kwargs.update({"observation": encode_image(observation)}) + # for LLM generate task descriptions + prompt_query = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content="You are a helpful assistant to operate a robotic arm with a camera to generate task plans according to descriptions." + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpg;base64,{observation}", + }, + }, + { + "type": "text", + "text": "Here is analysis for this image: {query}.", + }, + { + "type": "text", + "text": ( + "Using the context below, produce an actionable task plan.\n\n" + "**Environment background:** \n{basic_background}\n\n" + '**Task goal:** \n"{task_prompt}"\n\n' + "**Available atomic actions:** \n{atom_actions}\n" + ), + }, + ] + ), + ] + ) + + return [prompt.invoke(kwargs), {"prompt": prompt_query, "kwargs": kwargs}] diff --git a/embodichain/data/__init__.py b/embodichain/data/__init__.py index 9e152ab9..fccbbc24 100644 --- a/embodichain/data/__init__.py +++ b/embodichain/data/__init__.py @@ -14,5 +14,13 @@ # limitations under the License. # ---------------------------------------------------------------------------- +import os + + +database_dir = os.path.dirname(os.path.abspath(__file__)).replace("data", "database") +database_2d_dir = os.path.join(database_dir, "2dasset") +database_agent_prompt_dir = os.path.join(database_dir, "agent_prompt") +database_demo_dir = os.path.join(database_dir, "demostration") + from . import assets from .dataset import * diff --git a/embodichain/data/data_engine/__init__.py b/embodichain/data/data_engine/__init__.py new file mode 100644 index 00000000..6488c766 --- /dev/null +++ b/embodichain/data/data_engine/__init__.py @@ -0,0 +1,5 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- diff --git a/embodichain/data/data_engine/compressed_hdf5.py b/embodichain/data/data_engine/compressed_hdf5.py new file mode 100644 index 00000000..c050b766 --- /dev/null +++ b/embodichain/data/data_engine/compressed_hdf5.py @@ -0,0 +1,545 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_warning, log_error + +try: + import h5ffmpeg as hf + + has_h5ffmpeg = True +except Exception as e: + has_h5ffmpeg = False + log_warning("Fail to import h5ffmpeg.") + +import h5py +import numpy as np + +from typing import Dict, Any, List, Union, Optional +from embodichain.data.enum import ( + Modality, + PrivilegeType, +) +from tqdm import tqdm + +SCALE_FACTOR = 4e3 # Scale factor for depth data +FAR_CLIP = 4.0 # m + + +class CompressedVideoHDF5: + def __init__(self, save_path: str, chunks: int = 20) -> None: + """ + Initializes the data dictionary extractor with the specified save path and number of chunks. + Attempts to configure video encoding settings based on the detected GPU model using the h5ffmpeg library. + Supported GPUs include NVIDIA A800 and NVIDIA GeForce RTX 3060, with specific encoding configurations for each. + If the GPU is unsupported or an error occurs during initialization, a warning is logged and default configuration is used. + + Args: + save_path (str): Path where extracted data will be saved. + chunks (int, optional): Number of chunks to split the data into. Defaults to 20. + + Raises: + ValueError: If the detected GPU is not supported. + """ + self.save_path = save_path + self.chunks = chunks + + try: + import h5ffmpeg as hf + import torch + + name = torch.cuda.get_device_name() + + if "A800" in name or name == "NVIDIA A800-SXM4-80GB": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "3060" in name or name == "NVIDIA GeForce RTX 3060": + self.conf = { + Modality.GEOMAP.value: hf.h264_nvenc(), + Modality.IMAGES.value: hf.h264_nvenc(), + PrivilegeType.MASK.value: hf.h264_nvenc(), + } + elif "3090" in name or name == "NVIDIA GeForce RTX 3090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "4090" in name or name == "NVIDIA GeForce RTX 4090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "Orin" in name: + # FIXME: temporary solution for Orin GPU. Need to test and adjust parameters later for nvenc encoder. + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + else: + raise ValueError("Unsupported GPU: {}".format(name)) + + except Exception as e: + log_warning( + "{}. Please make sure h5ffmpeg is successfully installed.".format(e) + ) + self.conf = {} + + @staticmethod + def is_compressed_hdf5(data: Dict) -> bool: + images_group = data.get("observations", {}).get(Modality.IMAGES.value, {}) + has_compressed_keys = any( + (isinstance(k, str) and ("index" in k or "start" in k)) + for k in images_group.keys() + ) + return has_compressed_keys + + @staticmethod + def get_chunk_name(name: str, id: Union[int, str]) -> str: + """ + Generates a chunk name by concatenating the given name with the provided id, separated by an underscore. + Args: + name (str): The base name for the chunk. + id (Union[int, str]): The identifier to append to the name. + Returns: + str: The resulting chunk name in the format 'name_id'. + """ + + return name + "_{}".format(id) + + @staticmethod + def video_save( + f, + chunks: int, + data: Dict[str, np.ndarray], + key: str, + dtype=np.uint8, + conf: Dict = None, + ): + """ + Saves video data from multiple cameras into an HDF5 file, splitting the data into chunks for efficient storage. + Args: + f: An open HDF5 file handle where the video data will be saved. + data (Dict[str, np.ndarray]): Dictionary mapping camera names to their corresponding video data arrays. + key (str): Key under "observations" group in the HDF5 file to store the video data. + dtype (type, optional): Data type to convert the video frames to before saving (default: np.uint8). + conf (Dict, optional): Additional configuration parameters for HDF5 dataset creation. + Notes: + - Video data for each camera is processed and split into the specified number of chunks. + - Index and start datasets are created for each camera to map frame indices to chunk IDs and chunk start indices. + - Uses CompressedVideoHDF5 utility functions for data formatting and conversion. + - Progress is displayed using tqdm for each chunk being saved. + """ + import h5ffmpeg as hf + + f_images = f["observations"].create_group(key) + + for cam_name in data.keys(): + data_ = data[cam_name] + if len(data_) != 0: + data_ = CompressedVideoHDF5.to_bhw(data_) + + if dtype == np.uint16: + data_ = CompressedVideoHDF5.uint16_depth(data_) + else: + data_ = data_.astype(dtype) + + data_chunks = np.array_split(data_, chunks, axis=0) + data_chunk_ids = np.arange(data_.shape[0]) + data_chunk_ids_ = np.array_split(data_chunk_ids, chunks) + idtochunkid = np.zeros((data_.shape[0])) + chunkid2startid = np.zeros((chunks,)) + for chunkid, temp in enumerate(data_chunk_ids_): + chunkid2startid[chunkid] = min(temp) + for tempi in temp: + idtochunkid[tempi] = chunkid + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "index"), + data=idtochunkid, + ) + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "start"), + data=chunkid2startid, + ) + + for t, data_chunk in enumerate(tqdm(data_chunks)): + _ = f_images.create_dataset( + "{}/{}".format(cam_name, t), + data=data_chunk, + chunks=data_chunk.shape, + **conf, + ) + + @staticmethod + def uint16_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts a depth data array to a uint16 format after applying scaling and clipping. + Args: + data (np.ndarray): The input depth data as a NumPy array. + scale_factor (float, optional): The factor by which to scale the depth data. + Defaults to SCALE_FACTOR. + far_clip (float, optional): The maximum depth value (far clipping plane) + before scaling. Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled and clipped depth data as a NumPy array of type uint16. + """ + return (np.clip(data * scale_factor, 0, far_clip * scale_factor)).astype( + np.uint16 + ) + + @staticmethod + def float32_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts depth data to float32 and scales it by the given scale factor. + Args: + data (np.ndarray): The input depth data array. + scale_factor (float, optional): The factor by which to scale the depth values. Defaults to SCALE_FACTOR. + far_clip (float, optional): The far clipping distance (unused in this function). Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled depth data as a float32 numpy array. + """ + + return data.astype(np.float32) / scale_factor + + @staticmethod + def to_bhw(data: np.ndarray) -> np.ndarray: + """ + Reshapes a 4D numpy array from (vdepth, height, width, channels) to (vdepth, height, width * channels). + If the input is already a 3D array, returns it unchanged. + Args: + data (np.ndarray): Input array of shape (vdepth, height, width, channels) or (vdepth, height, width). + Returns: + np.ndarray: Reshaped array of shape (vdepth, height, width * channels) or the original array if 3D. + Raises: + Logs an error if the input array does not have 3 or 4 dimensions. + """ + + if len(data.shape) == 4: + vdepth, h, w, channels = ( + data.shape[0], + data.shape[1], + data.shape[2], + data.shape[3], + ) + return data.reshape(vdepth, h, w * channels) + elif len(data.shape) == 3: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + @staticmethod + def to_bhwc(data: np.ndarray): + """ + Converts a numpy array to BHWC (Batch, Height, Width, Channels) format. + If the input array has 3 dimensions, it reshapes the array to have a channel dimension of size 3. + If the input array already has 4 dimensions, it returns the array unchanged. + Otherwise, logs an error for unsupported shapes. + Args: + data (np.ndarray): Input numpy array to be converted. + Returns: + np.ndarray: Array in BHWC format. + Raises: + Logs an error if the input array shape is not supported. + """ + + if len(data.shape) == 3: + vdepth, h, w = data.shape + return data.reshape(vdepth, h, -1, 3) + elif len(data.shape) == 4: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + def dump( + self, + ret: Dict, + video_names: List[str] = [ + Modality.IMAGES.value, + PrivilegeType.MASK.value, + Modality.GEOMAP.value, + ], + dtypes: List = [np.uint8, np.uint8, np.uint16], + ): + """ + Dumps the provided data into an HDF5 file, saving specific video data with + compression and specified data types. + Args: + ret (Dict): The data dictionary containing observations and other metadata. + video_names (List[str], optional): A list of video names to extract from + the observations. Defaults to [Modality.IMAGES.value, PrivilegeType.MASK.value, Modality.GEOMAP.value]. + dtypes (List, optional): A list of data types corresponding to each video + name. Defaults to [np.uint8, np.uint8, np.uint16]. + Raises: + AssertionError: If the lengths of `video_names` and `dtypes` are not equal. + RuntimeError: If the configuration (`self.conf`) is empty, indicating that + `h5ffmpeg` is not installed or configured properly. + Notes: + - The method modifies the `ret` dictionary by temporarily removing the + specified video data during the HDF5 file creation process and then + restoring it afterward. + - The `hdfdict.dump` function is used to save the remaining data in the + dictionary, while the `CompressedVideoHDF5.video_save` function handles + the saving of video data with compression. + """ + + assert len(video_names) == len( + dtypes + ), "Inequal length of video names {} and dtypes {}.".format(video_names, dtypes) + import hdfdict + + if self.conf == {}: + raise RuntimeError( + "Please make sure h5ffmpeg is successfully installed before using `dump`." + ) + + pop_ret = {} + for video_name, dtype in zip(video_names, dtypes): + video_data = ret["observations"].pop(video_name) + pop_ret[video_name] = video_data + + # Open the file once and pass the open file object to hdfdict.dump so + # h5py doesn't try to truncate the same path while it is already open. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + for video_name, dtype in zip(video_names, dtypes): + CompressedVideoHDF5.video_save( + f, + self.chunks, + pop_ret[video_name], + video_name, + dtype=dtype, + conf=self.conf[video_name], + ) + + ret["observations"].update(pop_ret) + + @staticmethod + def decode_resources( + f: Dict, + ret: Dict, + name: str, + slice_id: int, + condition: callable, + function: callable, + padding: bool = True, + chunk_id: int = None, + ): + """ + Decodes and processes resources from a hierarchical data structure, applying + a condition and transformation function to the data, and optionally adding + zero-padding. + Args: + f (Dict): The input data dictionary containing observations and metadata. + ret (Dict): The output data dictionary where processed data will be stored. + name (str): The key name under "observations" to access specific data. + slice_id (int): The slice index used to retrieve the corresponding chunk ID. + condition (callable): A function that takes the data as input and returns + a boolean indicating whether the transformation function should be applied. + function (callable): A function to transform the data if the condition is met. + padding (bool, optional): Whether to add zero-padding to the data. Defaults to True. + chunk_id (int, optional): The chunk ID to use instead of deriving it from the slice ID. + Defaults to None. + Returns: + None: The function modifies the `ret` dictionary in place. + """ + + import time + + images = f["observations"][name] + + for cam_name in images.keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + + start_time = time.time() + sliceid2chunkid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "index") + ][:] + chunkid = int(sliceid2chunkid[slice_id]) if chunk_id is None else chunk_id + data_ = images[cam_name][str(chunkid)][:] + # log_warning("".format(time.time() - start_time) + if condition(data_): + data_ = function(data_) + + if padding: + chunkid2startid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "start") + ][:] + start_idx = chunkid2startid[chunkid] + zero_padding = np.zeros_like(data_)[0:1] + zero_padding = np.repeat(zero_padding, repeats=start_idx, axis=0) + ret["observations"][name][cam_name] = np.concatenate( + [zero_padding, data_], 0 + ) + else: + if ret["observations"][name][cam_name] is None: + ret["observations"][name][cam_name] = data_ + else: + ret["observations"][name][cam_name] = np.concatenate( + [ret["observations"][name][cam_name], data_], 0 + ) + + def safe_filter(self, f: Dict, slice_id: int = None) -> Dict: + """ + Filters and processes the input data dictionary based on the configuration + and specified slice ID. + Args: + f (Dict): The input data dictionary containing observations, including + images, masks, and geomap. + slice_id (int, optional): The specific slice ID to process. If None, + processes all chunks. Defaults to None. + Returns: + Dict: The filtered and processed data dictionary with updated + observations for images, masks, and geomap. + Notes: + - The method filters out camera names containing "index" or "start". + - It initializes the return dictionary with None values for images, + masks, and geomap for the filtered camera names. + - Depending on the `slice_id`, it either processes all chunks or a + specific slice using the `CompressedVideoHDF5.decode_resources` + method. + - The processed observations are updated in the input dictionary `f`. + """ + + if self.conf is {}: + return f + + cam_names = [] + for cam_name in f["observations"][Modality.IMAGES.value].keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + cam_names.append(cam_name) + + # Only build return structure for actually present modalities, avoid errors when real data lacks mask/geomap + present_modalities = [] + if Modality.IMAGES.value in f["observations"]: + present_modalities.append(Modality.IMAGES.value) + if PrivilegeType.MASK.value in f["observations"]: + present_modalities.append(PrivilegeType.MASK.value) + if Modality.GEOMAP.value in f["observations"]: + present_modalities.append(Modality.GEOMAP.value) + + ret = {"observations": {}} + for modality_key in present_modalities: + ret["observations"][modality_key] = { + cam_name: None for cam_name in cam_names + } + + if slice_id == None: + # For all chunks + for chunk_id_ in range(self.chunks): + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + None, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + chunk_id=chunk_id_, + padding=False, + ) + + else: + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + slice_id, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + ) + if Modality.IMAGES.value in present_modalities: + f["observations"][Modality.IMAGES.value] = ret["observations"][ + Modality.IMAGES.value + ] + if PrivilegeType.MASK.value in present_modalities: + f["observations"][PrivilegeType.MASK.value] = ret["observations"][ + PrivilegeType.MASK.value + ] + if Modality.GEOMAP.value in present_modalities: + f["observations"][Modality.GEOMAP.value] = ret["observations"][ + Modality.GEOMAP.value + ] + + return f diff --git a/embodichain/data/data_engine/data_dict_extractor.py b/embodichain/data/data_engine/data_dict_extractor.py new file mode 100644 index 00000000..5c5f3c4e --- /dev/null +++ b/embodichain/data/data_engine/data_dict_extractor.py @@ -0,0 +1,801 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_warning, log_error + +try: + import h5ffmpeg as hf + + has_h5ffmpeg = True +except Exception as e: + has_h5ffmpeg = False + log_warning("Fail to import h5ffmpeg.") + +import h5py +import os +import random +import torch +import numpy as np + +from functools import cached_property +from typing import Dict, Any, List, Union, Optional +from embodichain.data.enum import ( + HandQposNormalizer, + Modality, + PrivilegeType, + JointType, + ActionMode, + EefType, + EndEffector, + ControlParts, + ArmName, +) +from embodichain.data.global_mapping import GlobalMapping +from embodichain.lab.sim.sensors import StereoCamera +from embodichain.lab.sim.objects import Robot +from embodichain.lab.gym.envs import BaseEnv, EmbodiedEnv +from embodichain.lab.gym.utils.gym_utils import map_qpos_to_eef_pose +from embodichain.utils.utility import get_right_name +from embodichain.lab.gym.robots.interface import LearnableRobot +from embodichain.lab.gym.utils.misc import is_binocularcam, _data_key_to_control_part +from embodichain.utils import logger +from embodichain.data.data_engine.indices_unifier import ( + StateUnifier, +) +from embodichain.data.data_engine.compressed_hdf5 import CompressedVideoHDF5 +from embodichain.data.enum import ( + SUPPORTED_PROPRIO_TYPES, + SUPPORTED_ACTION_TYPES, + SUPPORTED_EXTRA_VISION_TYPES, +) +from copy import deepcopy +from embodichain.lab.gym.envs.action_bank.utils import ( + get_control_part_joint_ids, +) + +DATA_FORMATS = { + "observations": { + Modality.IMAGES.value: {}, + Modality.GEOMAP.value: {}, + PrivilegeType.MASK.value: {}, + PrivilegeType.EXTEROCEPTION.value: {}, + Modality.STATES.value: {}, + }, + Modality.ACTIONS.value: {}, +} + + +class ActStateStatistic: + def __init__(self, data_dict: Dict, min_len_steps: int) -> None: + self.data_dict = data_dict + self.min_len_steps = min_len_steps + + def prepare_state_and_action( + self, + ): + proprio = self.data_dict["observations"][Modality.STATES.value][:] + num_steps = proprio.shape[0] + # [Optional] We drop too-short episode + if num_steps < self.min_len_steps: + return False, None + # [Optional] We skip the first few still steps + EPS = 1e-2 + # Get the idx of the first qpos whose delta exceeds the threshold + proprio_delta = np.abs(proprio - proprio[0:1]) + indices = np.where(np.any(proprio_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + target_actions = self.data_dict[Modality.ACTIONS.value][:] + # Parse the state and action + state = proprio[first_idx - 1 :] + action = target_actions[first_idx - 1 :] + # Return the resulting sample + + return True, {Modality.STATES.value: state, Modality.ACTIONS.value: action} + + def statistic( + self, + ) -> Dict: + EPS = 1e-8 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + _, episode = self.prepare_state_and_action() + episode_cnt += 1 + + states = episode[Modality.STATES.value] + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) + - (z_state_sum / state_cnt) ** 2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + ) + ).tolist(), + "state_min": state_min.tolist(), + "state_max": state_max.tolist(), + } + + return result + + +class DataDictExtractor: + def __init__( + self, + env: Union[BaseEnv, EmbodiedEnv], + save_path: str = None, + compression_opts: int = 9, + ): + self.env = env + self.save_path = save_path + self.data = {} + self.filtered_action_types = [] + self.filtered_proprio_types = [] + + # First check if control_parts exists and is non-empty. Only filter if valid, else use the original types. + control_parts = self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", None + ) + if control_parts and len(control_parts) > 0: + control_parts_set = set(control_parts) + self.filtered_proprio_types = [ + proprio_name + for proprio_name in SUPPORTED_PROPRIO_TYPES + if any(part in proprio_name for part in control_parts_set) + ] + self.filtered_action_types = [ + action_name + for action_name in SUPPORTED_ACTION_TYPES + if any(part in action_name for part in control_parts_set) + ] + + if ( + len(self.filtered_proprio_types) == 0 + or len(self.filtered_action_types) == 0 + ): + log_warning( + "No control parts found in the robot metadata. Using all supported proprio and action types." + ) + self.filtered_proprio_types = list(SUPPORTED_PROPRIO_TYPES) + self.filtered_action_types = list(SUPPORTED_ACTION_TYPES) + + # save all supported proprio and action types. + robot_meta_config = deepcopy(self.env.metadata["dataset"]["robot_meta"]) + robot_meta_config["observation"][ + Modality.STATES.value + ] = self.filtered_proprio_types + robot_meta_config[Modality.ACTIONS.value] = self.filtered_action_types + + self.state_unifier = StateUnifier(robot_meta=robot_meta_config) + self.compression_opts = compression_opts + + @cached_property + def robot_control_parts(self) -> List[str]: + """Get the robot's control parts. + + Note: + If control_parts is specified in the robot metadata, return those parts. + Otherwise, return all control parts. + + Returns: + List[str]: The robot's control parts. + """ + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + control_parts = robot_meta_config.get("control_parts", None) + if control_parts is None: + log_warning( + "Please make sure you have configurated the control parts. This branch may cause underlying error for training data." + ) + return [] + else: + return control_parts + + def _get_arm_control_parts(self) -> List[str]: + control_parts = self.robot_control_parts + arm_control_parts = [] + for part in control_parts: + if "arm" in part: + arm_control_parts.append(part) + return arm_control_parts + + def _has_exteroception(self) -> bool: + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + return PrivilegeType.EXTEROCEPTION.value in robot_meta_config["observation"] + + def extract( + self, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + data_dict: Dict = DATA_FORMATS, + save: bool = True, + ): + if save: + assert ( + self.save_path is not None + ), "Please provide a save path for the dataset." + data_dict = deepcopy(data_dict) + + self._init_data(data_dict) + + ret = {} + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + + if isinstance(self.env, BaseEnv): + for i, (obs, action) in enumerate(zip(obs_list, action_list)): + self._extract_vision_obs(obs, data_dict) + self._extract_proprioception(obs, data_dict) + self._extract_action(action, data_dict) + action = self._collate_action(data_dict) + proprio = self._collate_proprio(data_dict) + else: + for i, (obs, action) in enumerate(zip(obs_list, action_list)): + self._extract_vision_obs_v2(obs, data_dict) + self._extract_proprioception_v2(obs, data_dict) + self._extract_action_v2(action, data_dict) + action = self._collate_action(data_dict) + proprio = self._collate_proprio(data_dict) + + robot_meta = self._collate_metainfo() + + extra_vision_config = robot_meta_config["observation"]["vision"] + obs = {"observations": {}} + images = self.collate_sub_anns( + data_dict, extra_vision_config, Modality.IMAGES.value + ) + obs["observations"].update(proprio) + obs["observations"].update(images) + + extra_vision_names = list( + set([name for list in extra_vision_config.values() for name in list]) + ) + for extra_vision_name in extra_vision_names: + extra_vision_obs = self.collate_sub_anns( + data_dict, extra_vision_config, extra_vision_name + ) + obs["observations"].update(extra_vision_obs) + + ret.update(robot_meta) + ret.update(obs) + ret.update(action) + + statistics = ActStateStatistic( + ret, self.env.metadata["dataset"]["robot_meta"]["min_len_steps"] + ).statistic() + ret.update(statistics) + + if save: + if has_h5ffmpeg: + cvhdf5 = CompressedVideoHDF5(self.save_path) + all_video_names = [Modality.IMAGES.value] + [ + name + for name in extra_vision_names + if name != PrivilegeType.EXTEROCEPTION.value + ] + all_dtypes = [ + np.uint16 if name == Modality.GEOMAP.value else np.uint8 + for name in all_video_names + ] + cvhdf5.dump(ret, video_names=all_video_names, dtypes=all_dtypes) + else: + logger.log_info( + "h5ffmpeg is not installed, saving dataset without compression." + ) + import hdfdict + + # Open the file once and pass the file object to hdfdict.dump to + # avoid opening/truncating the same file path twice which causes + # "unable to truncate a file which is already open" errors on + # some platforms and HDF5 builds. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + + return ret + + def _init_data(self, data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for proprio_name in self.filtered_proprio_types: + data_dict["observations"][Modality.STATES.value][proprio_name] = [] + for action_name in self.filtered_action_types: + data_dict[Modality.ACTIONS.value][action_name] = [] + + for camera_name, extra_vision_list in extra_vision_config.items(): + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name] = [] + if is_stereo: + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ] = [] + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + data_dict["observations"][extra_vision_name][camera_name] = [] + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ] = [] + + def _extract_vision_obs(self, obs: Dict[str, Any], data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for camera_name, extra_vision_list in extra_vision_config.items(): + if camera_name in obs["sensor"]: + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name].append( + obs["sensor"][camera_name]["rgb"] + ) + if is_stereo: + # save rgb right + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ].append(obs["sensor"][camera_name]["rgb_right"]) + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + if extra_vision_name == PrivilegeType.EXTEROCEPTION.value: + if is_stereo: + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs[extra_vision_name][camera_name]["l"]) + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append(obs[extra_vision_name][camera_name]["r"]) + elif camera_name in obs.get(extra_vision_name, {}): + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs[extra_vision_name][camera_name]) + elif extra_vision_name == PrivilegeType.MASK.value: + # save semantic mask for monocular cameras + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["semantic_mask_l"].astype( + np.uint8 + ) + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name][ + "semantic_mask_r" + ].astype(np.uint8), + ) + elif extra_vision_name == Modality.GEOMAP.value: + if not is_stereo: + log_error( + f"Camera {camera_name} is not stereo, while '{extra_vision_name}' is in gym_config.dataset.robot_meta.vision, please check again." + ) + if "depth" in obs["sensor"][camera_name]: + data_dict["observations"][extra_vision_name][ + camera_name + ].append(obs["sensor"][camera_name]["depth"]) + else: + log_error( + f"obs['sensor'][{camera_name}] has no key named 'depth' while it's required in gym_config.dataset.robot_meta.vision, please check again." + ) + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + else: + logger.log_error( + f"Camera {camera_name} not found in observations, please check your sensor configuration in gym_config.json" + ) + + def _extract_vision_obs_v2(self, obs: Dict[str, Any], data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for camera_name, extra_vision_list in extra_vision_config.items(): + if camera_name in obs["sensor"]: + is_stereo = is_binocularcam(self.env.get_sensor(camera_name)) + + data_dict["observations"][Modality.IMAGES.value][camera_name].append( + obs["sensor"][camera_name]["color"] + .squeeze(0)[:, :, :3] + .cpu() + .numpy() + ) + if is_stereo: + # save rgb right + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["color_right"] + .squeeze_(0)[:, :, :3] + .cpu() + .numpy() + ) + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + if extra_vision_name == PrivilegeType.EXTEROCEPTION.value: + if is_stereo: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name]["l"] + .cpu() + .numpy() + ) + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs[extra_vision_name][camera_name]["r"] + .cpu() + .numpy() + ) + elif camera_name in obs.get(extra_vision_name, {}): + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name].cpu().numpy() + ) + elif extra_vision_name == PrivilegeType.MASK.value: + # save semantic mask for monocular cameras + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["semantic_mask_l"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["semantic_mask_r"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + elif extra_vision_name == Modality.GEOMAP.value: + if not is_stereo: + log_error( + f"Camera {camera_name} is not stereo, while '{extra_vision_name}' is in gym_config.dataset.robot_meta.vision, please check again." + ) + if "depth" in obs["sensor"][camera_name]: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["depth"] + .squeeze_() + .numpy() + ) + else: + log_error( + f"obs['sensor'][{camera_name}] has no key named 'depth' while it's required in gym_config.dataset.robot_meta.vision, please check again." + ) + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + else: + logger.log_error( + f"Camera {camera_name} not found in observations, please check your sensor configuration in gym_config.json" + ) + + def _extract_action( + self, + action: Dict[str, Any], + data_dict: Dict, + ): + + agent: LearnableRobot = self.env.get_agent() + # extract qpos. + for key in data_dict[Modality.ACTIONS.value].keys(): + indices = agent.get_data_index(key, warning=False) + if len(indices) > 0: + action_data = action[JointType.QPOS.value][indices].copy() + action_data = HandQposNormalizer.normalize_hand_qpos( + action_data, key, agent=agent + ) + data_dict[Modality.ACTIONS.value][key].append(action_data) + qpos = action[JointType.QPOS.value] + action_eef_pose_dict = agent.map_env_qpos_to_eef_pose( + np.array([qpos]), to_dict=True + ) + for key, val in action_eef_pose_dict.items(): + data_dict[Modality.ACTIONS.value][key].append(val[0]) + + def _extract_action_v2( + self, + action: torch.Tensor, + data_dict: Dict, + ): + robot: Robot = self.env.robot + + for key in data_dict[Modality.ACTIONS.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = get_control_part_joint_ids(self.env, key) + qpos_data = ( + action[0, indices].cpu().numpy() + if isinstance(action, torch.Tensor) + else action[0, indices] + ) + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, part, robot=robot + ) + data_dict[Modality.ACTIONS.value][key].append(qpos_data) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, action, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict[Modality.ACTIONS.value][key].append( + val.squeeze_(0).cpu().numpy() + if isinstance(val, torch.Tensor) + else val.squeeze_(0) + ) + + def _extract_proprioception( + self, + obs: Dict[str, Any], + data_dict: Dict, + ): + agent: LearnableRobot = self.env.get_agent() + # extract qpos. + qpos = obs["agent"][agent.uid][JointType.QPOS.value] + for key in data_dict["observations"][Modality.STATES.value].keys(): + indices = agent.get_data_index(key, warning=False) + if len(indices) > 0: + qpos_data = qpos[ + indices + ].copy() # Deep copy to avoid modifying original data + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, key, agent=agent + ) + data_dict["observations"][Modality.STATES.value][key].append(qpos_data) + + eef_pose_dict: Dict = agent.map_env_qpos_to_eef_pose( + np.array([qpos]), to_dict=True + ) + for key, val in eef_pose_dict.items(): + data_dict["observations"][Modality.STATES.value][key].append(val[0]) + + def _extract_proprioception_v2( + self, + obs: Dict[str, Any], + data_dict: Dict, + ): + robot: Robot = self.env.robot + + qpos = obs["robot"][JointType.QPOS.value] + for key in data_dict["observations"][Modality.STATES.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = get_control_part_joint_ids(self.env, key) + qpos_data = qpos[0][indices].cpu().numpy() + qpos_data = HandQposNormalizer.normalize_hand_qpos( + qpos_data, part, robot=robot + ) + data_dict["observations"][Modality.STATES.value][key].append(qpos_data) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, qpos, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict["observations"][Modality.STATES.value][key].append( + val.squeeze_(0).cpu().numpy() + ) + + def _collate_proprio(self, data_dict: Dict) -> Dict: + proprio_dict = {} + for proprio_name in self.state_unifier.proprio_meta: + proprio = np.array( + data_dict["observations"][Modality.STATES.value][proprio_name] + ) + proprio_dict[proprio_name] = proprio + proprios = self.state_unifier.fill_in_state(proprio_dict) + return {Modality.STATES.value: proprios} + + def _collate_metainfo( + self, + ) -> Dict: + meta_info = { + "arm_dofs": self.env.metadata["dataset"]["robot_meta"].get("arm_dofs", 12), + "observation": self.env.metadata["dataset"]["robot_meta"].get( + "observation", {} + ), + "min_len_steps": self.env.metadata["dataset"]["robot_meta"].get( + "min_len_steps", 125 + ), + } + return { + "robot_meta": meta_info, + "instruction": { + "lang": self.env.metadata["dataset"]["instruction"].get("lang", "") + }, + } + + def _collate_action(self, data_dict: Dict) -> Dict: + action_data_dict = data_dict[Modality.ACTIONS.value] + for k, v in action_data_dict.items(): + action_data_dict[k] = np.array(v) + + action_dict = {} + action_dict.update(action_data_dict) + action = self.state_unifier.fill_in_action(action_dict) + return {Modality.ACTIONS.value: action} + + @staticmethod + def collate_sub_anns( + data_dict: Dict, + extra_vision_config: Dict, + key: str = Modality.IMAGES.value, + ) -> Dict: + ret = {key: {}} + for camera_name in extra_vision_config: + images_list = data_dict["observations"][key].pop(camera_name, None) + if images_list is None: + continue + if len(images_list) > 0: + ret[key][camera_name] = np.empty( + (len(images_list),) + images_list[0].shape, + dtype=images_list[0].dtype, + ) + for idx, image in enumerate(images_list): + ret[key][camera_name][idx] = image + else: + ret[key][camera_name] = np.array([]) + + del images_list + if get_right_name(camera_name) in data_dict["observations"][key]: + images_right_list = data_dict["observations"][key].pop( + get_right_name(camera_name), None + ) + if images_right_list is None: + continue + if len(images_right_list) > 0: + ret[key][get_right_name(camera_name)] = np.empty( + (len(images_right_list),) + images_right_list[0].shape, + dtype=images_right_list[0].dtype, + ) + for idx, image in enumerate(images_right_list): + ret[key][get_right_name(camera_name)][idx] = image + else: + ret[key][get_right_name(camera_name)] = np.array([]) + del images_right_list + + return ret + + +def fetch_imitation_dataset( + env: BaseEnv, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + id: str, + folder_name: str, + save: bool = True, +) -> Dict: + """ + Save imitation dataset for a single episode. + + Args: + env (BaseEnv): Environment instance. + obs_list (List[Dict]): List of observation dicts. + action_list (List[Dict]): List of action dicts. + id (str): Unique identifier for the episode. + folder_name (str): Folder name for saving the dataset. + + Returns: + dict: Contains data_path, id, current_episode, and extracted data. + """ + # Get dataset save path + dataset_path = env.metadata["dataset"].get("save_path", None) + if dataset_path is None: + from embodichain.data import database_demo_dir + + dataset_path = database_demo_dir + + # Create folder if first episode + dataset_save_path = os.path.join(dataset_path, folder_name) + if env.curr_episode == 0 and id: + os.makedirs(dataset_save_path, exist_ok=True) + + # Check robot dof validity + try: + if isinstance(env, BaseEnv): + agent: LearnableRobot = env.get_agent() + max_dofs = len(agent.get_data_index(agent.uid)) + assert ( + env.metadata["dataset"]["robot_meta"]["arm_dofs"] <= max_dofs + ), f"Control dof {env.metadata['dataset']['robot_meta']['arm_dofs']} must be less than {max_dofs}." + else: + robot: Robot = env.robot + assert ( + env.metadata["dataset"]["robot_meta"]["arm_dofs"] <= robot.dof + ), f"Control dof {env.metadata['dataset']['robot_meta']['arm_dofs']} must be less than {robot.dof}." + except Exception as e: + logger.log_error(f"Robot DOF check failed: {e}") + return None + + # Select data format + data_format = DATA_FORMATS + + # Extract and save data + if id is None: + ret = DataDictExtractor(env).extract( + obs_list, action_list, save=False, data_dict=data_format + ) + save_path = None + else: + save_path = os.path.join(dataset_save_path, id + ".hdf5") + logger.log_info(f"Save episode {env.curr_episode} to '{save_path}'") + ret = DataDictExtractor(env, save_path).extract( + obs_list, action_list, save=save, data_dict=data_format + ) + + # Update episode count + env.curr_episode += 1 + + # Return result dict + return { + "data_path": dataset_save_path, + "id": id, + "current_episode": env.curr_episode, + "data": ret, + "save_path": save_path, + } diff --git a/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py new file mode 100644 index 00000000..fdf3ad7f --- /dev/null +++ b/embodichain/data/data_engine/datasets/sim_real_unified_dict_dataset.py @@ -0,0 +1,696 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import fnmatch +from embodichain.utils.logger import log_warning, log_info + +try: + import h5ffmpeg as hf +except Exception as e: + log_warning("Fail to import h5ffmpeg.") +import h5py +import numpy as np +from typing import Dict, Callable, List, Tuple +from embodichain.utils.utility import get_right_name, pad_to_chunk, convert_bytes +from embodichain.utils.logger import log_warning, log_info +from embodichain.data.enum import Proprioception, Image, Exteroception, ModalInput +from copy import deepcopy +from typing import Dict +from embodichain.utils.utility import timer +from embodichain.data.enum import ( + Modality, + PrivilegeType, + ArmEnum, + JointType, + CameraName, + TeleoperationData, + CobotMagicTeleoperationData, +) +from embodichain.data.data_engine.indices_unifier import ActionIndicesGenerator + +DEFAULT_ONLINE_DATASET_LEN = 10000 + + +class SimRealUnifiedDictDataset: + """Dataset class for unified simulation and real-world data. + + This class handles loading, parsing, and sampling from datasets that may be + either simulation or real-world HDF5 files. It supports both offline and online + data sources, and provides utilities for extracting and standardize state, + action, and image modalities. + + Args: + data_path (str): Path to the HDF5 dataset directory. + batch_size (int): Batch size for sampling. + chunk_size (int): Number of timesteps per sample. + state (List): List of state modalities to extract. + output (List): List of output modalities to extract. + data_meta (Dict): Metadata describing the dataset. + arm_type (ArmEnum): Type of robot arm. + img_history_size (int): Number of image frames in history. + state_history_len (int): Number of state frames in history. + precomp_lang_embed (bool, optional): Whether to use precomputed language embeddings. Defaults to True. + online_config (Dict, optional): Configuration for online data engine. Defaults to None. + camera_used (List[str], optional): List of camera names to use. Defaults to None. + indices_generator (ActionIndicesGenerator, optional): Generator for action/state indices. Defaults to None. + """ + + def __init__( + self, + data_path: str, + batch_size: int, + chunk_size: int, + state: List, + output: List, + data_meta: Dict, + arm_type: ArmEnum, + robot_name: str, + img_history_size: int, + state_history_len: int, + precomp_lang_embed: bool = True, + online_engine: Dict = None, + camera_used: List[str] = None, + indices_generator=None, + ) -> None: + """Initialize the SimRealUnifiedDictDataset.""" + # [Modify] The path to the HDF5 dataset directory + # Each HDF5 file contains one episode + + self.precomp_lang_embed = precomp_lang_embed + self.batch_size = batch_size + self.chunk_size = chunk_size + self.state = state + self.output = output + self.data_meta = data_meta + self.arm_type = arm_type + self.robot_name = robot_name + self.img_history_size = img_history_size + self.state_history_len = state_history_len + self.engine = online_engine + self.camera_used = camera_used + self.indices_generator = indices_generator + if self.camera_used is not None: + for cam in CameraName: + if cam.value not in camera_used: + log_warning( + "{} does not exist in {}".format(cam.value, camera_used) + ) + + if self.engine is not None: + self.DATASET_NAME = "online_whatever" + else: + log_info("Init offline vla dataset.", color="purple") + self.data_path = data_path + assert os.path.exists(self.data_path), "{} does not exist.".format( + self.data_path + ) + if os.path.isabs(self.data_path) is False: + self.data_path = os.path.join(os.getcwd(), self.data_path) + self.DATASET_NAME = os.path.basename(self.data_path) + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info( + f"Init dataset with size of: {len(self.file_paths)}", color="purple" + ) + + def update_data_size(self): + """Update the dataset size for validation datasets generated on the fly.""" + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info(f"Update dataset with size of: {len(self.file_paths)}", color="purple") + + def __len__(self): + """Return the number of episodes in the dataset. + + Returns: + int: Number of episodes. + """ + return ( + len(self.file_paths) if self.engine is None else DEFAULT_ONLINE_DATASET_LEN + ) + + def get_item(self, index: int = None, chunk_size: int = None): + """Get a training sample at a random timestep. + + Args: + index (int, optional): The index of the episode. If not provided, a random episode will be selected. + chunk_size (int, optional): Number of timesteps per sample. Defaults to self.chunk_size. + + Returns: + dict: A dictionary containing the training sample. + """ + chunk_size = self.chunk_size if chunk_size is None else chunk_size + while True: + if self.engine is None: + # offline + if index is None: + file_path = np.random.choice(self.file_paths) + else: + file_path = self.file_paths[index] + valid, sample = self.parse_hdf5_file(file_path, chunk_size) + else: + data_dict = self.engine.sample_data() + valid, sample = self.parse_dict(data_dict, chunk_size) + + if valid: + return sample + else: + if self.engine is None: + index = np.random.randint(0, len(self.file_paths)) + + @staticmethod + def parse_exteroception( + file: Dict, + step_id: int, + chunk_size: int, + camera_used: List[str] = [], + ) -> Exteroception: + """Parse exteroception data from the file. + + Args: + file (Dict): Data dictionary. + step_id (int): Starting timestep index. + chunk_size (int): Number of timesteps to extract. + camera_used (List[str], optional): List of cameras to use. + + Returns: + Exteroception: Parsed exteroception data. + """ + exteroception = [] + for cam in camera_used: + exteroception_full = file["observations"][ + PrivilegeType.EXTEROCEPTION.value + ][cam] + exteroception.append(exteroception_full[step_id : step_id + chunk_size]) + + exteroception = np.concatenate(exteroception, 1) + _, cs, kn, _ = exteroception.shape + exteroception = pad_to_chunk(exteroception, chunk_size) + return Exteroception( + data=exteroception.reshape(chunk_size, cs, kn, 2).transpose( + 1, 0, 2, 3 + ) # cs, chunk_size, kn, 2 + ) + + @staticmethod + def parse_img( + file: Dict, + step_id: int, + first_idx: int, + cam: str, + chunk_size: int, + key: str = Modality.IMAGES.value, + camera_used: List[str] = [], + np_ops: Callable = lambda x: x, + ) -> Image: + """Parse image data for a given camera. + + Args: + file (Dict): Data dictionary. + step_id (int): Current timestep index. + first_idx (int): First index for history. + cam (str): Camera name. + chunk_size (int): Number of timesteps to extract. + key (str, optional): Key for image modality. Defaults to Modality.IMAGES.value. + camera_used (List[str], optional): List of cameras to use. + np_ops (Callable, optional): Numpy operation to apply to images. + + Returns: + Image: Parsed image data. + """ + valid_len = min(step_id - (first_idx - 1) + 1, chunk_size) + cam_mask = np.array([False] * (chunk_size - valid_len) + [True] * valid_len) + if cam in camera_used: + temp = file["observations"][key][cam][0] + imgs = np.zeros((valid_len,) + temp.shape, dtype=temp.dtype) + for t, i in enumerate(range(max(step_id - chunk_size + 1, 0), step_id + 1)): + img = file["observations"][key][cam][i] + imgs[t] = img + imgs = np_ops(imgs) + imgs = pad_to_chunk(imgs, chunk_size=chunk_size) + mask = cam_mask.copy() + else: + imgs = np.zeros((chunk_size, 0, 0, 0)) + mask = np.zeros((chunk_size,), dtype=bool) + return Image(data=imgs, mask=mask, name=cam) + + def parse_hdf5_file(self, file_path, chunk_size: int) -> Dict[str, ModalInput]: + """Parse an HDF5 file and extract modalities. + + Args: + file_path (str): Path to the HDF5 file. + chunk_size (int): Number of timesteps to extract. + + Returns: + dict: Parsed modalities. + """ + import hdfdict + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + + with h5py.File(file_path, "r") as f: + data = hdfdict.load(f) + keyname = ( + JointType.QPOS.value + if SimRealUnifiedDictDataset.is_real_datasets(data) + else Modality.STATES.value + ) + step_id = SimRealUnifiedDictDataset.random_step_id( + data, chunk_size, keyname + ) + if not SimRealUnifiedDictDataset.is_real_datasets(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + else: + # Real data: if compressed structure is detected (containing *_index/*_start), also perform decoding filtering + try: + if CompressedVideoHDF5.is_compressed_hdf5(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + except Exception: + pass + ret = self.parse_dict(data, chunk_size, step_id) + + return ret + + @staticmethod + def random_step_id( + f: Dict, chunk_size: int, key: str = Modality.STATES.value + ) -> int: + """Randomly sample a timestep index. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + key (str, optional): Key for state modality. + + Returns: + int: Randomly selected timestep index. + """ + obs = f["observations"] + proprio = obs[key][:] + num_steps = proprio.shape[0] + # We randomly sample a timestep + first_idx = 1 + step_id = np.random.randint( + first_idx, np.maximum(first_idx + 1, num_steps - 1 - chunk_size) + ) + return step_id + + @staticmethod + def is_real_datasets(f: Dict): + """Check if the dataset is a real-world dataset. + + Args: + f (Dict): Data dictionary. + + Returns: + bool: True if real-world dataset, False if simulation. + """ + return "robot_meta" not in f.keys() + + def parse_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + if not SimRealUnifiedDictDataset.is_real_datasets(f): + log_warning("Using simulation hdf5 datasets.") + return self.parse_sim_dict(f, chunk_size, step_id) + else: + log_warning("Using real world offline hdf5 datasets.") + return self.parse_real_dict(f, chunk_size, step_id) + + def parse_real_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a real-world data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + ( + actions, + proprio, + meta, + camera_used_from_dualsys_to_real, + camera_used, + ) = RobotRealDataRouter(robot_name=self.robot_name).realdata2simdata( + f, + chunk_size, + given_camera_used=self.camera_used, + dataset_name=self.DATASET_NAME, + step_id=step_id, + ) + first_idx = 1 + parse_dict = self.parse_core(proprio, actions, step_id, chunk_size) + parse_dict.update({"meta": meta}) + for cam in self.camera_used: + if cam in camera_used: + parse_dict[cam] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + camera_used_from_dualsys_to_real[cam], + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used_from_dualsys_to_real[cam], + ) + else: + raise ValueError( + "cam name {} is not in all cam names {} for this datasets.".format( + cam, camera_used + ) + ) + return True, parse_dict + + @timer + def parse_sim_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + """Parse a simulation data dictionary and extract modalities. + + Args: + f (Dict): Data dictionary. + chunk_size (int): Number of timesteps to extract. + step_id (int, optional): Timestep index. + + Returns: + dict: Parsed modalities. + """ + if step_id is None: + step_id = SimRealUnifiedDictDataset.random_step_id(f, chunk_size) + + obs = f["observations"] + metadata = dict(f["robot_meta"]) + first_idx = 1 + + proprio = obs[Modality.STATES.value][:] + num_steps = proprio.shape[0] + min_len_step = metadata["min_len_steps"] + # [Optional] We drop too-short episode + if num_steps < min_len_step: + return False, None + + # We randomly sample a timestep + + camera_used = ( + convert_bytes(list(metadata["observation"]["vision"].keys())) + if self.camera_used is None + else self.camera_used + ) + + # Assemble the meta + meta = { + "dataset_name": self.DATASET_NAME, + "#steps": num_steps, + "step_id": step_id, + "instruction": "", + "camera_used": camera_used, + "instruction": ( + f["language_prompt"] if f.get("language_prompt", None) else "" + ), + } + + assert ( + self.indices_generator.dof == metadata["arm_dofs"] + ), "Train dof {} but dataset dof {}.".format( + self.indices_generator.dof, metadata["arm_dofs"] + ) + parse_dict = self.parse_core( + proprio, f[Modality.ACTIONS.value], step_id, chunk_size + ) + parse_dict.update({"meta": meta}) + + for cam in camera_used: + cam_r = get_right_name(cam) + if cam_r in obs[Modality.IMAGES.value] and cam_r not in camera_used: + # insert camera name after cam + camera_used.insert(camera_used.index(cam) + 1, cam_r) + + for cam in camera_used: + parse_dict[cam] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used, + ) + if PrivilegeType.MASK.value in self.data_meta.get("privileges", []): + parse_dict[cam + "_{}".format(PrivilegeType.MASK.value)] = ( + SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + PrivilegeType.MASK.value, + camera_used=camera_used, + ) + ) + if PrivilegeType.EXTEROCEPTION.value in self.data_meta.get("privileges", []): + if obs[PrivilegeType.EXTEROCEPTION.value][camera_used[0]].shape[0] != 0: + parse_dict[PrivilegeType.EXTEROCEPTION.value] = ( + SimRealUnifiedDictDataset.parse_exteroception( + f, + step_id, + chunk_size, + camera_used=camera_used, + ) + ) + + if Modality.GEOMAP.value in self.data_meta.get("additional_modality", []): + if ( + hasattr(obs[Modality.GEOMAP.value][camera_used[0]], "shape") + and obs[Modality.GEOMAP.value][camera_used[0]].shape[0] != 0 + ): + parse_dict[Modality.GEOMAP.value] = SimRealUnifiedDictDataset.parse_img( + f, + step_id, + first_idx, + CameraName.HEAD.value, + self.img_history_size, + Modality.GEOMAP.value, + camera_used=camera_used, + np_ops=lambda x: np.tile(np.expand_dims(x, -1), [1, 1, 1, 3]), + ) + + # Return the resulting sample + # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0) + # E.g., return np.zeros((self.img_history_size, 0, 0, 0)) for the key "cam_left_wrist", + # if the left-wrist camera is unavailable on your robot + return True, parse_dict + + def parse_core( + self, proprio: np.ndarray, actions: np.ndarray, step_id: int, chunk_size: int + ): + """Parse and normalize state and action data. + + Args: + proprio (np.ndarray): Proprioceptive state data. + actions (np.ndarray): Action data. + step_id (int): Current timestep index. + chunk_size (int): Number of timesteps to extract. + + Returns: + dict: Dictionary containing normalized state, action, and statistics. + """ + # Parse the state and action + state = proprio[np.maximum(step_id - self.state_history_len, 0) : step_id] + state = np.concatenate( + [np.tile(state[0:1], [self.state_history_len - state.shape[0], 1]), state], + 0, + ) + self.indices_generator: ActionIndicesGenerator + global_mapping = self.indices_generator.global_mapping + state_indices = global_mapping.get_indices( + convert_bytes(self.state), + ) + state_indicator = np.zeros_like(state, dtype=np.int8) + state_indicator[:, state_indices] = 1 + state *= state_indicator + proprio *= state_indicator[0:1] + state_std = np.std(proprio, axis=0) + state_mean = np.mean(proprio, axis=0) + state_norm = np.sqrt(np.mean(proprio**2, axis=0)) + action_indices = self.indices_generator.get( + self.output, + ) + actions = deepcopy(actions[step_id : step_id + chunk_size]) + delta_qpos_indices = self.indices_generator.get_all_delta_qpos( + handness=self.arm_type + ) + qpos_indices = self.indices_generator.get_all_qpos(handness=self.arm_type) + # NOTE: Ops `cumsum` equal to action[:horizon]-action[0:1]. + # TODO: action = action_chunk - current_obs. + actions[:, delta_qpos_indices] = ( + actions[:, qpos_indices] - state[-1:, qpos_indices] + ) + actions = pad_to_chunk(actions, chunk_size=chunk_size) + + action_indicator = np.zeros_like(actions, dtype=np.int8) + action_indicator[:, action_indices] = 1 + actions *= action_indicator[0:1] + + parse_dict = { + "state_std": state_std, + "state_mean": state_mean, + "state_norm": state_norm, + Modality.STATES.value: Proprioception(data=state, mask=state_indicator), + Modality.ACTIONS.value: Proprioception(data=actions, mask=action_indicator), + PrivilegeType.PROGRESS.value: step_id / proprio.shape[0], + } + return parse_dict + + +class RobotRealDataRouter: + def __init__(self, robot_name: str): + from embodichain.data.enum import ( + ControlParts, + EndEffector, + JointType, + ) + + assert robot_name in [ + "CobotMagic", + "DexforceW1", + ], "Robot type {} not supported.".format(robot_name) + self.robot_name = robot_name + + if robot_name == "CobotMagic": + self._REAL_SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, + ] + self.qpos_index_dict = { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: CobotMagicTeleoperationData.LEFT_ARM_QPOS_INDICES.value, + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: CobotMagicTeleoperationData.RIGHT_ARM_QPOS_INDICES.value, + ControlParts.LEFT_EEF.value + + EndEffector.GRIPPER.value: CobotMagicTeleoperationData.LEFT_EEF_GRIPPER_INDICES.value, + ControlParts.RIGHT_EEF.value + + EndEffector.GRIPPER.value: CobotMagicTeleoperationData.RIGHT_EEF_GRIPPER_INDICES.value, + } + self.arm_dofs = 12 + self.camera_used_from_real_to_dualsys = { + CameraName.LEFT_WRIST.value: CameraName.LEFT_WRIST.value, + CameraName.RIGHT_WRIST.value: CameraName.RIGHT_WRIST.value, + CameraName.HEAD.value: CameraName.HEAD.value, + } + elif robot_name == "DexforceW1": + self._REAL_SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.HEAD.value + JointType.QPOS.value, + ControlParts.WAIST.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ] + self.qpos_index_dict = { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: TeleoperationData.LEFT_ARM_QPOS_INDICES.value, + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: TeleoperationData.RIGHT_ARM_QPOS_INDICES.value, + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.LEFT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.RIGHT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.HEAD.value + + JointType.QPOS.value: TeleoperationData.HEAD_QPOS_INDICES.value, + ControlParts.WAIST.value + + JointType.QPOS.value: TeleoperationData.WAIST_QPOS_INDICES.value, + } + self.arm_dofs = 14 + self.camera_used_from_real_to_dualsys = { + "cam_hand_left": CameraName.LEFT_WRIST.value, + "cam_hand_right": CameraName.RIGHT_WRIST.value, + "cam_high_left": CameraName.HEAD.value, + } + + def realdata2simdata( + self, + f: Dict, + chunk_size: int = -1, + given_camera_used: List[str] = [], + dataset_name: str = "", + step_id: int = None, + ): + + from embodichain.data.data_engine.indices_unifier import ( + StateUnifier, + ) + + if step_id is None: + step_id = VLADataset.random_step_id(f, chunk_size, "qpos") + obs = f["observations"] + proprio = obs["qpos"][:] + num_steps = proprio.shape[0] + camera_used_in_real = list(obs[Modality.IMAGES.value].keys()) + camera_used_from_dualsys_to_real = { + val: key for key, val in self.camera_used_from_real_to_dualsys.items() + } + # Now assume it is from W1. + camera_used = [ + self.camera_used_from_real_to_dualsys[cam] + for cam in camera_used_in_real + if cam in self.camera_used_from_real_to_dualsys + ] + + # Assemble the meta + meta = { + "dataset_name": dataset_name, + "#steps": num_steps, + "step_id": step_id, + "camera_used": [ + cam_name for cam_name in given_camera_used if cam_name in camera_used + ], + "instruction": ( + f["language_prompt"] if f.get("language_prompt", None) else "" + ), + } + # save all supported proprio and action types. + robot_meta_config = {"arm_dofs": self.arm_dofs, "observation": {}} + + robot_meta_config["observation"][ + Modality.STATES.value + ] = self._REAL_SUPPORTED_PROPRIO_TYPES + robot_meta_config[Modality.ACTIONS.value] = self._REAL_SUPPORTED_PROPRIO_TYPES + state_unifier = StateUnifier(robot_meta=robot_meta_config) + + qpos_dict = {} + for key, indices in self.qpos_index_dict.items(): + qpos_dict[key] = proprio[:, indices] + actions = state_unifier.fill_in_action(qpos_dict) + proprio = state_unifier.fill_in_state(qpos_dict) + return actions, proprio, meta, camera_used_from_dualsys_to_real, camera_used diff --git a/embodichain/data/data_engine/indices_unifier.py b/embodichain/data/data_engine/indices_unifier.py new file mode 100644 index 00000000..3cfd025b --- /dev/null +++ b/embodichain/data/data_engine/indices_unifier.py @@ -0,0 +1,395 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.data.global_indices import ( + GLOBAL_INDICES, + STATE_VEC_LEN, +) +from embodichain.data.global_mapping import GlobalMapping +import numpy as np +from typing import List, Dict, Tuple, Union +from embodichain.data.enum import ( + ArmEnum, + Modality, + JointType, + ActionMode, + EefType, + ControlParts, + EndEffector, + Modality, +) +from embodichain.utils.logger import log_info, log_warning + +DEFAULT_EMPTY_STATE = -1 + +__all__ = ["StateUnifier", "ActionIndicesGenerator"] + +"""Unified state utilities for EmbodiChain. + +This module provides helpers to construct and query a unified state/action +vector representation used across EmbodiChain environments and agents. + +Classes: + StateUnifier: Fill sparse per-modality state/action dictionaries into a + fixed-length unified state vector where unspecified entries are set + to a sentinel value (DEFAULT_EMPTY_STATE). + + ActionIndicesGenerator: Query index ranges in the unified vector for + common action/state groups (e.g. qpos, delta qpos, end-effector pose). + +Constants: + DEFAULT_EMPTY_STATE (int): Sentinel value used to mark unspecified + entries in the unified vector. +""" + + +class StateUnifier: + """Convert per-modality state/action arrays into a unified vector. + + The StateUnifier is constructed with ``robot_meta`` (the robot's + metadata) which should contain an ``observation`` mapping with keys for + modalities (e.g. ``Modality.STATES``) and an ``actions`` specification. + + Attributes: + metadata (dict): Robot metadata passed at construction. + arm_dofs (int): Degrees of freedom for the arm (default: 12). + indices_generator (ActionIndicesGenerator): Helper for action indices. + proprio_meta: Metadata list for proprioceptive modalities. + global_mapping (GlobalMapping): Mapping from names to unified indices. + output: Action output specification from metadata. + state_dim (int): Fixed length of the unified state vector. + """ + + def __init__(self, robot_meta: Dict) -> None: + assert "arm_dofs" in robot_meta + assert "observation" in robot_meta + assert Modality.ACTIONS.value in robot_meta + + self.arm_dofs = robot_meta["arm_dofs"] + self.indices_generator = ActionIndicesGenerator(self.arm_dofs) + self.proprio_meta = robot_meta["observation"][Modality.STATES.value] + self.global_mapping = GlobalMapping(self.arm_dofs) + self.output = robot_meta[Modality.ACTIONS.value] + + self.state_dim = STATE_VEC_LEN + + def fill_in_state( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified state vector from given values. + + Args: + values (np.ndarray or dict): If ``values`` is a numpy array it is + assumed to already be aligned to the unified layout and will + be placed into the output container. If it is a ``dict``, + keys should match entries from the robot metadata + ``observation[Modality.STATES]`` and values are numpy arrays + with a trailing dimension matching each state's width. + + Returns: + np.ndarray: An array with shape ``(..., STATE_VEC_LEN)`` containing + the unified state with unspecified entries set to + ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.global_mapping.get_indices(self.proprio_meta) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for val in values.values(): + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for state_name in self.proprio_meta: + state_indices = self.global_mapping.get_indices([state_name]) + if values[state_name].size != 0: + uni_vec[..., state_indices] = values[state_name] + + return uni_vec + + def fill_in_action( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified action vector from given action values. + + This mirrors :meth:`fill_in_state` but uses the metadata's action + output specification to determine which named outputs map into the + unified vector. + + Args: + values (np.ndarray or dict): Action values aligned to the unified + layout or a mapping from output names to numpy arrays. + + Returns: + np.ndarray: Unified vector shaped ``(..., STATE_VEC_LEN)`` with + unspecified entries filled with ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.indices_generator.get(self.output) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for key, val in values.items(): + + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for out_name in self.output: + state_indices = self.global_mapping.get_indices([out_name]) + if out_name in values and values[out_name].size != 0: + uni_vec[..., state_indices] = values[out_name] + return uni_vec + + +class ActionIndicesGenerator: + """Utility for generating index lists for action/state groups. + + The ActionIndicesGenerator wraps :class:`GlobalMapping` to provide + common queries like retrieving indices for all joint positions (qpos), + delta qpos (relative mode), end-effector transforms/poses, and + hand-specific selections (left/right/both). + + Args: + dof (int, optional): If provided, a :class:`GlobalMapping` is + constructed and reused for queries. + """ + + def __init__(self, dof: int = None): + self.global_mapping = None + self.dof = dof + if dof is not None: + self.global_mapping = GlobalMapping(dof) + + def get_all_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices covering all joint position entries. + + Args: + dof (int, optional): Degrees of freedom to construct a temporary + :class:`GlobalMapping` if the generator was not initialized + with a ``dof``. + handness (str): One of values from :class:`ArmEnum` specifying + which arm(s) to include. + + Returns: + List[int]: Ordered list of indices in the unified vector + corresponding to qpos entries for the requested arm + selection. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [qpos_name], [delta_qpos_name]) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + inv_handness = ControlParts.RIGHT_ARM.value + return self.get( + all_names, dof, [qpos_name], [delta_qpos_name, inv_handness + qpos_name] + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + inv_handness = ControlParts.LEFT_ARM.value + return self.get( + all_names, dof, [qpos_name], [delta_qpos_name, inv_handness + qpos_name] + ) + + def get_all_delta_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices for delta (relative) joint position entries. + + Args and return are the same as :meth:`get_all_qpos` but select the + ``ActionMode.RELATIVE`` named entries. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [delta_qpos_name], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + inv_handness = ControlParts.RIGHT_ARM.value + return self.get( + all_names, dof, [delta_qpos_name], [inv_handness + delta_qpos_name] + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + inv_handness = ControlParts.LEFT_ARM.value + return self.get( + all_names, dof, [delta_qpos_name], [inv_handness + delta_qpos_name] + ) + + def get_all_eef( + self, + dof: int = None, + eef_effector: str = "", + handness: str = ArmEnum.DUAL_ARM.value, + ) -> List[int]: + """Retrieves the indices of all end-effectors (EEF) based on the specified parameters. + + Args: + dof (int, optional): Degree of freedom to use for mapping. If None, uses default. + eef_effector (str, optional): Type of end-effector. Must be one of + EndEffector.DEXTROUSHAND.value, EndEffector.GRIPPER.value, or "" (empty string). + handness (str, optional): Specifies which arm(s) to consider. Must be one of + ArmEnum.DUAL_ARM.value, ArmEnum.LEFT_ARM_ONLY.value, or ArmEnum.RIGHT_ARM_ONLY.value. + + Returns: + List[int]: List of indices corresponding to the selected end-effectors. + + Raises: + AssertionError: If an invalid end-effector type is provided. + """ + assert eef_effector in [ + EndEffector.DEXTROUSHAND.value, + EndEffector.GRIPPER.value, + "", + ], "Invalid end-effector effector type {}.".format(eef_effector) + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get( + all_names, + dof, + [ + ControlParts.LEFT_EEF.value + eef_effector, + ControlParts.RIGHT_EEF.value + eef_effector, + ], + [], + ) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_EEF.value + return self.get( + all_names, + dof, + [handness + eef_effector], + [], + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_EEF.value + return self.get( + all_names, + dof, + [handness + eef_effector], + [], + ) + + def get_all_eef_pose( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices specifically for EEF pose entries. + + Args: + dof (int, optional): Degrees of freedom for mapping lookup. + handness (str): Which arm(s) to include (left/right/both). + + Returns: + List[int]: Indices corresponding to EEF poses. + """ + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [EefType.POSE.value], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + + def get_mapping(self, dof: int = None): + """Return the :class:`GlobalMapping` used by this generator. + + If a mapping was created during initialization (because ``dof`` was + provided), ensure any provided ``dof`` argument matches it. Otherwise + construct and return a temporary :class:`GlobalMapping` for the + requested ``dof``. + + Args: + dof (int, optional): Degrees of freedom to construct a mapping + if one was not provided at initialization. + + Returns: + GlobalMapping: Mapping instance for name->index lookups. + """ + if self.global_mapping is not None: + assert dof is None or dof == self.dof + global_mapping = self.global_mapping + else: + assert ( + dof is not None + ), "Dof must be set when dof is not provided in initialization." + global_mapping = GlobalMapping(dof) + return global_mapping + + def get( + self, + output: List[str], + dof: int = None, + white_list: List[str] = None, + black_list: List[str] = None, + ) -> List[int]: + """Select and return indices from ``output`` names applying optional + white/black list filters. + + Args: + output (List[str]): Names (keys) in a :class:`GlobalMapping` + whose indices should be collected. + dof (int, optional): Degrees of freedom used to construct a + temporary :class:`GlobalMapping` if needed. + white_list (List[str], optional): If provided, only include names + that contain any of these substrings. + black_list (List[str], optional): If provided, exclude names + that contain any of these substrings. + + Returns: + List[int]: Ordered list of unified-vector indices for the + selected names. + """ + + action_indices = [] + global_mapping = self.get_mapping(dof) + + for action_type in output: + if isinstance(white_list, list) and isinstance(black_list, list): + if any([temp in action_type for temp in white_list]) and all( + [temp not in action_type for temp in black_list] + ): + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + else: + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + + return action_indices # keep order. diff --git a/embodichain/data/data_engine/online/engine.py b/embodichain/data/data_engine/online/engine.py new file mode 100644 index 00000000..fbdba4be --- /dev/null +++ b/embodichain/data/data_engine/online/engine.py @@ -0,0 +1,513 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import time +import sys +import numpy as np +from threading import Thread +from typing import Dict, Tuple, Any, List, Callable, Optional +from copy import deepcopy +import threading +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +import torch +import torch.multiprocessing as mp +import copy +from embodichain.utils.logger import ( + log_info, + log_warning, + decorate_str_color, + log_debug, +) + +# Must call cuda init to prevent cuda error in subprocess. +torch._C._cuda_init() + +from dexsim.utility import NumpyRNG + +import threading +from multiprocessing import shared_memory +import pickle +from datetime import datetime +import zmq + +rng = NumpyRNG.get_rng() + +log_info_produce = lambda x: log_info(decorate_str_color(x, "cyan")) +log_info_consume = lambda x: log_info(decorate_str_color(x, "orange")) + +MAX_LOOP_TIMES = 40000 + + +def init_context(port): + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://localhost:{}".format(port)) + return socket + + +class DataPoolCont: + data: Any + count: int = 0 + tag: str + + @staticmethod + def from_list(data_pool: List[Dict]) -> List["DataPoolCont"]: + ret = [] + for data in data_pool: + dcnt = DataPoolCont() + dcnt.data = data + dcnt.count = 0 + dcnt.tag = str(datetime.now()).split(".")[0] + ret.append(dcnt) + return ret + + @staticmethod + def clean_data_pool_in_place( + data_pool: List["DataPoolCont"], clean_indices: List[int] + ): + if clean_indices is None: + data_pool = [] + else: + if len(clean_indices) > 0: + log_debug( + "Clean data pool with data indices {}, counts {}.".format( + clean_indices, + [data_pool[index].count for index in clean_indices], + ), + color="purple", + ) + for i in list(np.sort(clean_indices)[::-1]): + data_pool.pop(i) + + +def fetch_data( + queue_data: mp.Queue, data_pool: List[DataPoolCont], worker_info, debug: bool = True +) -> bool: + start_time = time.time() + try: + existing_shm = queue_data.get(timeout=5) + except Exception as error: + log_debug("Timeout! {}.".format(str(error)), color="red") + return False + log_debug( + "[Thread {}][Worker {}][Get] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + start_time = time.time() + scene_data = pickle.loads(existing_shm.buf[:]) + log_debug( + "[Thread {}][Worker {}][Pickle] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + if np.random.random() > 0.5 or queue_data.qsize() == 0: + start_time = time.time() + queue_data.put(existing_shm) # put back + log_debug( + "[Thread {}][Worker {}][Put] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + assert isinstance(scene_data, list), "Invalid data format {}.".format( + type(scene_data) + ) + start_time = time.time() + data = DataPoolCont.from_list(scene_data) + data_pool.extend(data) + + log_debug( + "[Thread {}][Worker {}][Other] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + return True + + +class RestockCriterion: + def __init__(self, data_pool_limit: int, buffer_size: int, max_sample_num: int): + self.data_pool_limit = data_pool_limit + self.buffer_size = buffer_size + self.max_sample_num = max_sample_num + + def restock_condition(self, data_pool: List, queue: mp.Queue) -> bool: + return len(data_pool) < self.data_pool_limit + + def expired_condition( + self, data_pool: List[DataPoolCont], inverse: bool = False + ) -> List[bool]: + + if len(data_pool) == 0: + return [] + + if inverse: + return [data.count <= self.max_sample_num for data in data_pool] + else: + return [data.count > self.max_sample_num for data in data_pool] + + +class OnlineEngine: + """Data manager for online data production and training. + + The objectives of this class are: + - Manage the fetch data in a separate thread. + - Perform data synchronization between the data production process and + the training process (main process). + - Provide data sampling interface for the training process, which is designed + to return a batch of synthetic data with the different scene id. + - Data lifecycle management. + + To achieve the above objectives, the following functions should be implemented: + - from_shm_thread (static method) + + Args: + insight_config (List[CfgNode]): The config of insight pipeline. + episode_limit (int, optional): The maximum number of frames in the data pool. Defaults to 24. + max_sample_num (int, optional): The maximum number of times that a data can be sampled. + Defaults to 2. + target_device (torch.device, optional): The target device of the data. Defaults to torch.device('cpu'). + annos_param (Dict[str, Any], optional): The parameters of the annotations. Defaults to None. + data_gen_func (Callable, optional): The data generation function. Defaults to None. + unique_scene_frame (int, optional): The number of unique scene frame to be sampled. Defaults to None. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + + def __init__( + self, + episode_limit: int = 24, + max_sample_num: int = 2, + port: int = 5555, + buffer_size: int = 10, + multiprocess: bool = False, + **kwargs, + ) -> None: + + self.episode_limit = episode_limit + self._max_sample_num = max_sample_num + self.port = port + + self._data_pool = [] + + self._duration = 0.01 + + self._context = mp.get_context("forkserver") + + self._queue_data = self._context.Queue() + self._queue_data.cancel_join_thread() + + self.buffer_size = buffer_size + + self._data_gen_proc = None + self._fetch_data_thread = None + self._restock_data_pool = None + + self._is_started = False + self._is_restocked = False + self._socket = init_context(port + 1 if multiprocess else port) + + self._restock_criterion = RestockCriterion( + data_pool_limit=episode_limit, + buffer_size=buffer_size, + max_sample_num=max_sample_num, + ) + self._lock = threading.RLock() + + def start( + self, + ) -> None: + """Start the data production process and the data synchronization thread. + + Args: + wait_for_limit (bool, optional): Whether to wait for the data pool to reach + the frame limit. Defaults to False. + """ + + self._signal_gen = self._context.Value("b", True) + self._signal_fetch = self._context.Value("b", True) + + self._fetch_data_thread = Thread( + target=self.from_shm_thread, + args=( + self._socket, + self._queue_data, + self._duration, + self.buffer_size, + ), + daemon=True, + ) + self._fetch_data_thread.start() + self._is_started = True + log_info( + "Now start the thread to fetch data from share memory.", color="purple" + ) + + def start_restock(self, static: bool = False): + if static: + self._restock_data_pool = Thread( + target=self.restock_data_pool_static, + args=( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ), + daemon=True, + ) + else: + self._restock_data_pool = Thread( + target=self.restock_data_pool, + daemon=True, + ) + + self._restock_data_pool.start() + self._is_restocked = True + + def stop(self) -> None: + if self.is_started: + self._is_started = False + self._signal_fetch.value = 2 + self._fetch_data_thread.join() + self.empty_queue(self._queue_data, self._context) + self.clean_data_pool_in_place() + self._signal_gen.value = 2 + else: + log_info( + "The data generation process has not been started.", color="purple" + ) + + @property + def is_started(self) -> bool: + return self._is_started + + @property + def data_size(self) -> int: + with self._lock: + return len(self._data_pool) + + @property + def queue_size(self) -> int: + return self._queue.qsize() + + @property + def unique_scene_frame(self) -> int: + return self._unique_scene_frame + + @staticmethod + def empty_queue(queue: mp.Queue, context: mp) -> None: + while queue.qsize() > 0: + try: + queue.get() + except Exception as e: + log_info("queue put invaild data format") + queue.close() + queue.join_thread() + queue = context.Queue() + break + return queue + + @staticmethod + def empty_share_memory(queue: mp.Queue) -> None: + while queue.qsize() > 0: + shm_name = queue.get() + shm = shared_memory.SharedMemory(shm_name) + shm.close() + shm.unlink() + + def restock_data_pool(self): + return OnlineEngine.restock_data_pool_static( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ) + + @staticmethod + def restock_data_pool_static( + data_pool: List[DataPoolCont], + queue_data: mp.Queue, + duration: float, + restock_criterion: RestockCriterion, + context, + thread_lock, + ): + counts = 0 + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + while True: + time.sleep(duration) + # always clean the data pool first. + + start_time = time.time() + with thread_lock: + # delete + clean_indices = list( + np.argwhere(restock_criterion.expired_condition(data_pool)).reshape( + -1 + ) + ) + DataPoolCont.clean_data_pool_in_place( + data_pool, + clean_indices, + ) + if len(clean_indices) > 0: + log_debug( + "[Thread {}][Delete][Cost {}s]".format( + threading.current_thread().ident, time.time() - start_time + ) + ) + + # after clean, we check whether to restock data. + while restock_criterion.restock_condition(data_pool, queue_data): + + prev_data_size = len(data_pool) + should_fetch = False + for i in range(worker_info.num_workers): + if queue_data.qsize() > 0 and worker_info.id == i: + should_fetch = True + if should_fetch: + start_time = time.time() + with thread_lock: + # add + fetch_data( + data_pool=data_pool, + queue_data=queue_data, + worker_info=worker_info, + ) + log_debug( + "[Thread {}][Worker {}][ToDataPool] Produce data: {}->{}. Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + prev_data_size, + len(data_pool), + time.time() - start_time, + ) + ) + counts = 0 + else: + counts += 1 + + if counts % MAX_LOOP_TIMES == 0 and counts != 0: + log_info("Can not find the shm after {} times.".format(counts)) + # queue_data = OnlineEngine.empty_queue(queue_data, context) + + @staticmethod + def from_shm_thread( + socket, + queue_data: mp.Queue, + duration: float = 0.001, + buffer_size: int = 10, + ) -> None: + """The data fetching thread for data synchronization. + + The queue_data_size is used to control the data fetching thread. + If queue_data_size < buffer_size, the data fetching thread will fetch data from the queue. + If queue_data_size >= buffer_size, the data fetching thread will stop fetch data. + + Args: + socket (zmq.Context): The socket send signal for connect fetch and generator. + queue_data (mp.Queue): This queue contains information about shared memory. + duration (float, optional): _description_. Defaults to 0.001. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + counts = 0 + while True: + time.sleep(duration) + counts += 1 + if queue_data.qsize() < buffer_size: + socket.send_string(ConsumerTeleEnum.SHAKEHAND.value) + message = socket.recv() + try: + message_str = message.decode() + except Exception as e: + log_debug(str(e), color="red") + message_str = "" + if message_str != ProducerTeleEnum.NOREADY.value: + log_debug("Receive data.", color="purple") + shm_name = pickle.loads(message).popleft() + existing_shm = shared_memory.SharedMemory(name=shm_name) + queue_data.put(existing_shm) + log_debug( + "[FromShmThread] Produce queue: {}->{};".format( + queue_data.qsize() - 1, queue_data.qsize() + ) + ) + else: + if counts % MAX_LOOP_TIMES == 0: + log_debug("Queue is full. Skip this stage.", "purple") + + def sample_data( + self, + ): + + if self._is_restocked: + pass + else: + log_debug("Now start the thread to restock data.", color="purple") + self.start_restock(static=False) + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + counts = 0 + while True: + time.sleep(self._duration) + if len(self._data_pool) > 0: + start_time = time.time() + with self._lock: + index = rng.integers(0, len(self._data_pool)) + data = deepcopy(self._data_pool[index]) + self._data_pool[index].count += 1 + log_debug( + "[SampleData, worker {}] Consume data {}: index {}; times: {}->{}; Show queue size: {}; Cost time: {}s.".format( + worker_info.id, + data.tag, + index, + data.count, + data.count + 1, + self._queue_data.qsize(), + np.round(time.time() - start_time, 4), + ) + ) + counts = 0 + return data.data + else: + counts += 1 + if counts % MAX_LOOP_TIMES == 0: + log_info("Data pool is always empty after {} times.".format(counts)) diff --git a/embodichain/data/data_engine/online/enum.py b/embodichain/data/data_engine/online/enum.py new file mode 100644 index 00000000..fe521bdb --- /dev/null +++ b/embodichain/data/data_engine/online/enum.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from enum import Enum + + +class ConsumerTeleEnum(Enum): + SHAKEHAND = "Data is ready?" + CONSUME = "Fetch data!" + NOCONSUME = "Data_pool is full." + GOTDATA = "Feched data!" + NOGOTDATA = "Not fetching data." + + +class ProducerTeleEnum(Enum): + READY = "Yes" + NOREADY = "No ready" + FULL = "Data_pool is full" + FAIL = "Failed" + SEND = "Send!" + EMPTYSTR = "Empty String." diff --git a/embodichain/data/data_engine/online/online_generator.py b/embodichain/data/data_engine/online/online_generator.py new file mode 100644 index 00000000..b1364f4f --- /dev/null +++ b/embodichain/data/data_engine/online/online_generator.py @@ -0,0 +1,181 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import time +import zmq +import random +from multiprocessing import shared_memory +import pickle +from collections import deque +from typing import List +from threading import Thread +import multiprocessing as mp +import traceback +from embodichain.utils.logger import log_info, log_warning, log_error, log_debug +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +torch._C._cuda_init() + + +class OnlineGenerator: + """Callback collection for online training mode.""" + + def __init__( + self, port: int, max_limit_gb: int = 50, multiprocess: bool = False, **kwargs + ) -> None: + self.shm_val = None + max_limit = max_limit_gb * 1024**3 + self._context = mp.get_context("forkserver") + self.port = port + self.socket = self.init_context(self.port, multiprocess) + self._duration = 0.01 + self.queue = deque() + self.queue_memroy = deque() + self.max_limit = max_limit + + self.validation_config = kwargs.get("validation", {}) + + def get_validation_config(self): + return self.validation_config + + def init_context(self, port, multiprocess: bool = False): + context = zmq.Context() + socket = context.socket(zmq.REP) + if multiprocess: + socket.connect(f"tcp://127.0.0.1:{port}") + else: + socket.bind(f"tcp://*:{port}") + + return socket + + def generator(self, generate_func, loop_times: int = -1, **kwargs): + self.signal = self._context.Value("b", True) + + self._zmq_send = Thread( + target=self.zmq_send, args=(self.queue, self.signal), daemon=True + ) + self._zmq_send.start() + log_debug("Start zmq sending.") + scene_id = 0 + + # -1 means infinite loop + while scene_id < loop_times or loop_times == -1: + if self.signal.value == 1: + first_time = True + try: + t0 = time.time() + return_list = generate_func( + time_id=scene_id, **self.validation_config + ) + + # TODO: support multiple trajectories for each scene. + if len(return_list) > 1: + log_error( + "Only support one trajectory for each scene in online generation mode." + ) + + data_dict_list = [return_list[0]["data"]] + + if ( + scene_id == 0 + and self.validation_config.get("num_samples", 0) > 0 + and "data_path" in return_list[0] + ): + # create shared memory to store the validation dataset path, which will be accessed by training process. + import sys + + data_path = return_list[0]["data_path"] + + shared_name = self.validation_config.get( + "dataset_name", "val_data_path" + ) + log_info( + f"Create shared memory for validation data path: {shared_name}", + color="green", + ) + self.shm_val = shared_memory.SharedMemory( + name=shared_name, + create=True, + size=len(data_path.encode()) + sys.getsizeof(""), + ) + self.shm_val.buf[: len(data_path.encode())] = data_path.encode() + log_info( + f"Craete shared memory for validation data path: {data_path}" + ) + + log_info( + f"Generate scene {scene_id + 1} time cost: {time.time() - t0}" + ) + serialized_data = pickle.dumps(data_dict_list) + shm = shared_memory.SharedMemory( + create=True, size=len(serialized_data) + ) + self.queue.append(shm.name) + self.queue_memroy.append( + {"name": shm.name, "size": len(serialized_data)} + ) + shm.buf[: len(serialized_data)] = serialized_data + except Exception as e: + log_error(f"Error in data generation process: {e}.") + traceback.print_exc() + self._zmq_send.join() + break + scene_id += 1 + self.empty_memory() + elif self.signal.value == 0: + if first_time: + log_warning("zmq recive full signal, wait generator signal") + first_time = False + log_warning("Signal value is 0.") + time.sleep(self._duration) + continue + else: + log_error("Unknown signal, data generator stop") + break + + def zmq_send(self, queue, signal): + while True: + try: + message = self.socket.recv_string() + if message == ConsumerTeleEnum.SHAKEHAND.value: + if len(queue) > 0: + log_warning( + "Recieve {} and send [data] to consumer.".format(message) + ) + self.socket.send(pickle.dumps(queue)) + queue.clear() + else: + self.socket.send(ProducerTeleEnum.NOREADY.value.encode()) + signal.value = 1 + except Exception as e: + print(e) + traceback.print_exc() + break + + def empty_memory(self): + total_size = sum([x["size"] for x in self.queue_memroy]) + log_info(f"share memory size is {total_size/(1024**3)} GB") + while total_size >= self.max_limit: + shm_name = self.queue_memroy.popleft() + if shm_name["name"] in self.queue: + log_info(f"remove {shm_name['name']} from queue") + self.queue.remove(shm_name["name"]) + try: + shm = shared_memory.SharedMemory(shm_name["name"]) + except: + continue + shm.close() + shm.unlink() + total_size = sum([x["size"] for x in self.queue_memroy]) + + def __del__(self): + if self.shm_val: + self.shm_val.close() + self.shm_val.unlink() diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py index cf758192..b629a2f1 100644 --- a/embodichain/data/enum.py +++ b/embodichain/data/enum.py @@ -34,6 +34,18 @@ class SemanticMask(IntEnum): ROBOT = 2 +class Modality(Enum): + STATES = "states" + STATE_INDICATOR = "state_indicator" + ACTIONS = "actions" + ACTION_INDICATOR = "action_indicator" + IMAGES = "images" + LANG = "lang" + LANG_INDICATOR = "lang_indicator" + GEOMAP = "geomap" # e.g., depth, point cloud, etc. + VISION_LANGUAGE = "vision_language" # e.g., image + lang + + class EndEffector(Enum): GRIPPER = "gripper" DEXTROUSHAND = "hand" @@ -53,6 +65,16 @@ class ControlParts(Enum): WAIST = "waist" +class ControlPartsMappingW1(Enum): + ANKLE_IN_TORSO = 0 + KNEE_IN_TORSO = 1 + BUTTOCK_IN_TORSO = 2 + WAIST_IN_TORSO = 3 + + NECK1_IN_HEAD = 0 + NECK2_IN_HEAD = 1 + + class Hints(Enum): EEF = ( ControlParts.LEFT_EEF.value, @@ -74,3 +96,118 @@ class EefType(Enum): class ActionMode(Enum): ABSOLUTE = "" RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. + + +class PrivilegeType(Enum): + EXTEROCEPTION = "exteroception" + MASK = "mask" + STATE = "state" + PROGRESS = "progress" + + +SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + EefType.POSE.value, + ControlParts.RIGHT_ARM.value + EefType.POSE.value, + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.HEAD.value + JointType.QPOS.value, + ControlParts.WAIST.value + JointType.QPOS.value, +] +SUPPORTED_ACTION_TYPES = SUPPORTED_PROPRIO_TYPES + [ + ControlParts.LEFT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, +] +SUPPORTED_EXTRA_VISION_TYPES = [ + Modality.GEOMAP.value, + PrivilegeType.EXTEROCEPTION.value, + PrivilegeType.MASK.value, +] + + +class ArmEnum(IntEnum): + LEFT_ARM_ONLY = 1 + RIGHT_ARM_ONLY = 2 + DUAL_ARM = 3 + + +class ArmName(Enum): + LEFT_ARM_ONLY = "left_arm" + RIGHT_ARM_ONLY = "right_arm" + + +def is_dual_arms(dofs: int) -> bool: + return dofs > 10 + + +class HandQposNormalizer: + """ + A class for normalizing and denormalizing dexterous hand qpos data. + """ + + def __init__(self): + pass + + @staticmethod + def normalize_hand_qpos( + qpos_data: np.ndarray, + key: str, + agent=None, + robot=None, + ) -> np.ndarray: + """ + Clip and normalize dexterous hand qpos data. + + Args: + qpos_data: Raw qpos data + key: Control part key + agent: LearnableRobot instance (for V2 API) + robot: Robot instance (for V3 API) + + Returns: + Normalized qpos data in range [0, 1] + """ + if isinstance(qpos_data, torch.Tensor): + qpos_data = qpos_data.cpu().numpy() + + if agent is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + return qpos_data + indices = agent.get_data_index(key, warning=False) + full_limits = agent.get_joint_limits(agent.uid) + limits = full_limits[indices] # shape: [num_joints, 2] + elif robot is not None: + if key not in [ + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ]: + if key in [ControlParts.LEFT_EEF.value, ControlParts.RIGHT_EEF.value]: + # Note: In V3, robot does not distinguish between GRIPPER EEF and HAND EEF in uid, + # _data_key_to_control_part maps both to EEF. Under current conditions, normalization + # will not be performed. Please confirm if this is intended. + pass + return qpos_data + indices = robot.get_joint_ids(key, remove_mimic=True) + limits = robot.body_data.qpos_limits[0][indices] # shape: [num_joints, 2] + else: + raise ValueError("Either agent or robot must be provided") + + if isinstance(limits, torch.Tensor): + limits = limits.cpu().numpy() + + qpos_min = limits[:, 0] # Lower limits + qpos_max = limits[:, 1] # Upper limits + + # Step 1: Clip to valid range + qpos_clipped = np.clip(qpos_data, qpos_min, qpos_max) + + # Step 2: Normalize to [0, 1] + qpos_normalized = (qpos_clipped - qpos_min) / (qpos_max - qpos_min + 1e-8) + + return qpos_normalized diff --git a/embodichain/data/global_indices.py b/embodichain/data/global_indices.py new file mode 100644 index 00000000..b7cb8dcf --- /dev/null +++ b/embodichain/data/global_indices.py @@ -0,0 +1,122 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np + +GLOBAL_INDICES = { + # [0, 10): right arm joint positions + **{"arm_joint_{}_pos".format(i): i for i in range(10)}, + **{"right_arm_joint_{}_pos".format(i): i for i in range(10)}, + # [10, 15): right gripper joint positions + **{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + **{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + "gripper_open": 10, # alias of right_gripper_joint_0_pos + "right_gripper_open": 10, + # [15, 25): right arm joint velocities + **{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + **{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + # [25, 30): right gripper joint velocities + **{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + **{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + "gripper_open_vel": 25, # alias of right_gripper_joint_0_vel + "right_gripper_open_vel": 25, + # [30, 33): right end effector positions + "eef_pos_x": 30, + "right_eef_pos_x": 30, + "eef_pos_y": 31, + "right_eef_pos_y": 31, + "eef_pos_z": 32, + "right_eef_pos_z": 32, + # [33, 39): right end effector 6D pose + "eef_angle_0": 33, + "right_eef_angle_0": 33, + "eef_angle_1": 34, + "right_eef_angle_1": 34, + "eef_angle_2": 35, + "right_eef_angle_2": 35, + "eef_angle_3": 36, + "right_eef_angle_3": 36, + "eef_angle_4": 37, + "right_eef_angle_4": 37, + "eef_angle_5": 38, + "right_eef_angle_5": 38, + # [39, 42): right end effector velocities + "eef_vel_x": 39, + "right_eef_vel_x": 39, + "eef_vel_y": 40, + "right_eef_vel_y": 40, + "eef_vel_z": 41, + "right_eef_vel_z": 41, + # [42, 45): right end effector angular velocities + "eef_angular_vel_roll": 42, + "right_eef_angular_vel_roll": 42, + "eef_angular_vel_pitch": 43, + "right_eef_angular_vel_pitch": 43, + "eef_angular_vel_yaw": 44, + "right_eef_angular_vel_yaw": 44, + # [45, 50): reserved + # [50, 60): left arm joint positions + **{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)}, + # [60, 65): left gripper joint positions + **{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)}, + "left_gripper_open": 60, # alias of left_gripper_joint_0_pos + # [65, 75): left arm joint velocities + **{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)}, + # [75, 80): left gripper joint velocities + **{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)}, + "left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel + # [80, 83): left end effector positions + "left_eef_pos_x": 80, + "left_eef_pos_y": 81, + "left_eef_pos_z": 82, + # [83, 89): left end effector 6D pose + "left_eef_angle_0": 83, + "left_eef_angle_1": 84, + "left_eef_angle_2": 85, + "left_eef_angle_3": 86, + "left_eef_angle_4": 87, + "left_eef_angle_5": 88, + # [89, 92): left end effector velocities + "left_eef_vel_x": 89, + "left_eef_vel_y": 90, + "left_eef_vel_z": 91, + # [92, 95): left end effector angular velocities + "left_eef_angular_vel_roll": 92, + "left_eef_angular_vel_pitch": 93, + "left_eef_angular_vel_yaw": 94, + # [95, 100): reserved + # [100, 102): base linear velocities + "base_vel_x": 100, + "base_vel_y": 101, + # [102, 103): base angular velocities + "base_angular_vel": 102, + # [103, 115): dextrous hand joint positions + **{"left_hand_joint_{}_pos".format(i): i + 103 for i in range(6)}, + **{"right_hand_joint_{}_pos".format(i): i + 109 for i in range(6)}, + # [115, 119): torso joint positions + **{"torso_joint_{}_pos".format(i): i + 115 for i in range(4)}, + # [119, 121): head joint positions + **{"head_joint_{}_pos".format(i): i + 119 for i in range(2)}, + "waist": 115, + # [121, 123): head joint velocities + **{"head_joint_{}_vel".format(i): i + 121 for i in range(2)}, + "waist_vel": 123, + # [124, 128): reserved +} + + +STATE_VEC_LEN = 128 + + +def get_all_left_related_indices(including_end: bool = True): + if including_end: + return np.arange(50, 128, step=1) + else: + return np.arange(50, 100) + + +def get_all_right_related_indices(): + return np.arange(0, 50) diff --git a/embodichain/data/global_mapping.py b/embodichain/data/global_mapping.py new file mode 100644 index 00000000..0be40a02 --- /dev/null +++ b/embodichain/data/global_mapping.py @@ -0,0 +1,151 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from embodichain.data.enum import ( + ControlParts, + ActionMode, + EndEffector, + JointType, + EefType, + is_dual_arms, +) +from embodichain.data.global_indices import GLOBAL_INDICES +import numpy as np +from typing import List + + +class GlobalMapping: + def __init__(self, dof: int): + self_attrs = GlobalMapping.__dict__ + num_arm = 2 if is_dual_arms(dofs=dof) else 1 + single_dof = dof // num_arm + function_dict = {} + for k, v in self_attrs.items(): + if isinstance(v, staticmethod) and "__" not in k: + function_dict.update(v.__func__(dof=single_dof, num_arm=num_arm)) + self.mapping_from_name_to_indices = function_dict + + @staticmethod + def get_qpos_indices(dof: int, num_arm, **kwrags): + + return { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.HEAD.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_pos".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + JointType.QPOS.value: [GLOBAL_INDICES["waist"]], + } + + @staticmethod + def get_gripper_open_state_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open"]], + ControlParts.RIGHT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open"]], + } + + @staticmethod + def get_hand_qpos_indices(num_arm: int, hand_dof: int = 6, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"left_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"right_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + } + + @staticmethod + def get_gripper_open_vel_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open_vel"]], + ControlParts.RIGHT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open_vel"]], + } + + @staticmethod + def get_delta_qpos_indices(dof: int, num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.HEAD.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_vel".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [GLOBAL_INDICES["waist_vel"]], + } + + @staticmethod + def get_eef_pose_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["left_eef_pos_x"], + GLOBAL_INDICES["left_eef_pos_y"], + GLOBAL_INDICES["left_eef_pos_z"], + GLOBAL_INDICES["left_eef_angle_0"], + GLOBAL_INDICES["left_eef_angle_1"], + GLOBAL_INDICES["left_eef_angle_2"], + GLOBAL_INDICES["left_eef_angle_3"], + GLOBAL_INDICES["left_eef_angle_4"], + GLOBAL_INDICES["left_eef_angle_5"], + ], + ControlParts.RIGHT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["right_eef_pos_x"], + GLOBAL_INDICES["right_eef_pos_y"], + GLOBAL_INDICES["right_eef_pos_z"], + GLOBAL_INDICES["right_eef_angle_0"], + GLOBAL_INDICES["right_eef_angle_1"], + GLOBAL_INDICES["right_eef_angle_2"], + GLOBAL_INDICES["right_eef_angle_3"], + GLOBAL_INDICES["right_eef_angle_4"], + GLOBAL_INDICES["right_eef_angle_5"], + ], + } + + def get_indices(self, state_meta: List[str]): + state_indices = [] + + for proprio_name in state_meta: + state_indices += self.mapping_from_name_to_indices[proprio_name] + + return state_indices + + def ret_all_state( + self, + ): + state_indices = [] + + for val in self.mapping_from_name_to_indices.values(): + state_indices += val + + return state_indices diff --git a/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt new file mode 100644 index 00000000..6b2205e2 --- /dev/null +++ b/embodichain/database/agent_prompt/DualPourWaterAgent-v3/task_prompt.txt @@ -0,0 +1,5 @@ +Task: +Use both arms to pour water from the bottle into the cup. +First, grasp the bottle with right arm and the cup with left arm simultaneously. Then lift the cup by 0.10 m, and then move it to [0.55, 0.05] to prepare for pouring. +Then hold the bottle at a relative offset of [0.05, −0.10, 0.125] with respect to the cup for starting pouring. +After pouring, place the bottle to [0.7, −0.1] and place the cup at [0.6, 0.1] simultaneously. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt new file mode 100644 index 00000000..be7eb3b1 --- /dev/null +++ b/embodichain/database/agent_prompt/PourWaterAgent-v3/task_prompt.txt @@ -0,0 +1,5 @@ +Task: +Use a single robotic arm to pour water from the bottle into the cup. +Position the bottle at an offset of [0.05, −0.10, 0.125] relative to the cup’s position during pouring. +After completing the pour, return the bottle to the location [0.7, −0.1]. +The cup should remain stationary throughout the task. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py new file mode 100644 index 00000000..36de8c16 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_code.py @@ -0,0 +1,29 @@ +# Step 1: Grasp the fork with the left arm and the spoon with the right arm +drive( + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), + right_arm_action=grasp( + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), +) + +# Step 2: Reorient both end-effectors to a downward-facing pose +drive( + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), +) + +# Step 3: Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center +drive( + left_arm_action=place_on_table( + robot_name="left_arm", obj_name="fork", x=0.0, y=0.16, pre_place_dis=0.08 + ), + right_arm_action=place_on_table( + robot_name="right_arm", obj_name="spoon", x=0.0, y=-0.16, pre_place_dis=0.08 + ), +) + +# Step 4: Return both arms to their initial poses +drive( + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt new file mode 100644 index 00000000..d5ef2e70 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/agent_generated_plan.txt @@ -0,0 +1,19 @@ +**[PLANS]:** + +Step 1: Grasp the fork with the left arm and the spoon with the right arm — `drive(left_arm_action=grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10), right_arm_action=grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10))` + +Step 2: Reorient both end-effectors to a downward-facing pose — `drive(left_arm_action=orient_eef(robot_name='left_arm', direction='down'), right_arm_action=orient_eef(robot_name='right_arm', direction='down'))` + +Step 3: Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center — `drive(left_arm_action=place_on_table(robot_name='left_arm', obj_name='fork', x=0.0, y=0.16, pre_place_dis=0.08), right_arm_action=place_on_table(robot_name='right_arm', obj_name='spoon', x=0.0, y=-0.16, pre_place_dis=0.08))` + +Step 4: Return both arms to their initial poses — `drive(left_arm_action=back_to_initial_pose(robot_name='left_arm'), right_arm_action=back_to_initial_pose(robot_name='right_arm'))` + +**[VALIDATION_CONDITIONS]:** + +Step 1: The left arm should be holding the fork, and the right arm should be holding the spoon. + +Step 2: Both end-effectors should be facing downward, with the fork and spoon still held. + +Step 3: The fork should be placed at y = +0.16, and the spoon should be placed at y = −0.16 relative to the plate’s center. Both arms should have released their objects. + +Step 4: Both arms should be in their initial poses, with no objects held. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py new file mode 100644 index 00000000..c94c82cd --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_code.py @@ -0,0 +1,34 @@ +# Step 1 — Grasp the Fork and Spoon Simultaneously +drive( + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), + right_arm_action=grasp( + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), +) + +# Step 2 — Reorient End-Effectors to Downward-Facing Pose +drive( + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), +) + +# Step 3 — Place the Fork and Spoon on Opposite Sides of the Plate +drive( + left_arm_action=move_relative_to_object( + robot_name="left_arm", obj_name="plate", x_offset=0, y_offset=0.16, z_offset=0 + ), + right_arm_action=move_relative_to_object( + robot_name="right_arm", obj_name="plate", x_offset=0, y_offset=-0.16, z_offset=0 + ), +) + +drive( + left_arm_action=open_gripper(robot_name="left_arm"), + right_arm_action=open_gripper(robot_name="right_arm"), +) + +# Step 4 — Return Both Arms to Initial Pose +drive( + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt new file mode 100644 index 00000000..a53d6c6e --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_105727/agent_generated_plan.txt @@ -0,0 +1,38 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned on either side of the table, ready to interact with the objects. + +### Task Plan + +1. **Grasp the Fork and Spoon Simultaneously** + - **Left Arm (Fork):** + - Use `grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10)` to grasp the fork. + - **Right Arm (Spoon):** + - Use `grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10)` to grasp the spoon. + - Execute: `drive(left_arm_action=grasp_left, right_arm_action=grasp_right)` + +2. **Reorient End-Effectors to Downward-Facing Pose** + - **Left Arm:** + - Use `orient_eef(robot_name='left_arm', direction='down')` to reorient the left arm. + - **Right Arm:** + - Use `orient_eef(robot_name='right_arm', direction='down')` to reorient the right arm. + - Execute: `drive(left_arm_action=orient_left, right_arm_action=orient_right)` + +3. **Place the Fork and Spoon on Opposite Sides of the Plate** + - **Left Arm (Fork):** + - Use `move_relative_to_object(robot_name='left_arm', obj_name='plate', x_offset=0, y_offset=0.16, z_offset=0)` to position the fork. + - Use `open_gripper(robot_name='left_arm')` to release the fork. + - **Right Arm (Spoon):** + - Use `move_relative_to_object(robot_name='right_arm', obj_name='plate', x_offset=0, y_offset=-0.16, z_offset=0)` to position the spoon. + - Use `open_gripper(robot_name='right_arm')` to release the spoon. + - Execute: `drive(left_arm_action=move_left, right_arm_action=move_right)` + - Execute: `drive(left_arm_action=open_left, right_arm_action=open_right)` + +4. **Return Both Arms to Initial Pose** + - **Left Arm:** + - Use `back_to_initial_pose(robot_name='left_arm')` to return the left arm. + - **Right Arm:** + - Use `back_to_initial_pose(robot_name='right_arm')` to return the right arm. + - Execute: `drive(left_arm_action=back_left, right_arm_action=back_right)` + +This plan ensures the fork and spoon are rearranged on opposite sides of the plate, with the robot arms returning to their initial positions at the end. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py new file mode 100644 index 00000000..661b61b8 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_code.py @@ -0,0 +1,34 @@ +# Step 1 — Grasp the fork and spoon +drive( + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), + right_arm_action=grasp( + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), +) + +# Step 2 — Reorient end-effectors to downward-facing pose +drive( + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), +) + +# Step 3 — Place fork and spoon on opposite sides of the plate +drive( + left_arm_action=move_relative_to_object( + robot_name="left_arm", obj_name="plate", y_offset=0.16 + ), + right_arm_action=move_relative_to_object( + robot_name="right_arm", obj_name="plate", y_offset=-0.16 + ), +) + +drive( + left_arm_action=open_gripper(robot_name="left_arm"), + right_arm_action=open_gripper(robot_name="right_arm"), +) + +# Step 4 — Return arms to initial pose +drive( + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt new file mode 100644 index 00000000..185df534 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110033/agent_generated_plan.txt @@ -0,0 +1,49 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned with the left arm near the fork and the right arm near the spoon. + +### Task Plan + +1. **Grasp the Fork and Spoon:** + - Use the left arm to grasp the fork. + - Use the right arm to grasp the spoon. + + ```python + left_grasp_action = grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10) + right_grasp_action = grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10) + drive(left_arm_action=left_grasp_action, right_arm_action=right_grasp_action) + ``` + +2. **Reorient End-Effectors to Downward-Facing Pose:** + - Reorient both end-effectors to face downward. + + ```python + left_orient_action = orient_eef(robot_name='left_arm', direction='down') + right_orient_action = orient_eef(robot_name='right_arm', direction='down') + drive(left_arm_action=left_orient_action, right_arm_action=right_orient_action) + ``` + +3. **Place Fork and Spoon on Opposite Sides of the Plate:** + - Place the fork at y = +0.16 relative to the plate’s center. + - Place the spoon at y = −0.16 relative to the plate’s center. + + ```python + left_place_action = move_relative_to_object(robot_name='left_arm', obj_name='plate', y_offset=0.16) + right_place_action = move_relative_to_object(robot_name='right_arm', obj_name='plate', y_offset=-0.16) + drive(left_arm_action=left_place_action, right_arm_action=right_place_action) + + left_release_action = open_gripper(robot_name='left_arm') + right_release_action = open_gripper(robot_name='right_arm') + drive(left_arm_action=left_release_action, right_arm_action=right_release_action) + ``` + +4. **Return Arms to Initial Pose:** + - Return both arms to their initial configurations. + + ```python + left_initial_action = back_to_initial_pose(robot_name='left_arm') + right_initial_action = back_to_initial_pose(robot_name='right_arm') + drive(left_arm_action=left_initial_action, right_arm_action=right_initial_action) + ``` + +This plan ensures the fork and spoon are rearranged on opposite sides of the plate, with the fork on the left and the spoon on the right, relative to the robot base frame. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py new file mode 100644 index 00000000..26648600 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_code.py @@ -0,0 +1,29 @@ +# Step 1 — Grasp the Fork and Spoon Simultaneously +drive( + left_arm_action=grasp(robot_name="left_arm", obj_name="fork", pre_grasp_dis=0.10), + right_arm_action=grasp( + robot_name="right_arm", obj_name="spoon", pre_grasp_dis=0.10 + ), +) + +# Step 2 — Reorient Both End-Effectors to a Downward-Facing Pose +drive( + left_arm_action=orient_eef(robot_name="left_arm", direction="down"), + right_arm_action=orient_eef(robot_name="right_arm", direction="down"), +) + +# Step 3 — Place the Fork and Spoon at Specified Positions +drive( + left_arm_action=place_on_table( + robot_name="left_arm", obj_name="fork", x=None, y=0.16, pre_place_dis=0.08 + ), + right_arm_action=place_on_table( + robot_name="right_arm", obj_name="spoon", x=None, y=-0.16, pre_place_dis=0.08 + ), +) + +# Step 4 — Return Both Arms to Initial Pose +drive( + left_arm_action=back_to_initial_pose(robot_name="left_arm"), + right_arm_action=back_to_initial_pose(robot_name="right_arm"), +) diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt new file mode 100644 index 00000000..7d5b20ea --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/feedback_logs/20260120_110212/agent_generated_plan.txt @@ -0,0 +1,47 @@ +### Scene Analysis + +The image shows a table with a plate in the center, a fork on the left side, and a spoon on the right side. The robot arms are positioned above the table, ready to interact with the objects. + +### Task Plan + +1. **Grasp the Fork and Spoon Simultaneously:** + - Use the left arm to grasp the fork. + - Use the right arm to grasp the spoon. + + ```python + left_grasp_action = grasp(robot_name='left_arm', obj_name='fork', pre_grasp_dis=0.10) + right_grasp_action = grasp(robot_name='right_arm', obj_name='spoon', pre_grasp_dis=0.10) + drive(left_arm_action=left_grasp_action, right_arm_action=right_grasp_action) + ``` + +2. **Reorient Both End-Effectors to a Downward-Facing Pose:** + - Reorient the left arm's end-effector to face downward. + - Reorient the right arm's end-effector to face downward. + + ```python + left_orient_action = orient_eef(robot_name='left_arm', direction='down') + right_orient_action = orient_eef(robot_name='right_arm', direction='down') + drive(left_arm_action=left_orient_action, right_arm_action=right_orient_action) + ``` + +3. **Place the Fork and Spoon at Specified Positions:** + - Place the fork at y = +0.16 relative to the plate’s center using the left arm. + - Place the spoon at y = −0.16 relative to the plate’s center using the right arm. + + ```python + left_place_action = place_on_table(robot_name='left_arm', obj_name='fork', x=None, y=0.16, pre_place_dis=0.08) + right_place_action = place_on_table(robot_name='right_arm', obj_name='spoon', x=None, y=-0.16, pre_place_dis=0.08) + drive(left_arm_action=left_place_action, right_arm_action=right_place_action) + ``` + +4. **Return Both Arms to Initial Pose:** + - Return the left arm to its initial pose. + - Return the right arm to its initial pose. + + ```python + left_initial_action = back_to_initial_pose(robot_name='left_arm') + right_initial_action = back_to_initial_pose(robot_name='right_arm') + drive(left_arm_action=left_initial_action, right_arm_action=right_initial_action) + ``` + +This plan ensures that both arms perform the required actions simultaneously, maintaining synchronization throughout the task. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt b/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt new file mode 100644 index 00000000..280bd227 --- /dev/null +++ b/embodichain/database/agent_prompt/RearrangementAgent-v3/task_prompt.txt @@ -0,0 +1,8 @@ +Task: +Use both arms to rearrange a fork and a spoon on opposite sides of a plate. Perform the following steps **simultaneously**. + +1. Grasp the fork with the left arm and the spoon with the right arm. + +2. Reorient both end-effectors to a downward-facing pose. + +3. Place the fork at y = +0.16 and the spoon at y = −0.16 relative to the plate’s center. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/atom_actions.txt b/embodichain/database/agent_prompt/atom_actions.txt new file mode 100644 index 00000000..257464c5 --- /dev/null +++ b/embodichain/database/agent_prompt/atom_actions.txt @@ -0,0 +1,136 @@ +### Atom Functions for Robot Arm Control +Each atomic function returns a list of joint-space trajectories (list[np.ndarray]). +All functions support an optional argument: + +force_valid: bool + If True, the system will automatically correct an invalid target pose by + projecting it to the nearest valid pose. Use this option carefully: + enable it only for actions where small spatial deviations are acceptable + and will not compromise task correctness. Default is False. + +Use the following functions exactly as defined. Do not invent new APIs or parameters. + +"grasp": + def grasp(robot_name: str, obj_name: str, pre_grasp_dis: float, **kwargs) -> list[np.ndarray] + + Moves the specified arm to the target object’s affordance-based grasp pose and executes a grasp by closing the gripper. + + The function plans a two-stage trajectory: + (1) from the current pose to a pre-grasp pose offset from the object, and + (2) from the pre-grasp pose to the final grasp pose, followed by gripper closure. + + Upon completion, the gripper is closed and the target object is expected to be stably held by the gripper. + + Example: + grasp(robot_name='right_arm', obj_name='bottle', pre_grasp_dis=0.10) # Moves the right arm to a pre-grasp pose 10 cm from the bottle, then to the grasp pose and closes the gripper to grasp the bottle. + +"place_on_table": + def place_on_table(robot_name: str, obj_name: str, x: float, y: float, pre_place_dis: float, **kwargs) -> list[np.ndarray] + + Moves the specified robot arm with the target object to the desired [x, y] location on the table and opens the gripper to place the object. + The z-coordinate is automatically adjusted based on the table height and the object’s dimensions. + This function assumes that the robot is already holding the object and that the task is to place it on the table at the specified coordinates. + Remember that when you need to place some objects on the table at specific coordinates, use this function without using other movement atom actions. + Otherwise, **if you need to place some objects relative to some place, then use "move_relative_to_object" first to move to the desired position, then use "open_gripper" to release the object.** + + Example: + place_on_table(robot_name='right_arm', obj_name='bottle', x=0.1, y=0.5, pre_place_dis=0.08) # Moves the right arm to a pre-place position 8 cm from the table, then places the bottle at the specified [0.1, 0.5] location on the table and opens the gripper. + +"move_relative_to_object": + def move_relative_to_object(robot_name: str, obj_name: str, + x_offset=0, y_offset=0, z_offset=0, + **kwargs) -> list[np.ndarray] + Moves the end-effector to a pose defined relative to the target object: + target = object_position + [x_offset, y_offset, z_offset] + Orientation is preserved. + Example: + move_relative_to_object(robot_name='right_arm', obj_name='cup', + x_offset=0.05, y_offset=0.10, z_offset=0.10) # Moves the right arm’s end-effector to a spot located 5 cm forward, 10 cm to the left, and 10 cm above the cup, while preserving the current gripper orientation. + move_relative_to_object(robot_name='right_arm', obj_name='cup', + x_offset=-0.05, y_offset=-0.10, z_offset=0.10) # Moves the right arm’s end-effector to a spot located 5 cm backward, 10 cm to the right, and 10 cm above the cup, while preserving the current gripper orientation. + +"move_to_absolute_position": + def move_to_absolute_position(robot_name: str, + x=None, y=None, z=None, + **kwargs) -> list[np.ndarray] + Moves the end-effector to an absolute (x, y, z) position in world coordinates. + Any coordinate set to None remains unchanged. + Orientation is preserved. + Example: + move_to_absolute_position(robot_name='right_arm', x=0.10, y=0.10, z=None) # Moves the end-effector to the absolute world position (x=0.10 m, y=0.10 m) while leaving z unchanged, and preserves the orientation. + +"move_by_relative_offset": + def move_by_relative_offset(robot_name: str, + dx=0.0, dy=0.0, dz=0.0, mode='extrinsic', + **kwargs) -> list[np.ndarray] + Moves the end-effector by a relative translation: + new_position = current_position + [dx, dy, dz] + The offset is applied along the specified axes using the given mode, while preserving the original end-effector orientation. + Mode can be 'extrinsic' (world frame) or 'intrinsic' (end-effector frame). If you want to move along the world axes, use 'extrinsic'. If you want to move along the end-effector’s local axes, use "intrinsic". + Example: + move_by_relative_offset(robot_name='right_arm', dx=0.05, dy=-0.10, dz=0.20, mode='extrinsic') # Translates the end-effector by +5 cm in x (front), −10 cm in y (right), +20 cm in z (above) in the world coordinate, relative to its current position, with orientation preserved. + move_by_relative_offset(robot_name='right_arm', dx=0, dy=0, dz=0.1, mode='intrinsic') # Translates the end-effector by +10 cm in z (forward) in the EEF coordinate, meaning that it moves forward relative to its current facing direction, with orientation preserved. + +"rotate_eef" + def rotate_eef(robot_name: str, degree: float, **kwargs) -> list[np.ndarray] + Rotates the wrist roll joint (joint index 5) of the specified arm by the + given number of degrees. End-effector position is preserved. + Example: + rotate_eef(robot_name='right_arm', degree=-90) # Rotates the right arm’s wrist-roll joint by −45° (counterclockwise), while keeping the end-effector position unchanged. This is a joint-level rotation, not a full orientation override. + rotate_eef(robot_name='right_arm', degree=90) # Rotates the right arm’s wrist-roll joint by 45° (clockwise), while keeping the end-effector position unchanged. This is a joint-level rotation, not a full orientation override. + Typical use cases: + Pouring or tilting a grasped object. + Rotating the gripper around its forward axis without translating the end effector. + After rotating, you typically need to apply an opposite rotation back to return to the original pose. + Usage notes: + Rotation sign convention: negative = counterclockwise, positive = clockwise, viewed along the end-effector forward axis. + For pouring with the right arm, a common pattern is: first apply a negative rotation to start pouring, then apply a positive rotation to return. + For the left arm, the sign convention is typically reversed. + +"orient_eef": + def orient_eef(robot_name: str, + direction: str = 'front', # 'front' or 'down' + **kwargs) -> list[np.ndarray] + Reorients the end-effector to a predefined canonical orientation in the + WORLD coordinate frame, while keeping the EE’s current position fixed. + This function replaces the entire 3×3 orientation matrix of the current + end-effector pose. + Usage notes: + This function should only be used when you explicitly need to override the end-effector’s full orientation. + This differs from rotate_eef(). orient_eef performs a full orientation override of the end-effector, not a single-joint rotation. For tasks like pouring, no need to use it. + For general wrist rotation, prefer using rotate_eef instead. + For aligning tasks, use "front" or "down" orientations as needed. + Supported orientations: + • 'front' : Align the end-effector so its direction faces forward. + • 'down' : Align the end-effector so its direction faces downward. + Example: + orient_eef(robot_name='right_arm', direction='front') # Reorients the right arm’s end-effector so it faces forward + +"back_to_initial_pose": + def back_to_initial_pose(robot_name: str, **kwargs) -> list[np.ndarray] + Returns the specified arm to its predefined initial joint configuration + stored in the environment. + Example: + back_to_initial_pose(robot_name='right_arm') # Returns the right arm back to its predefined initial joint configuration stored in the environment, regardless of its current pose. + +"close_gripper": + def close_gripper(robot_name: str, **kwargs) -> list[np.ndarray] + Closes the arm’s gripper using a short (10-step) gripper-only trajectory. + Example: + close_gripper(robot_name='right_arm') # Closes the right gripper using a short, smooth 10-step gripper-only trajectory. + +"open_gripper": + def open_gripper(robot_name: str, **kwargs) -> list[np.ndarray] + Opens the arm’s gripper using a short (10-step) gripper-only trajectory. + Example: + open_gripper(robot_name='right_arm') # Opens the right gripper using a 10-step gripper-only trajectory. + +### Drive Function (Trajectory Synchronization) +"drive": + def drive(left_arm_action=None, right_arm_action=None, **kwargs) -> list[torch.Tensor] + Wraps one or two arm trajectories into synchronized full-robot actions. + • If only one arm action is provided, the other arm stays idle. + • If both are provided, they are temporally aligned and executed together. + • The actions are obtained from the output of the above functions. + Example: + drive(left_arm_action=left_actions, right_arm_action=right_actions) \ No newline at end of file diff --git a/embodichain/database/agent_prompt/basic_background.txt b/embodichain/database/agent_prompt/basic_background.txt new file mode 100644 index 00000000..dc6d1c30 --- /dev/null +++ b/embodichain/database/agent_prompt/basic_background.txt @@ -0,0 +1,42 @@ +The environment uses a right-handed world coordinate system, where 1 unit equals 1 meter. +All robot poses are represented as 4×4 homogeneous transformation matrices. + +The robot base coordinate frame is the ONLY authoritative frame for all spatial reasoning, planning, and action generation. + +CAMERA AND IMAGE INTERPRETATION + +The camera is positioned in front of the robot, facing the robot arm and looking toward the robot base. +Because of this viewpoint, the rendered image is horizontally mirrored relative to the robot base frame. +This mirroring affects LEFT–RIGHT only. There is NO vertical or depth inversion. + +Mirror mapping (image → robot base frame): + +* Image left corresponds to robot right +* Image right corresponds to robot left +* Image up corresponds to robot up +* Image down corresponds to robot down + +REQUIRED REASONING PERSPECTIVE (NON-NEGOTIABLE) + +You must ignore the camera and rendered image orientation when reasoning. +All spatial reasoning must be performed as if you are physically located at the robot base, looking outward along the robot’s +x (forward) direction. + +Do NOT reason from the camera viewpoint. +Do NOT trust left/right as shown in the image. +Always remap image left/right before reasoning. + +ROBOT BASE COORDINATE DEFINITIONS + +All directions below are defined strictly in the robot base frame: + +* Moving forward increases x +* Moving backward decreases x +* Moving left increases y (appears as right in the image) +* Moving right decreases y (appears as left in the image) +* Moving up increases z +* Moving down decreases z + +ROBOT INITIALIZATION AND TERMINATION + +Both robot arms start in predefined initial configurations with their end-effectors open. +At task completion, both arms must be returned to their initial poses. \ No newline at end of file diff --git a/embodichain/database/agent_prompt/code_example.txt b/embodichain/database/agent_prompt/code_example.txt new file mode 100644 index 00000000..c2952fed --- /dev/null +++ b/embodichain/database/agent_prompt/code_example.txt @@ -0,0 +1,35 @@ +# Python scripts +# Use the right arm to grasp bottle, move to the target location (x=0.2, y=0.1), and then open the gripper to release the object. + +```python +# Step 1 — Reach and grasp the bottle +drive( + right_arm_action=grasp( + robot_name="right_arm", + obj_name="bottle", + ), +) + +# Step 2 — Move to target location +drive( + right_arm_action=move_to_absolute_position( + robot_name="right_arm", + x=0.2, + y=0.1, + ), +) + +# Step 3 — Open gripper to release the object +drive( + right_arm_action=open_gripper( + robot_name="right_arm", + ), +) + +# Step 4 — Return the arm to the initial pose +drive( + right_arm_action=back_to_initial_pose( + robot_name="right_arm", + ), +) +``` \ No newline at end of file diff --git a/embodichain/database/agent_prompt/code_prompt.txt b/embodichain/database/agent_prompt/code_prompt.txt new file mode 100644 index 00000000..3fadf1c9 --- /dev/null +++ b/embodichain/database/agent_prompt/code_prompt.txt @@ -0,0 +1,7 @@ +Constraints: +- Every atomic action MUST be executed via a single drive(...) call. +- Each drive(...) call must directly contain the atomic action(s); do NOT define actions separately and then pass them into drive. +- For single-arm execution: specify the active arm’s action and explicitly set the unused arm to None within the same drive(...) call. +- For dual-arm execution: both arms’ actions MUST be specified within the same drive(...) call. +- Use exactly one drive(...) call per step; no exceptions. +- Output MUST be executable Python code only: no explanations, no comments, no markdown, and no extra text. \ No newline at end of file diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py index 88257690..e4a16daa 100644 --- a/embodichain/lab/gym/envs/__init__.py +++ b/embodichain/lab/gym/envs/__init__.py @@ -26,6 +26,14 @@ PourWaterEnv, ) from embodichain.lab.gym.envs.tasks.tableware.scoop_ice import ScoopIce +from embodichain.lab.gym.envs.tasks.tableware.pour_water_v3 import ( + PourWaterEnv3, + PourWaterAgentEnv3, +) +from embodichain.lab.gym.envs.tasks.tableware.rearrangement_v3 import ( + RearrangementEnv3, + RearrangementAgentEnv3, +) # Reinforcement learning environments from embodichain.lab.gym.envs.tasks.rl.push_cube import PushCubeEnv diff --git a/embodichain/lab/gym/envs/action_bank/utils.py b/embodichain/lab/gym/envs/action_bank/utils.py index 255e5b8f..1ca395f1 100644 --- a/embodichain/lab/gym/envs/action_bank/utils.py +++ b/embodichain/lab/gym/envs/action_bank/utils.py @@ -29,6 +29,41 @@ def get_init_affordance(scope: str, tag: str = "init") -> str: return "{}_{}_qpos".format(scope, tag) +def get_control_part(env, agent_uid): + + from embodichain.lab.gym.utils.misc import _data_key_to_control_part + + control_parts = env.metadata["dataset"]["robot_meta"].get("control_parts", []) + + if agent_uid in control_parts: + return agent_uid + else: + return _data_key_to_control_part( + robot=env.robot, + control_parts=control_parts, + data_key=agent_uid, + ) + + +def get_control_part_joint_ids(env, key: str) -> List[int]: + from embodichain.data.enum import ( + ControlParts, + ControlPartsMappingW1, + ) + + control_part = get_control_part(env, key) + if control_part == ControlParts.WAIST.value: + waist_joint_id = env.robot.get_joint_ids(name=ControlParts.TORSO.value)[ + ControlPartsMappingW1.WAIST_IN_TORSO.value + ] + if not isinstance(waist_joint_id, (list)): + return [waist_joint_id] + return waist_joint_id + + else: + return env.robot.get_joint_ids(name=control_part, remove_mimic=True) + + def generate_affordance_from_src( env, src_key: str, diff --git a/embodichain/lab/gym/envs/managers/dataset_manager.py b/embodichain/lab/gym/envs/managers/dataset_manager.py index a0ca1688..d1ea3049 100644 --- a/embodichain/lab/gym/envs/managers/dataset_manager.py +++ b/embodichain/lab/gym/envs/managers/dataset_manager.py @@ -87,20 +87,28 @@ def __init__(self, cfg: object, env: EmbodiedEnv): super().__init__(cfg, env) ## TODO: fix configurable_action.py to avoid getting env.metadata['dataset'] - # Extract robot_meta from first functor and add to env.metadata for backward compatibility + # Extract robot_meta and instruction from first functor and add to env.metadata for backward compatibility # This allows legacy code (like action_bank) to access robot_meta via env.metadata["dataset"]["robot_meta"] for mode_cfgs in self._mode_functor_cfgs.values(): for functor_cfg in mode_cfgs: - if "robot_meta" in functor_cfg.params: + if ( + "robot_meta" in functor_cfg.params + or "instruction" in functor_cfg.params + ): if not hasattr(env, "metadata"): env.metadata = {} if "dataset" not in env.metadata: env.metadata["dataset"] = {} - env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ - "robot_meta" - ] + if "robot_meta" in functor_cfg.params: + env.metadata["dataset"]["robot_meta"] = functor_cfg.params[ + "robot_meta" + ] + if "instruction" in functor_cfg.params: + env.metadata["dataset"]["instruction"] = functor_cfg.params[ + "instruction" + ] logger.log_info( - "Added robot_meta to env.metadata for backward compatibility" + "Added robot_meta and instruction to env.metadata for backward compatibility" ) break else: diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py new file mode 100644 index 00000000..e9e12329 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -0,0 +1,260 @@ +import torch +from embodichain.utils import logger +import traceback +from embodichain.data import database_agent_prompt_dir +from pathlib import Path +import tempfile +import numpy as np +import random +import os +from embodichain.toolkits.interfaces import extract_drive_calls, draw_axis +from embodichain.agents.hierarchy.code_agent import format_execution_history +from embodichain.agents.hierarchy.validation_agent import ( + save_obs_image, + get_obj_position_info, +) + + +class BaseAgentEnv: + + def _init_agents(self, agent_config, task_name): + from embodichain.agents.hierarchy.task_agent import TaskAgent + from embodichain.agents.hierarchy.code_agent import CodeAgent + from embodichain.agents.hierarchy.validation_agent import ValidationAgent + from embodichain.agents.hierarchy.llm import ( + create_llm, + task_llm, + code_llm, + validation_llm, + ) + + if agent_config.get("TaskAgent") is not None: + self.task_agent = TaskAgent( + task_llm, + **agent_config["Agent"], + **agent_config["TaskAgent"], + task_name=task_name, + ) + self.code_agent = CodeAgent( + code_llm, + **agent_config["Agent"], + **agent_config.get("CodeAgent"), + task_name=task_name, + ) + self.validation_agent = ValidationAgent( + validation_llm, + task_name=task_name, + task_description=self.code_agent.prompt_kwargs.get("task_prompt")[ + "content" + ], + basic_background=self.code_agent.prompt_kwargs.get("basic_background")[ + "content" + ], + atom_actions=self.code_agent.prompt_kwargs.get("atom_actions")["content"], + ) + + def get_states(self): + # TODO: only support num_env = 1 for now + # store robot states in each env.reset + self.init_qpos = self.robot.get_qpos().squeeze(0) + + self.left_arm_joints = self.robot.get_joint_ids(name="left_arm") + self.right_arm_joints = self.robot.get_joint_ids(name="right_arm") + self.left_eef_joints = self.robot.get_joint_ids(name="left_eef") + self.right_eef_joints = self.robot.get_joint_ids(name="right_eef") + + self.left_arm_init_qpos = self.init_qpos[self.left_arm_joints] + self.right_arm_init_qpos = self.init_qpos[self.right_arm_joints] + + self.left_arm_init_xpos = self.robot.compute_fk( + self.left_arm_init_qpos, name="left_arm", to_matrix=True + ).squeeze(0) + self.right_arm_init_xpos = self.robot.compute_fk( + self.right_arm_init_qpos, name="right_arm", to_matrix=True + ).squeeze(0) + + self.left_arm_current_qpos = self.left_arm_init_qpos + self.right_arm_current_qpos = self.right_arm_init_qpos + + self.left_arm_current_xpos = self.left_arm_init_xpos + self.right_arm_current_xpos = self.right_arm_init_xpos + + self.left_arm_base_pose = self.robot.get_control_part_base_pose( + "left_arm", to_matrix=True + ).squeeze(0) + self.right_arm_base_pose = self.robot.get_control_part_base_pose( + "right_arm", to_matrix=True + ).squeeze(0) + + self.open_state = torch.tensor([0.05]) + self.close_state = torch.tensor([0.0]) + self.left_arm_current_gripper_state = self.open_state + self.right_arm_current_gripper_state = self.open_state + + # store some useful obj information + init_obj_info = {} + obj_uids = self.sim.get_rigid_object_uid_list() + for obj_name in obj_uids: + obj = self.sim.get_rigid_object(obj_name) + obj_pose = obj.get_local_pose(to_matrix=True).squeeze(0) + obj_height = obj_pose[2, 3] # Extract the height (z-coordinate) + obj_grasp_pose = self.affordance_datas.get( + f"{obj_name}_grasp_pose_object", None + ) + init_obj_info[obj_name] = { + "pose": obj_pose, # Store the full pose (4x4 matrix) + "height": obj_height, # Store the height (z-coordinate) + "grasp_pose_obj": ( + obj_grasp_pose.squeeze(0) if obj_grasp_pose is not None else None + ), # Store the grasp pose if available + } + self.init_obj_info = init_obj_info + + # -------------------- Common getters / setters -------------------- + + def get_obs_for_agent(self): + obs = self.get_obs(get_valid_sensor_data=True) + rgb = obs["sensor"]["cam_high"]["color"].squeeze(0) + valid_rgb_1 = obs["sensor"]["valid_cam_1"]["color"].squeeze(0) + valid_rgb_2 = obs["sensor"]["valid_cam_2"]["color"].squeeze(0) + valid_rgb_3 = obs["sensor"]["valid_cam_3"]["color"].squeeze(0) + + # obs_image_path = save_obs_image(obs_image=self.get_obs_for_agent()["rgb_1"], save_dir='./', step_id=0) + + return { + "rgb": rgb, + "valid_rgb_1": valid_rgb_1, + "valid_rgb_2": valid_rgb_2, + "valid_rgb_3": valid_rgb_3, + } + + # depth = obs["sensor"]["cam_high"]["depth"].squeeze(0) + # mask = obs["sensor"]["cam_high"]["mask"].squeeze(0) + # semantic_mask = obs["sensor"]["cam_high"]["semantic_mask_l"].squeeze(0) + # return {"rgb": rgb, "depth": depth, "mask": mask, "semantic_mask": semantic_mask} + + def get_current_qpos_agent(self): + return self.left_arm_current_qpos, self.right_arm_current_qpos + + def set_current_qpos_agent(self, arm_qpos, is_left): + if is_left: + self.left_arm_current_qpos = arm_qpos + else: + self.right_arm_current_qpos = arm_qpos + + def get_current_xpos_agent(self): + return self.left_arm_current_xpos, self.right_arm_current_xpos + + def set_current_xpos_agent(self, arm_xpos, is_left): + if is_left: + self.left_arm_current_xpos = arm_xpos + else: + self.right_arm_current_xpos = arm_xpos + + def get_current_gripper_state_agent(self): + return self.left_arm_current_gripper_state, self.right_arm_current_gripper_state + + def set_current_gripper_state_agent(self, arm_gripper_state, is_left): + if is_left: + self.left_arm_current_gripper_state = arm_gripper_state + else: + self.right_arm_current_gripper_state = arm_gripper_state + + # -------------------- IK / FK -------------------- + def get_arm_ik(self, target_xpos, is_left, qpos_seed=None): + control_part = "left_arm" if is_left else "right_arm" + ret, qpos = self.robot.compute_ik( + name=control_part, pose=target_xpos, joint_seed=qpos_seed + ) + return ret.all().item(), qpos.squeeze(0) + + def get_arm_fk(self, qpos, is_left): + control_part = "left_arm" if is_left else "right_arm" + xpos = self.robot.compute_fk( + name=control_part, qpos=torch.as_tensor(qpos), to_matrix=True + ) + return xpos.squeeze(0) + + # -------------------- get only code for action list -------------------- + def generate_code_for_actions(self, regenerate=False, **kwargs): + logger.log_info( + f"Generate code for creating action list for {self.code_agent.task_name}.", + color="green", + ) + + # Task planning + print(f"\033[92m\nStart task planning.\n\033[0m") + + task_agent_input = self.task_agent.get_composed_observations( + env=self, regenerate=regenerate, **kwargs + ) + task_plan = self.task_agent.generate(**task_agent_input) + + # Code generation + print(f"\033[94m\nStart code generation.\n\033[0m") + code_agent_input = self.code_agent.get_composed_observations( + env=self, regenerate=regenerate, **kwargs + ) + code_agent_input["task_plan"] = task_plan + + code_file_path, kwargs, code = self.code_agent.generate(**code_agent_input) + return code_file_path, kwargs, code + + # -------------------- get action list -------------------- + def create_demo_action_list(self, regenerate=False): + code_file_path, kwargs, _ = self.generate_code_for_actions( + regenerate=regenerate + ) + action_list = self.code_agent.act(code_file_path, **kwargs) + return action_list + + def to_dataset( + self, + id: str = None, + obs_list: list = None, + action_list: list = None, + ): + from embodichain.data.data_engine.data_dict_extractor import ( + fetch_imitation_dataset, + ) + + from embodichain.lab.gym.robots.interface import LearnableRobot + + # Get episode data from env if not provided + if obs_list is None: + obs_list = getattr(self, "_episode_obs_list", []) + if action_list is None: + action_list = getattr(self, "_episode_action_list", []) + + if len(obs_list) == 0 or len(action_list) == 0: + logger.log_warning("No episode data found. Returning empty dataset.") + return { + "data_path": None, + "id": id, + "current_episode": getattr(self, "curr_episode", 0), + "data": None, + "save_path": None, + } + + dataset_path = self.metadata["dataset"].get("save_path", None) + if dataset_path is None: + from embodichain.data import database_demo_dir + + dataset_path = database_demo_dir + + # TODO: create imitation dataset folder with name "{task_name}_{robot_type}_{num_episodes}" + from embodichain.lab.gym.utils.misc import camel_to_snake + + if not hasattr(self, "folder_name") or self.curr_episode == 0: + robot_class_name = ( + self.robot.__class__.__name__ + if hasattr(self, "robot") and self.robot is not None + else "Robot" + ) + self.folder_name = f"{camel_to_snake(self.__class__.__name__)}_{camel_to_snake(robot_class_name)}" + if os.path.exists(os.path.join(dataset_path, self.folder_name)): + self.folder_name = f"{self.folder_name}_{random.randint(0, 1000)}" + + return fetch_imitation_dataset( + self, obs_list, action_list, id, self.folder_name + ) diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py new file mode 100644 index 00000000..c4b469b6 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water_v3.py @@ -0,0 +1,78 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, Union, Optional, Sequence, Tuple, List + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import configclass, logger + +from embodichain.lab.gym.envs.tasks.tableware.base_agent_env import BaseAgentEnv + +__all__ = ["PourWaterEnv3"] + + +@register_env("PourWater-v3", max_episode_steps=600) +class PourWaterEnv3(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + action_config = kwargs.get("action_config", None) + if action_config is not None: + self.action_config = action_config + + def is_task_success(self, **kwargs) -> torch.Tensor: + """Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. + + Args: + **kwargs: Additional arguments for task-specific success criteria. + + Returns: + torch.Tensor: A boolean tensor indicating success for each environment in the batch. + """ + + bottle = self.sim.get_rigid_object("bottle") + cup = self.sim.get_rigid_object("cup") + + bottle_final_xpos = bottle.get_local_pose(to_matrix=True) + cup_final_xpos = cup.get_local_pose(to_matrix=True) + + bottle_ret = self._is_fall(bottle_final_xpos) + cup_ret = self._is_fall(cup_final_xpos) + + return ~(bottle_ret | cup_ret) + + def _is_fall(self, pose: torch.Tensor) -> torch.Tensor: + # Extract z-axis from rotation matrix (last column, first 3 elements) + pose_rz = pose[:, :3, 2] + world_z_axis = torch.tensor([0, 0, 1], dtype=pose.dtype, device=pose.device) + + # Compute dot product for each batch element + dot_product = torch.sum(pose_rz * world_z_axis, dim=-1) # Shape: (batch_size,) + + # Clamp to avoid numerical issues with arccos + dot_product = torch.clamp(dot_product, -1.0, 1.0) + + # Compute angle and check if fallen + angle = torch.arccos(dot_product) + return angle >= torch.pi / 4 + + +@register_env("PourWaterAgent-v3", max_episode_steps=600) +class PourWaterAgentEnv3(BaseAgentEnv, PourWaterEnv3): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + super()._init_agents(**kwargs) + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + obs, info = super().reset(seed=seed, options=options) + super().get_states() + return obs, info diff --git a/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py new file mode 100644 index 00000000..4f14656a --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/rearrangement_v3.py @@ -0,0 +1,97 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, Union, Optional, Sequence, Tuple, List + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import configclass, logger + +from embodichain.lab.gym.envs.tasks.tableware.base_agent_env import BaseAgentEnv + +__all__ = ["RearrangementEnv3"] + + +@register_env("Rearrangement-v3", max_episode_steps=600) +class RearrangementEnv3(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + action_config = kwargs.get("action_config", None) + if action_config is not None: + self.action_config = action_config + + def is_task_success(self) -> bool: + fork = self.sim.get_rigid_object("fork") + spoon = self.sim.get_rigid_object("spoon") + plate = self.sim.get_rigid_object("plate") + plate_pose = plate.get_local_pose(to_matrix=True) + # TODO: now only for 1 env + ( + spoon_place_target_x, + spoon_place_target_y, + spoon_place_target_z, + ) = self.affordance_datas["spoon_place_pose"][:3, 3] + ( + fork_place_target_x, + fork_place_target_y, + fork_place_target_z, + ) = self.affordance_datas["fork_place_pose"][:3, 3] + + spoon_pose = spoon.get_local_pose(to_matrix=True) + spoon_x, spoon_y, spoon_z = spoon_pose[0, :3, 3] + + fork_pose = fork.get_local_pose(to_matrix=True) + fork_x, fork_y, fork_z = fork_pose[0, :3, 3] + + tolerance = self.metadata.get("success_params", {}).get("tolerance", 0.02) + + # spoon and fork should with the x y range of tolerance related to plate. + return ~( + abs(spoon_x - spoon_place_target_x) > tolerance + or abs(spoon_y - spoon_place_target_y) > tolerance + or abs(fork_x - fork_place_target_x) > tolerance + or abs(fork_y - fork_place_target_y) > tolerance + ) + + +@register_env("RearrangementAgent-v3", max_episode_steps=600) +class RearrangementAgentEnv3(BaseAgentEnv, RearrangementEnv3): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + super()._init_agents(**kwargs) + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + obs, info = super().reset(seed=seed, options=options) + super().get_states() + return obs, info + + def is_task_success(self): + fork = self.sim.get_rigid_object("fork") + spoon = self.sim.get_rigid_object("spoon") + plate = self.sim.get_rigid_object("plate") + + plate_pose = plate.get_local_pose(to_matrix=True) + spoon_place_target_y = plate_pose[0, 1, 3] - 0.16 + fork_place_target_y = plate_pose[0, 1, 3] + 0.16 + + spoon_pose = spoon.get_local_pose(to_matrix=True) + spoon_y = spoon_pose[0, 1, 3] + + fork_pose = fork.get_local_pose(to_matrix=True) + fork_y = fork_pose[0, 1, 3] + + tolerance = self.metadata.get("success_params", {}).get("tolerance", 0.02) + + # spoon and fork should with the y range of tolerance related to plate. + return ( + abs(spoon_y - spoon_place_target_y) <= tolerance + or abs(fork_y - fork_place_target_y) <= tolerance + ) diff --git a/embodichain/lab/gym/motion_generation/action/action.py b/embodichain/lab/gym/motion_generation/action/action.py new file mode 100644 index 00000000..74f5b55e --- /dev/null +++ b/embodichain/lab/gym/motion_generation/action/action.py @@ -0,0 +1,77 @@ +import numpy as np + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.gym.envs import BaseEnv +from embodichain.utils import logger + + +class Action(ABC): + r"""Base class for action terms. + + The action term is responsible for processing the raw actions sent to the environment + and applying them to the asset managed by the term. The action term is comprised of two + operations: + + """ + + env = None + scene = None + + def __init__(self, env, **kwargs) -> None: + self.env: BaseEnv = env + self.scene = self.env.scene + + # def reset(self, env_ids: Sequence[int] | None = None) -> None: + # r"""Resets the manager term. + + # Args: + # env_ids: The environment ids. Defaults to None, in which case + # all environments are considered. + # """ + # pass + + # @abstractmethod + # def process_actions(self, actions: torch.Tensor): + # r"""Processes the actions sent to the environment. + + # Note: + # This function is called once per environment step by the manager. + + # Args: + # actions: The actions to process. + # """ + # raise NotImplementedError + + # @abstractmethod + # def apply_actions(self): + # r"""Applies the actions to the asset managed by the term. + + # Note: + # This is called at every simulation step by the manager. + # """ + # raise NotImplementedError + + def __call__(self, *args) -> Any: + """Returns the value of the term required by the manager. + + In case of a class implementation, this function is called by the manager + to get the value of the term. The arguments passed to this function are + the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`). + + .. attention:: + To be consistent with memory-less implementation of terms with functions, it is + recommended to ensure that the returned mutable quantities are cloned before + returning them. For instance, if the term returns a tensor, it is recommended + to ensure that the returned tensor is a clone of the original tensor. This prevents + the manager from storing references to the tensors and altering the original tensors. + + Args: + *args: Variable length argument list. + + Returns: + The value of the term. + """ + raise NotImplementedError diff --git a/embodichain/lab/gym/motion_generation/action/arm_action.py b/embodichain/lab/gym/motion_generation/action/arm_action.py new file mode 100644 index 00000000..bacc98b1 --- /dev/null +++ b/embodichain/lab/gym/motion_generation/action/arm_action.py @@ -0,0 +1,710 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, List, Tuple, Union +from embodichain.utils import logger +from embodichain.lab.gym.motion_generation.action.action import Action +from embodichain.lab.gym.motion_generation.planner.utils import ( + TrajectorySampleMethod, +) +from embodichain.lab.gym.motion_generation.planner.toppra_planner import ( + ToppraPlanner, +) + + +class ArmAction(Action): + r"""Initialize the ArmAction class.""" + + def __init__(self, env, robot_uid, **kwargs) -> None: + super().__init__(env, **kwargs) + self.agent_uid = robot_uid + if "LeftManipulator" == robot_uid or "RightManipulator" == robot_uid: + self.agent = self.scene.get_robot("DualManipulator") + else: + self.agent = self.scene.get_robot(self.agent_uid) + + init_qpos = self.agent.get_init_qpos(self.agent_uid) + self.init_ee_xpos = self.agent.get_fk(qpos=init_qpos, uid=self.agent_uid) + self.init_base_xpos = self.agent.get_base_xpos(self.agent_uid) + + self.drive_controller = self.agent.drive_controllers[self.agent_uid] + + def move( + self, + xpos_list: np.ndarray, + is_linear: bool = False, + is_wait: bool = True, + is_world_coordinates=False, + **kwargs, + ): + r"""Move the robot to a specified position. + + Args: + xpos_list (np.ndarray): List of target positions. + is_linear (bool): If True, move in a linear path. + is_wait (bool): If True, wait until the movement is completed. + is_world_coordinates (bool): If True, interpret positions in world coordinates. + kwargs (dict): Additional arguments. + + Returns: + bool: True if movement is successful, else False. + """ + if hasattr(self.agent, "move"): + res = self.agent.move( + xpos_list, + is_linear=is_linear, + is_wait=is_wait, + is_world_coordinates=is_world_coordinates, + ) + + return res + else: + return False + + def move_in_joints( + self, + qpos_list: np.ndarray, + is_linear: bool = False, + is_wait: bool = True, + **kwargs, + ): + r"""Move the robot joints to specified positions. + + Args: + qpos_list (np.ndarray): List of target joint positions. + is_linear (bool): If True, move joints in a linear path. + is_wait (bool): If True, wait until the movement is completed. + kwargs (dict): Additional arguments. + + Returns: + bool: True if movement is successful, else False. + """ + if hasattr(self.agent, "move_in_joints"): + res = self.agent.move_in_joints( + qpos_list, is_linear=is_linear, is_wait=is_wait + ) + return res + else: + return False + + def apply_transform( + self, xpos: np.ndarray, lift_vector: np.ndarray, is_local: bool = False + ): + """Apply a lift to the given pose in either local or world coordinates. + + Args: + pick_xpos (np.ndarray): The original 4x4 transformation matrix. + lift_vector (np.ndarray): A 3-element vector representing the lift in [x, y, z] directions. + is_local (bool): If True, apply the lift in local coordinates; + if False, apply in world coordinates. + + Returns: + np.ndarray: The new 4x4 transformation matrix after applying the lift. + """ + xpos = np.array(xpos) + lift_vector = np.array(lift_vector) + # Assert to ensure xpos is a 4x4 matrix and lift_vector has three components + assert xpos.shape == (4, 4), "Target pose must be a 4x4 matrix." + assert lift_vector.shape == ( + 3, + ), "Lift vector must have three components [x, y, z]." + + # Create a copy of the xpos + new_xpos = deepcopy(xpos) + + # Create a translation matrix for lifting in world coordinates + translation_matrix = np.array( + [ + [1, 0, 0, lift_vector[0]], + [0, 1, 0, lift_vector[1]], + [0, 0, 1, lift_vector[2]], + [0, 0, 0, 1], + ] + ) + + if is_local: + # Apply lift in local coordinates + new_xpos = new_xpos @ translation_matrix + else: + # Apply the translation in the world coordinate system + new_xpos = translation_matrix @ new_xpos + + return new_xpos + + @staticmethod + def create_discrete_trajectory( + agent, + uid, + xpos_list: List[np.ndarray] = None, + is_use_current_qpos: bool = True, + is_linear: bool = False, + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.QUANTITY, + sample_num: Union[float, int] = 20, + qpos_seed: np.ndarray = None, + **kwargs, + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + r"""Generate a discrete trajectory between waypoints using cartesian or joint space interpolation. + + This method supports two trajectory planning approaches: + 1. Linear interpolation: Fast, uniform spacing, no dynamics constraints + 2. ToppraPlanner: Smooth, considers velocity/acceleration limits, realistic motion + + Args: + agent: The robot agent instance + uid: Unique identifier for the robot agent + xpos_list: List of waypoints as 4x4 transformation matrices + is_use_current_qpos: Whether to use current joint angles as IK seed + is_linear: If True, use cartesian linear interpolation, else joint space + sample_method: Sampling method for ToppraPlanner (QUANTITY or TIME) + sample_num: Number of interpolated points for final trajectory + qpos_seed: Initial joint configuration for IK solving + **kwargs: Additional arguments: + - qpos_list: Optional list of joint configurations + + Returns: + A tuple containing: + - List[np.ndarray]: Joint space trajectory as a list of joint configurations + - List[np.ndarray]: Cartesian space trajectory as a list of 4x4 matrices + """ + from scipy.spatial.transform import Rotation, Slerp + import numpy as np + + def interpolate_xpos( + current_xpos: np.ndarray, target_xpos: np.ndarray, num_samples: int + ) -> list[np.ndarray]: + if num_samples < 2: + num_samples = 2 + + slerp = Slerp( + [0, 1], + Rotation.from_matrix([current_xpos[:3, :3], target_xpos[:3, :3]]), + ) + interpolated_poses = [] + for s in np.linspace(0, 1, num_samples): + interp_rot = slerp(s).as_matrix() + interp_trans = (1 - s) * current_xpos[:3, 3] + s * target_xpos[:3, 3] + interp_pose = np.eye(4) + interp_pose[:3, :3] = interp_rot + interp_pose[:3, 3] = interp_trans + interpolated_poses.append(interp_pose) + return interpolated_poses + + def calculate_point_allocations( + xpos_list, step_size=0.002, angle_step=np.pi / 90 + ): + point_allocations = [] + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + if isinstance(start_pose, torch.Tensor): + start_pose = start_pose.squeeze().cpu().numpy() + if isinstance(end_pose, torch.Tensor): + end_pose = end_pose.squeeze().cpu().numpy() + + pos_dist = np.linalg.norm(end_pose[:3, 3] - start_pose[:3, 3]) + pos_points = max(1, int(pos_dist / step_size)) + + angle_diff = Rotation.from_matrix( + start_pose[:3, :3].T @ end_pose[:3, :3] + ) + angle = abs(angle_diff.as_rotvec()).max() + rot_points = max(1, int(angle / angle_step)) + + num_points = max(pos_points, rot_points) + point_allocations.append(num_points) + + return point_allocations + + def create_qpos_dict(position: np.ndarray, dof: int) -> Dict: + """Create qpos dictionary with zero velocity and acceleration""" + return { + "position": ( + position.tolist() if isinstance(position, np.ndarray) else position + ), + "velocity": [0.0] * dof, + "acceleration": [0.0] * dof, + } + + if hasattr(agent, "get_dof"): + agent_dof = agent.get_dof(uid) + elif hasattr(agent, "control_parts"): + agent_dof = len(agent.control_parts[uid]) + + # TODO(@Jietao Chen): max_constraints should be read from URDF file + max_constraints = kwargs.get("max_constraints", None) + if max_constraints is None: + max_constraints = { + "velocity": [0.2] * agent_dof, + "acceleration": [0.5] * agent_dof, + } + planner = ToppraPlanner(agent_dof, max_constraints) + + out_qpos_list = [] + out_xpos_list = [] + + # Handle input arguments + qpos_list = kwargs.get("qpos_list", None) + if qpos_list is not None: + qpos_list = np.asarray(qpos_list) + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in qpos_list] + elif hasattr(agent, "compute_fk"): + qpos_list = ( + torch.tensor(qpos_list) + if not isinstance(qpos_list, torch.Tensor) + else qpos_list + ) + xpos_list = [ + agent.compute_fk(qpos=q, name=uid, to_matrix=True) + for q in qpos_list + ] + else: + logger.log_warning("Agent does not support FK computation") + + if is_use_current_qpos: + current_qpos = agent.get_current_qpos(uid) + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + current_xpos = agent.get_fk(uid=uid, qpos=current_qpos) + elif hasattr(agent, "compute_fk"): + current_xpos = agent.compute_fk( + qpos=current_qpos, name=uid, to_matrix=True + ) + else: + logger.log_warning("Agent does not support FK computation") + return [], [] + + pos_diff = np.linalg.norm(current_xpos[:3, 3] - xpos_list[0][:3, 3]) + rot_diff = np.linalg.norm(current_xpos[:3, :3] - xpos_list[0][:3, :3]) + + if pos_diff > 0.001 or rot_diff > 0.01: + xpos_list = np.concatenate( + [current_xpos[None, :, :], xpos_list], axis=0 + ) + if qpos_list is not None: + qpos_list = np.concatenate( + [current_qpos[None, :], qpos_list], axis=0 + ) + + if qpos_seed is None and qpos_list is not None: + qpos_seed = qpos_list[0] + + # Input validation + if xpos_list is None or len(xpos_list) < 2: + logger.log_warning("xpos_list must contain at least 2 points") + return [], [] + + # Calculate point allocations for interpolation + interpolated_point_allocations = calculate_point_allocations( + xpos_list, step_size=0.002, angle_step=np.pi / 90 + ) + + # Generate trajectory + interpolate_qpos_list = [] + if is_linear or qpos_list is None: + # Linear cartesian interpolation + for i in range(len(xpos_list) - 1): + interpolated_poses = interpolate_xpos( + xpos_list[i], xpos_list[i + 1], interpolated_point_allocations[i] + ) + + for xpos in interpolated_poses: + # TODO: It will use computed ik in the future + if hasattr(agent, "get_ik"): + success, qpos = agent.get_ik(xpos, qpos_seed=qpos_seed, uid=uid) + elif hasattr(agent, "compute_ik"): + success, qpos = agent.compute_ik( + pose=xpos, joint_seed=qpos_seed, name=uid + ) + else: + logger.log_warning("Agent does not support IK computation") + + if not success: + logger.log_warning(f"IK solving failed for pose {xpos}") + return [], [] + interpolate_qpos_list.append(qpos) + qpos_seed = qpos + else: + # Joint space interpolation + interpolate_qpos_list = qpos_list + + # Create trajectory dictionary + current_qpos_dict = create_qpos_dict(interpolate_qpos_list[0], agent_dof) + target_qpos_dict_list = [ + create_qpos_dict(pos, agent_dof) for pos in interpolate_qpos_list[1:] + ] + + # Plan trajectory + res, out_qpos_list, *_ = planner.plan( + current_qpos_dict, + target_qpos_dict_list, + sample_method=sample_method, + sample_interval=sample_num, + ) + if not res: + logger.log_warning("Failed to plan trajectory with ToppraPlanner") + return [], [] + + # TODO: It will use computed fk in the future + if hasattr(agent, "get_fk"): + out_xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in out_qpos_list] + elif hasattr(agent, "compute_fk"): + out_qpos_list = ( + torch.tensor(out_qpos_list) + if not isinstance(out_qpos_list, torch.Tensor) + else out_qpos_list + ) + out_xpos_list = [ + agent.compute_fk(qpos=q, name=uid, to_matrix=True) + for q in out_qpos_list + ] + else: + logger.log_warning("Agent does not support FK computation") + + return out_qpos_list, out_xpos_list + + @staticmethod + def estimate_trajectory_sample_count( + agent, + uid, + xpos_list: List[np.ndarray] = None, + qpos_list: List[np.ndarray] = None, + step_size: float = 0.01, + angle_step: float = np.pi / 90, + **kwargs, + ) -> int: + """Estimate the number of trajectory sampling points required. + + This function estimates the total number of sampling points needed to generate + a trajectory based on the given waypoints and sampling parameters. It can be + used to predict computational load and memory requirements before actual + trajectory generation. + + Args: + agent: Robot agent instance + uid: Unique identifier for the robot agent + xpos_list: List of 4x4 transformation matrices representing waypoints + qpos_list: List of joint positions (optional) + is_linear: Whether to use linear interpolation + step_size: Maximum allowed distance between consecutive points (in meters) + angle_step: Maximum allowed angular difference between consecutive points (in radians) + **kwargs: Additional parameters for further customization + + Returns: + int: Estimated number of trajectory sampling points + """ + + def rotation_matrix_to_angle(self, rot_matrix: np.ndarray) -> float: + cos_angle = (np.trace(rot_matrix) - 1) / 2 + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return np.arccos(cos_angle) + + # Input validation + if xpos_list is None and qpos_list is None: + return 0 + + # If joint position list is provided but end effector position list is not, + # convert through forward kinematics + if qpos_list is not None and xpos_list is None: + if len(qpos_list) < 2: + return 1 if len(qpos_list) == 1 else 1 + try: + if hasattr(agent, "get_fk_batch"): + xpos_list = agent.get_fk_batch(uid=uid, qpos_list=qpos_list) + else: + xpos_list = [agent.get_fk(uid=uid, qpos=q) for q in qpos_list] + except Exception as e: + logger.log_warning(f"Forward kinematics failed: {e}") + return 0 + + if xpos_list is None or len(xpos_list) == 0: + return 1 + + if len(xpos_list) == 1: + return 1 + + total_samples = 1 # Starting point + angle_step_inv = 1.0 / angle_step + + total_pos_dist = 0.0 + total_angle = 0.0 + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + pos_diff = end_pose[:3, 3] - start_pose[:3, 3] + total_pos_dist += np.linalg.norm(pos_diff) + + try: + rot_matrix = start_pose[:3, :3].T @ end_pose[:3, :3] + angle = rotation_matrix_to_angle(rot_matrix) + total_angle += angle + except Exception: + pass + + pos_samples = max(1, int(total_pos_dist / step_size)) + rot_samples = max(1, int(total_angle / angle_step)) + + total_samples = max(pos_samples, rot_samples) + + return max(2, total_samples) + + def create_action_dict_list( + self, + xpos_list: List[np.ndarray], + qpos_list: List[np.ndarray], + ee_state: float = 0.0, + ) -> List[Dict]: + """Constructs a list of actions based on the given end effector poses on agent base coordinates and joint positions. + + Args: + xpos_list (List[np.ndarray]): A list of end effector poses. + qpos_list (List[np.ndarray]): A list of joint positions. + ee_state (float, optional): The state of the end effector (e.g., open or closed). Defaults to 0.0. + + Returns: + List[Dict]: A list of actions, where each action contains: + - "ef_pose": The end effector pose at the step. + - "qpos": The joint positions corresponding to the step. + - "ee_state": The state of the end effector (e.g., open or closed). + """ + # Check if xpos_list or qpos_list is None + if xpos_list is None or qpos_list is None: + return [] + + # Check if xpos_list and qpos_list have the same length + if len(xpos_list) != len(qpos_list): + logger.log_warning("The xpos_list and qpos_list must have the same length.") + return [] + + action_list = [ + { + "ef_pose": xpos_list[i], + "qpos": qpos_list[i], + "ee_state": ee_state, + } + for i in range(len(xpos_list)) + ] + + return action_list + + def create_back_action_list( + self, + start_xpos: np.ndarray = None, + is_move_linear: bool = False, + qpos_seed: np.ndarray = None, + lift_height: float = 0.25, + reference_xpos: np.ndarray = None, + back_distance_z: float = 0.02, + traj_num: int = 20, + **kwargs, + ) -> List[Dict]: + r"""Generate a list of actions for the robot to move back to its initial joint position after completing a task. + + Args: + start_xpos (np.ndarray, optional): The starting position of the end effector (EE) in agent base coordinates, + represented as a 4x4 transformation matrix. If None, the agent's current EE + position is used. Defaults to None. + is_move_linear (bool, optional): True for linear movement and False for joint space interpolation. Defaults to False. + qpos_seed (np.ndarray, optional): Qpos seed for solving Inverse Kinematics (IK). Defaults to None, which uses the current Qpos. + lift_height (float, optional): The vertical distance the EE should be lifted. Defaults to 0.25 meters. + reference_xpos (np.ndarray, optional): An optional reference position used to compute the back path. If None, the path will be a simple lift and return. Defaults to None. + back_distance_z (float, optional): Distance to offset reference_xpos in the -z direction. Defaults to 0.02. + traj_num (int, optional): The number of discrete steps (trajectory points) to generate for the move back action. + More steps result in a smoother trajectory. Defaults to 20. + **kwargs: Additional parameters for further customization. + + Returns: + List[Dict]: A list of actions, where each action is represented as a dictionary containing: + - "ef_pose": The end effector pose at the step. + - "qpos": The joint positions corresponding to the step. + - "ee_state": The state of the end effector (e.g., open or closed). + + Note: + - The initial joint position ('init_qpos') is the home configuration of the robot's joints, representing the agent's Qpos + at the start or in a safe/rest position. It serves as a fallback in case no valid IK solutions are found for the lifting position, + ensuring that the robot can still return to its last known configuration before the task. + """ + if start_xpos is None: + start_xpos = self.agent.get_current_xpos( + name=self.agent_uid, is_world_coordinates=False + ) + + if reference_xpos is None: + lift_xpos = self.apply_transform( + xpos=start_xpos, lift_vector=[0.0, 0.0, lift_height], is_local=False + ) + back_path = [start_xpos, lift_xpos] + else: + z_back_xpos = self._compute_offset_xpos( + start_xpos, reference_xpos, back_distance_z + ) + lift_xpos = self.apply_transform( + xpos=z_back_xpos, lift_vector=[0.0, 0.0, lift_height], is_local=False + ) + back_path = [start_xpos, z_back_xpos, lift_xpos] + + back_qpos_path = [] + + for p in back_path: + res, qpos = self.drive_controller.get_ik(p, qpos_seed) + if res: + back_qpos_path.append(qpos) + + init_qpos = self.agent.get_current_qpos(self.agent_uid) + if back_qpos_path: + back_qpos_path.append(init_qpos) + else: + back_qpos_path = [init_qpos, init_qpos] + + qpos_list, xpos_list = self.drive_controller.create_discrete_trajectory( + qpos_list=back_qpos_path, + is_use_current_qpos=False, + is_linear=is_move_linear, + sample_num=traj_num, + qpos_seed=qpos_seed, + ) + + action_list = self.create_action_dict_list( + xpos_list=xpos_list, + qpos_list=qpos_list, + ee_state=self.end_effector.open_state, + ) + if not action_list: + logger.log_warning( + "Create approach action list failed. Please check the approach path!" + ) + + return action_list + + def supplyment_action_data( + self, action_list: Dict, connected_qpos: np.ndarray = None + ) -> Dict: + r"""Supplement the action data for a DualManipulator agent. + + This function checks if the agent is a DualManipulator and determines the + appropriate end effector index based on the end effector's unique identifier. + It retrieves the current open states of the end effectors and updates the + provided action list with new joint positions and end effector states. + + Args: + action_list (Dict): A list of actions to be modified, where each action + contains 'qpos' for joint positions and 'ee_state' for end effector states. + connected_qpos (connected_qpos): + + Returns: + Dict: The modified action list with updated joint positions and end + ` effector states. If the agent is not a DualManipulator or if the end + effector UID is invalid, the original action list is returned. + """ + if self.agent.__class__.__name__ != "DualManipulator": + return action_list + + # TODO: Does the number here really correspond to the ee obtained by step_action? + if "Left" in self.end_effector_uid: + ee_idx = 0 + elif "Right" in self.end_effector_uid: + ee_idx = 1 + else: + logger.log_warning("There is no left or right gripper, no processing.") + return action_list + + all_ee_state = np.array([]) + # TODO: Here we assume that the results are obtained in the order of left and then right. + ee_list = self.env.get_end_effector() + for ee in ee_list: + ee_open_state = ee.get_open_state() + all_ee_state = np.append(all_ee_state, ee_open_state) + + if connected_qpos is None: + current_qpos = self.agent.get_current_qpos("DualManipulator") + else: + current_qpos = connected_qpos + + target_joint_ids = self.agent.get_joint_ids(self.agent_uid) + + left_current_xpos = self.agent.get_current_xpos("LeftManipulator") + right_current_xpos = self.agent.get_current_xpos("RightManipulator") + + all_xpos = np.array([left_current_xpos, right_current_xpos]) + + for action in action_list: + new_qpos = np.copy(current_qpos) + new_qpos[target_joint_ids] = action["qpos"] + action["qpos"] = new_qpos + + all_xpos[ee_idx] = action["ef_pose"] + action["ef_pose"] = all_xpos + + new_ee_state = np.copy(all_ee_state) + new_ee_state[ee_idx] = action["ee_state"] + action["ee_state"] = new_ee_state + + return action_list + + def merge_action_data( + self, left_action_list: Dict, right_action_list: Dict + ) -> Dict: + r"""Merge action data from left and right action lists. + + This function is designed to combine action data from two separate action + lists (left and right) into a single unified action list. The implementation + details for merging the action lists will depend on the specific requirements + of the application. + + Args: + left_action_list (Dict): A dictionary containing actions for the left end effector. + right_action_list (Dict): A dictionary containing actions for the right end effector. + + Returns: + Dict: A merged dictionary containing combined action data from both + left and right action lists. The exact structure of the returned + dictionary will depend on the merging logic implemented in this method. + """ + merged_action_list = [] + if self.agent.__class__.__name__ != "DualManipulator": + return merged_action_list + + current_qpos = self.agent.get_current_qpos("DualManipulator") + # Assuming both action lists have the same length + for left_action, right_action in zip(left_action_list, right_action_list): + merged_action = {} + + # Get joint IDs for left and right actions + left_joint_ids = self.agent.get_joint_ids("LeftManipulator") + right_joint_ids = self.agent.get_joint_ids("RightManipulator") + + # Initialize new qpos and ee_state for the merged action + new_qpos = np.zeros( + len(current_qpos) + ) # Assuming total count includes both left and right + + # Set joint positions based on left action + new_qpos[left_joint_ids] = left_action["qpos"][left_joint_ids] + + # Set joint positions based on right action + new_qpos[right_joint_ids] = right_action["qpos"][right_joint_ids] + + # Set end effector states + new_ee_state = [] # Assuming two end effectors: left and right + new_ee_state.extend(left_action["ee_state"]) + new_ee_state.extend(right_action["ee_state"]) + + # Construct the merged action + merged_action["qpos"] = new_qpos + merged_action["ee_state"] = new_ee_state + + # Append the merged action to the list + merged_action_list.append(merged_action) + + return merged_action_list diff --git a/embodichain/lab/gym/motion_generation/planner/toppra_planner.py b/embodichain/lab/gym/motion_generation/planner/toppra_planner.py new file mode 100644 index 00000000..0fc2b79c --- /dev/null +++ b/embodichain/lab/gym/motion_generation/planner/toppra_planner.py @@ -0,0 +1,258 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +from embodichain.lab.gym.motion_generation.planner.utils import ( + TrajectorySampleMethod, +) + +from typing import TYPE_CHECKING, Union + +try: + import toppra as ta + import toppra.constraint as constraint +except ImportError: + raise ImportError("toppra not installed. Install with `pip install toppra==0.6.3`") + +ta.setup_logging(level="WARN") + + +class ToppraPlanner: + def __init__(self, DOFs, max_constraints): + r"""Initialize the TOPPRA trajectory planner. + + Args: + DOFs: Number of degrees of freedom + max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints + """ + + self.DOFs = DOFs + self.time_step = 0.01 + self.max_constraints = max_constraints + + # Create TOPPRA constraints + self.vlims = np.array([[-v, v] for v in max_constraints["velocity"]]) + self.alims = np.array([[-a, a] for a in max_constraints["acceleration"]]) + + def plan( + self, + current_state: dict, + target_states: list[dict], + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.TIME, + sample_interval: Union[float, int] = 0.01, + ): + r"""Execute trajectory planning. + + Args: + current_state: Dictionary containing 'position', 'velocity', 'acceleration' for current state + target_states: List of dictionaries containing target states + + Returns: + Tuple of (success, positions, velocities, accelerations, times, duration) + """ + if not isinstance(sample_interval, (float, int)): + raise TypeError( + f"sample_interval must be float/int, got {type(sample_interval)}" + ) + if sample_method == TrajectorySampleMethod.TIME and sample_interval <= 0: + raise ValueError("Time interval must be positive") + elif sample_method == TrajectorySampleMethod.QUANTITY and sample_interval < 2: + raise ValueError("At least 2 sample points required") + + # Check waypoints + if len(current_state["position"]) != self.DOFs: + print(f"Current wayponit does not align") + return False, None, None, None, None, None + for target in target_states: + if len(target["position"]) != self.DOFs: + print(f"Target Wayponits does not align") + return False, None, None, None, None, None + + if ( + len(target_states) == 1 + and np.sum( + np.abs( + np.array(target_states[0]["position"]) + - np.array(current_state["position"]) + ) + ) + < 1e-3 + ): + print(f"Only two same waypoints, do not plan") + return ( + True, + np.array([current_state["position"], target_states[0]["position"]]), + np.array([[0.0] * self.DOFs, [0.0] * self.DOFs]), + np.array([[0.0] * self.DOFs, [0.0] * self.DOFs]), + 0, + 0, + ) + + # Build waypoints + waypoints = [np.array(current_state["position"])] + for target in target_states: + waypoints.append(np.array(target["position"])) + waypoints = np.array(waypoints) + + # Create spline interpolation + # NOTE(fsh):适合密集的点 + ss = np.linspace(0, 1, len(waypoints)) + + # NOTE(fsh):适合稀疏的点,密集点容易不满足CubicSpline严格递增的条件 + # len_total = 0 + # len_from_start = [0] + # for i in range(len(waypoints)-1): + # len_total += np.sum(np.abs(waypoints[i+1] - waypoints[i])) + # len_from_start.append(len_total) + # ss = np.array([cur/len_total for cur in len_from_start]) + + path = ta.SplineInterpolator(ss, waypoints) + + # Set constraints + pc_vel = constraint.JointVelocityConstraint(self.vlims) + pc_acc = constraint.JointAccelerationConstraint(self.alims) + + # Create TOPPRA instance + instance = ta.algorithm.TOPPRA( + [pc_vel, pc_acc], + path, + parametrizer="ParametrizeConstAccel", + gridpt_min_nb_points=max(100, 10 * len(waypoints)), + ) + # NOTES:合理设置gridpt_min_nb_points对加速度约束很重要 + + # Compute parameterized trajectory + jnt_traj = instance.compute_trajectory() + if jnt_traj is None: + # raise RuntimeError("Unable to find feasible trajectory") + print(f"Unable to find feasible trajectory") + return False, None, None, None, None, None + + duration = jnt_traj.duration + # Sample trajectory points + if duration <= 0: + raise ValueError(f"Duration must be positive, got {duration}") + if sample_method == TrajectorySampleMethod.TIME: + n_points = max(2, int(np.ceil(duration / sample_interval)) + 1) + ts = np.linspace(0, duration, n_points) + else: + ts = np.linspace(0, duration, num=int(sample_interval)) + + positions = [] + velocities = [] + accelerations = [] + + for t in ts: + positions.append(jnt_traj.eval(t)) + velocities.append(jnt_traj.evald(t)) + accelerations.append(jnt_traj.evaldd(t)) + + return ( + True, + np.array(positions), + np.array(velocities), + np.array(accelerations), + ts, + duration, + ) + + def is_satisfied_constraint(self, velocities, accelerations) -> bool: + r"""Check if the trajectory satisfies velocity and acceleration constraints. + + Args: + velocities: array + accelerations: array + """ + # NOTE(fsh):密集点过多的情况下,当前实现容易求解无法严格满足约束,会有一定的越界 + vlims = self.vlims * (1 + 0.1) # 允许10%误差 + alims = self.alims * (1 + 0.25) # 允许25%误差 + + vel_check = np.all((velocities >= vlims[:, 0]) & (velocities <= vlims[:, 1])) + acc_check = np.all( + (accelerations >= alims[:, 0]) & (accelerations <= alims[:, 1]) + ) + + # 超限情况 + if not vel_check: + vel_exceed_info = [] + min_vel = np.min(velocities, axis=0) + max_vel = np.max(velocities, axis=0) + for i in range(self.DOFs): + exceed_percentage = 0 + if min_vel[i] < self.vlims[i, 0]: + exceed_percentage = (min_vel[i] - self.vlims[i, 0]) / self.vlims[ + i, 0 + ] + if max_vel[i] > self.vlims[i, 1]: + temp = (max_vel[i] - self.vlims[i, 1]) / self.vlims[i, 1] + if temp > exceed_percentage: + exceed_percentage = temp + vel_exceed_info.append(exceed_percentage * 100) + print(f"Velocity exceed info: {vel_exceed_info} percentage") + + if not acc_check: + acc_exceed_info = [] + min_acc = np.min(accelerations, axis=0) + max_acc = np.max(accelerations, axis=0) + for i in range(self.DOFs): + exceed_percentage = 0 + if min_acc[i] < self.alims[i, 0]: + exceed_percentage = (min_acc[i] - self.alims[i, 0]) / self.alims[ + i, 0 + ] + if max_acc[i] > self.alims[i, 1]: + temp = (max_acc[i] - self.alims[i, 1]) / self.alims[i, 1] + if temp > exceed_percentage: + exceed_percentage = temp + acc_exceed_info.append(exceed_percentage * 100) + print(f"Acceleration exceed info: {acc_exceed_info} percentage") + + return vel_check and acc_check + + def plot_trajectory(self, positions, velocities, accelerations): + r"""Plot trajectory data. + + Args: + positions: Position array + velocities: Velocity array + accelerations: Acceleration array + """ + import matplotlib.pyplot as plt + + time_steps = np.arange(positions.shape[0]) * self.time_step + fig, axs = plt.subplots(3, 1, figsize=(10, 8)) + + for i in range(self.DOFs): + axs[0].plot(time_steps, positions[:, i], label=f"Joint {i+1}") + axs[1].plot(time_steps, velocities[:, i], label=f"Joint {i+1}") + axs[2].plot(time_steps, accelerations[:, i], label=f"Joint {i+1}") + + axs[1].plot( + time_steps, + [self.vlims[0][0]] * len(time_steps), + "k--", + label="Max Velocity", + ) + axs[1].plot(time_steps, [self.vlims[0][1]] * len(time_steps), "k--") + axs[2].plot( + time_steps, + [self.alims[0][0]] * len(time_steps), + "k--", + label="Max Accleration", + ) + axs[2].plot(time_steps, [self.alims[0][1]] * len(time_steps), "k--") + + axs[0].set_title("Position") + axs[1].set_title("Velocity") + axs[2].set_title("Acceleration") + + for ax in axs: + ax.set_xlabel("Time [s]") + ax.legend() + ax.grid() + + plt.tight_layout() + plt.show() diff --git a/embodichain/lab/gym/motion_generation/planner/utils.py b/embodichain/lab/gym/motion_generation/planner/utils.py new file mode 100644 index 00000000..b68a9ae3 --- /dev/null +++ b/embodichain/lab/gym/motion_generation/planner/utils.py @@ -0,0 +1,43 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from enum import Enum +from typing import Union + + +class TrajectorySampleMethod(Enum): + r"""Enumeration for different trajectory sampling methods. + + This enum defines various methods for sampling trajectories, + providing meaningful names for different sampling strategies. + """ + + TIME = "time" + """Sample based on time intervals.""" + + QUANTITY = "quantity" + """Sample based on a specified number of points.""" + + DISTANCE = "distance" + """Sample based on distance intervals.""" + + @classmethod + def from_str( + cls, value: Union[str, "TrajectorySampleMethod"] + ) -> "TrajectorySampleMethod": + if isinstance(value, cls): + return value + try: + return cls[value.upper()] + except KeyError: + valid_values = [e.name for e in cls] + raise ValueError( + f"Invalid version '{value}'. Valid values are: {valid_values}" + ) + + def __str__(self): + """Override string representation for better readability.""" + return self.value.capitalize() diff --git a/embodichain/lab/gym/robots/interface.py b/embodichain/lab/gym/robots/interface.py new file mode 100644 index 00000000..c41b1b4c --- /dev/null +++ b/embodichain/lab/gym/robots/interface.py @@ -0,0 +1,243 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import pytorch_kinematics as pk + +from abc import abstractmethod +from typing import List, Dict, Tuple, Union + +from gymnasium import spaces + +from embodichain.data.enum import ControlParts, EndEffector, JointType +from embodichain.lab.sim.objects import Robot +from embodichain.utils import logger +from embodichain.data.enum import JointType, EefType, ActionMode + + +class LearnableRobot(Robot): + """The interface class for the learnable robot agent. + + There are three types of actions should be explained: + - Real robot actions: The actions that the real robot can execute. + - Control actions: The actions that the robot interacts with the policy. + - Environment actions: The actions that the robot executes in the simulation environment. + """ + + def get_single_action_space(self) -> spaces.Dict: + limits = self.get_joint_limits(self.uid) + low, high = limits[:, 0], limits[:, 1] + single_action_space = spaces.Dict( + { + JointType.QPOS.value: spaces.Box(low=low, high=high, dtype=np.float32), + } + ) + + return single_action_space + + def step_env_action(self, action: Dict): + qpos = action[JointType.QPOS.value] + + self.set_current_qpos(self.uid, qpos) + return action + + def get_debug_xpos_dict( + self, + ) -> Dict[str, np.ndarray]: + """Get the debug xpos list.""" + return {} + + def get_data_index(self, name: str, warning: bool = True) -> List[int]: + """ + Get the data index for the control part. Subclasses must implement the index_map attribute. + + Args: + name (str): The name of the control part. + warning (bool, optional): Whether to log a warning if the control part is not supported. Defaults to True. + + Returns: + List[int]: The list of indices for the control part. Returns an empty list if not found. + + Raises: + NotImplementedError: If the subclass does not define the index_map attribute. + """ + if not hasattr(self, "index_map"): + raise NotImplementedError("Subclasses must define the index_map attribute.") + if name in self.index_map: + return self.index_map[name] + else: + if warning: + logger.log_warning(f"Control part {name} is not supported.") + return [] + + def map_ee_state_to_env_actions( + self, ee_state: np.ndarray, env_actions: np.ndarray + ) -> np.ndarray: + """Map the end-effector state to the environment actions of robot agent. + + Args: + ee_state (np.ndarray): The end-effector state. + env_actions (np.ndarray): The environment actions of the robot agent. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + return env_actions + + def map_real_actions_to_control_actions(self, actions: np.ndarray) -> np.ndarray: + """Map the real robot actions to the control actions of robot agent, which + should has the same dimension. + + The control actions should be the actions that match the articulation joint limits. + + Note: + Real robot may have gap in the action compared to the simulation robot agent. The + method provides a place the process the gap. + + Args: + actions (np.ndarray): The real robot actions collected from the robot. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + return actions + + def map_control_actions_to_env_actions( + self, + actions: np.ndarray, + env_action_dim: int, + action_type: str = JointType.QPOS.value, + ) -> np.ndarray: + """Map the control actions to the environment actions of robot agent. + + Args: + actions (np.ndarray): The control actions. + env_action_dim (int): The dimension of the environment action space. + action_type (str, optional): The type of action. Defaults to JointType.QPOS.value. + + Returns: + np.ndarray: The environment actions of the robot agent. + """ + control_index = self.get_data_index(self.uid) + if action_type != EefType.POSE.value and actions.shape[1] != len(control_index): + logger.log_error( + f"The policy action dimension {actions.shape[1]} does not match the control index dimension {len(control_index)}." + ) + + length = len(actions) + env_actions = np.zeros((length, env_action_dim)) + if action_type == JointType.QPOS.value: + env_actions[:, control_index] = actions + elif action_type == EefType.POSE.value: + # TODO: the eef state is also mapped in this function, which should be separated. + env_actions = self.map_eef_pose_to_env_qpos(actions, env_actions) + else: + logger.log_error(f"Invalid action type: {action_type}") + + return env_actions + + def map_env_qpos_to_eef_pose( + self, env_qpos: np.ndarray, to_dict: bool = False, ret_mat: bool = False + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Map environment joint positions to end-effector pose representation. + + Args: + env_qpos (np.ndarray): Joint positions from the environment, shape [batch, total_dof]. + + Returns: + np.ndarray: End-effector pose, shape [batch, 18]. + [left_pos(3), left_x(3), left_y(3), right_pos(3), right_x(3), right_y(3)] + """ + num_pose = env_qpos.shape[0] + left_indices = self.get_joint_ids( + ControlParts.LEFT_ARM.value + JointType.QPOS.value + ) + right_indices = self.get_joint_ids( + ControlParts.RIGHT_ARM.value + JointType.QPOS.value + ) + + left_env_qpos = torch.as_tensor(env_qpos[:, left_indices], dtype=torch.float32) + right_env_qpos = torch.as_tensor( + env_qpos[:, right_indices], dtype=torch.float32 + ) + + left_ret = ( + self.pk_serial_chain[ControlParts.LEFT_ARM.value + JointType.QPOS.value] + .forward_kinematics(left_env_qpos, end_only=True) + .get_matrix() + ) + right_ret = ( + self.pk_serial_chain[ControlParts.RIGHT_ARM.value + JointType.QPOS.value] + .forward_kinematics(right_env_qpos, end_only=True) + .get_matrix() + ) + + eef_pose = np.zeros((num_pose, 18)) + eef_pose[..., :3] = left_ret[..., :3, 3] + eef_pose[..., 3:6] = left_ret[..., :3, 0] + eef_pose[..., 6:9] = left_ret[..., :3, 1] + eef_pose[..., 9:12] = right_ret[..., :3, 3] + eef_pose[..., 12:15] = right_ret[..., :3, 0] + eef_pose[..., 15:18] = right_ret[..., :3, 1] + + if to_dict: + from embodichain.data.enum import EefType + + if not ret_mat: + return { + ControlParts.LEFT_ARM.value + EefType.POSE.value: eef_pose[..., :9], + ControlParts.RIGHT_ARM.value + + EefType.POSE.value: eef_pose[..., 9:], + } + else: + return { + ControlParts.LEFT_ARM.value + EefType.POSE.value: left_ret, + ControlParts.RIGHT_ARM.value + EefType.POSE.value: right_ret, + } + else: + return eef_pose + + def map_eef_pose_to_env_qpos( + self, eef_pose: np.ndarray, env_qpos: np.ndarray + ) -> np.ndarray: + """Map the end-effector pose to the environment actions. + + Args: + eef_pose (np.ndarray): The end-effector pose. + env_qpos (np.ndarray): The env qpos to be mapped. + + Returns: + np.ndarray: The environment actions. + """ + return env_qpos + + def clip_env_qpos(self, env_qpos: np.ndarray) -> np.ndarray: + """Clip the environment qpos based on the robot joint limits. + + Args: + env_qpos (np.ndarray): The environment qpos to be clipped. + + Returns: + np.ndarray: The clipped environment qpos. + """ + limits = self.get_joint_limits(self.uid) + low, high = limits[:, 0], limits[:, 1] + env_qpos = np.clip(env_qpos, low, high) + return env_qpos + + @staticmethod + def build_pk_serial_chain(**kwargs) -> Dict[str, pk.SerialChain]: + """Build the serial chain from the URDF file. + + Args: + **kwargs: Additional arguments for building the serial chain. + + Returns: + Dict[str, pk.SerialChain]: The serial chain of the robot. + """ + return {} diff --git a/embodichain/lab/gym/structs/__init__.py b/embodichain/lab/gym/structs/__init__.py new file mode 100644 index 00000000..a240d034 --- /dev/null +++ b/embodichain/lab/gym/structs/__init__.py @@ -0,0 +1 @@ +from .object import Object diff --git a/embodichain/lab/gym/structs/object.py b/embodichain/lab/gym/structs/object.py new file mode 100644 index 00000000..08a8785a --- /dev/null +++ b/embodichain/lab/gym/structs/object.py @@ -0,0 +1,311 @@ +import numpy as np +import open3d as o3d +import os +import hashlib + +from typing import List, Dict, Union + +from embodichain.data import get_data_path +from embodichain.utils import logger + +from embodichain.toolkits.processor.function.mesh_processor.base import ( + MeshProcessorList, +) +from embodichain.toolkits.graspkit.pg_grasp import ( + AntipodalGenerator, +) + + +def generate_pickpose_sampler( + file_name: str, mesh: o3d.t.geometry.TriangleMesh, params: dict +) -> None: + logger.log_info(f"Generating object mesh {file_name} pick poses.") + if os.path.exists(file_name): + try: + if os.path.exists(file_name) and not file_name.endswith(".dae"): + pickpose_sampler = AntipodalGenerator( + mesh, + **params, + unique_id=hashlib.md5(file_name.encode()).hexdigest(), + ) + elif file_name.endswith(".dae"): + pickpose_sampler = None + else: + logger.log_warning( + f"Failed to build AntipodalGenerator because {file_name} is invalid!" + ) + except Exception as e: + logger.log_warning(f"Failed to build AntipodalGenerator: {str(e)}") + else: + logger.log_warning( + f"Failed to build AntipodalGenerator cause {file_name} unvalid!" + ) + return pickpose_sampler + + +# TODO: We should refactor the Object to support Group object design. +class Object: + name: str = "Undefined" + description: str = "Undefined" + pick_poses: List[np.ndarray] = None + pose: np.ndarray + parts: List["Object"] + articulations: Dict[str, np.ndarray] + mesh: o3d.geometry.TriangleMesh + scale: Union[List, np.ndarray] = [1, 1, 1] + mesh_file: str # to be depracted. + active_state: bool = False + unit: str = "m" + pickpose_sampler: AntipodalGenerator = None + pickpose_sampler_params: dict = None + # for object group + folder_path: str + mesh_processor: MeshProcessorList = None + + def get_mesh_file(self): + if hasattr(self, "mesh_file"): + return self.mesh_file + else: + obj_cad_files = [ + file + for file in os.listdir(self.folder_path) + if file.startswith("mesh_processed_") is False + ] + + target_file = np.random.choice(obj_cad_files) + + return self.select_mesh_file_from_folder(target_file) + + def select_mesh_file_from_folder(self, target_file: str): + from embodichain.toolkits.processor.component import TriangleComponent + from embodichain.toolkits.processor.entity import MeshEntity + + cache_fpath = os.path.join(self.folder_path, f"mesh_processed_{target_file}") + if os.path.exists(cache_fpath) is False: + tri_component = TriangleComponent.from_fpath( + os.path.join(self.folder_path, target_file) + ) + mesh_entity = MeshEntity("mesh", tri_component) + mesh = self.mesh_processor.apply([mesh_entity])[0] + mesh.save_mesh(cache_fpath) + + if self.pickpose_sampler_params is not None: + mesh = o3d.t.io.read_triangle_mesh(cache_fpath) + self.pickpose_sampler = generate_pickpose_sampler( + cache_fpath, mesh, self.pickpose_sampler_params + ) + + return cache_fpath + + @staticmethod + def from_folder(path: str, obj_data: dict) -> "Object": + from embodichain.toolkits.processor.function.mesh_processor import ( + build_mesh_processors, + ) + + obj = Object() + obj.folder_path = path + obj.description = obj_data.get("description", "Undefined") + obj.name = obj_data.get("name", "Undefined") + obj.unit = obj_data.get("unit", "m") + obj.pickpose_sampler_params = obj_data.get("auto_pickpose_generator", None) + obj.pose = obj_data.get("pose", np.eye(4)) + mesh_processor_config = obj_data.get("mesh_processor", None) + if mesh_processor_config is not None: + obj.mesh_processor = build_mesh_processors(mesh_processor_config) + + return obj + + @staticmethod + def from_mesh(path: str, downsample: bool = False, local_dir: str = "") -> "Object": + r"""Create an Object instance from a mesh file. + + Args: + path (str): The file path to the mesh key, value in obj_related[k].items(): + setattr(obj, key, value) + objs.append(obj) + + # Order the objects based file. + downsample (bool, optional): Whether to downsample the mesh. Defaults to False. + + Returns: + Object: An `Object` instance containing the mesh data. + """ + obj = Object() + + if not os.path.exists(path): + if local_dir is not None and local_dir != "": + path = os.path.join(local_dir, path) + else: + path = get_data_path(path) + + mesh = o3d.io.read_triangle_mesh(path) + if downsample: + obj.mesh = mesh.simplify_quadric_decimation(target_number_of_triangles=4000) + else: + obj.mesh = mesh + obj.mesh_file = path + return obj + + @staticmethod + def from_urdf(path: str) -> List["Object"]: + r"""Create a list of Object instances from a URDF file. + + This method reads a URDF (Unified Robot Description Format) file, extracts + the geometry data, and returns a list of `Object` instances representing + the visual elements described in the URDF. + + Args: + path (str): The file path to the URDF file. + + Returns: + List[Object]: A list of `Object` instances representing the visual elements. + """ + import pinocchio + import copy + + data_path = copy.deepcopy(path) + if not os.path.exists(data_path): + data_path = get_data_path(path) + package_dirs = [os.path.dirname(data_path)] + + model, collision_model, visual_model = pinocchio.buildModelsFromUrdf( + data_path, package_dirs=package_dirs + ) + + urdf_dir = os.path.dirname(data_path) + objs = [] + + # Parse the geometry data from URDF + for geom in visual_model.geometryObjects.tolist(): + if hasattr(geom, "meshPath"): + mesh_path = geom.meshPath + if not os.path.isabs(mesh_path): + mesh_path = os.path.join(urdf_dir, mesh_path) + obj = Object.from_mesh(mesh_path) + obj.name = geom.name + objs.append(obj) + + return objs + + @staticmethod + def _save_mesh_or_urdf(obj: "Object", new_file_name: str): + r"""Save the mesh or URDF file with error handling. + + Args: + obj (Object): The object containing the mesh or URDF data. + new_file_name (str): The new file path where the data will be saved. + """ + try: + if new_file_name.endswith(".urdf"): + obj.save_as_urdf(new_file_name) + else: + o3d.io.write_triangle_mesh(new_file_name, obj.mesh) + except Exception as e: + logger.log_error(f"Failed to save the file {new_file_name}: {str(e)}") + + @staticmethod + def _generate_new_filename(file_path: str, extension: str) -> str: + r"""Generate a new filename with a specified extension. + + Args: + file_path (str): The original file path. + extension (str): The new extension to append to the filename. + + Returns: + str: The generated file path with the new extension. + """ + _, file_extension = os.path.splitext(file_path) + return os.path.join( + os.path.dirname(file_path), + os.path.basename(file_path).split(".")[0] + f"_{extension}{file_extension}", + ) + + @staticmethod + def _apply_common_settings(obj: "Object", obj_data: dict, file_path: str): + r"""Apply common settings such as unit conversion and pose generation to the object. + + Args: + obj (Object): The object to which settings will be applied. + obj_data (dict): Dictionary containing object-specific configuration data. + file_path (str): Object file path. + """ + if "unit" in obj_data and obj_data["unit"] == "mm": + obj.mesh.scale(1e-3, center=np.zeros((3))) + obj.unit = "m" + obj_data.pop("unit") + new_file_name = Object._generate_new_filename(obj.mesh_file, extension="m") + Object._save_mesh_or_urdf(obj, new_file_name) + obj.mesh_file = new_file_name + + obj.scale = obj_data.get("scale", [1, 1, 1]) + from dexsim.utility.meshproc import scale_trianglemesh + + obj.mesh = scale_trianglemesh(obj.mesh, obj.scale) + + if "auto_pickpose_generator" in obj_data: + mesh = o3d.t.geometry.TriangleMesh.from_legacy(obj.mesh) + obj.pickpose_sampler = generate_pickpose_sampler( + obj.mesh_file, mesh, obj_data["auto_pickpose_generator"] + ) + + for key, value in obj_data.items(): + setattr(obj, key, value) + + @staticmethod + def from_config( + path: Union[str, Dict], downsample: bool = False, local_dir: str = "" + ) -> List["Object"]: + r"""Create a list of Object instances from a configuration file. + + Args: + path (Union[str, Dict]): The file path to the configuration file or a dictionary containing the configuration. + downsample (bool, optional): Whether to downsample the mesh. Defaults to False. + + Returns: + List[Object]: A list of `Object` instances as specified in the configuration file. + """ + from embodichain.utils.utility import load_json + + if isinstance(path, str): + config = load_json(path) + else: + config = path + + obj_related = config["obj_list"] + objs = [] + + for k, obj_data in obj_related.items(): + if obj_data.get("mesh_file", None) is not None: + file = obj_data.pop("mesh_file") + file_ext = os.path.splitext(file)[-1].lower() + + if local_dir is not None and local_dir != "": + data_file_path = os.path.join(local_dir, file) + else: + data_file_path = get_data_path(file) + + if file_ext == ".urdf": + urdf_objs = Object.from_urdf(file) + for obj in urdf_objs: + Object._apply_common_settings(obj, obj_data, data_file_path) + objs.extend(urdf_objs) + else: + obj = Object.from_mesh(file, downsample, local_dir=local_dir) + Object._apply_common_settings(obj, obj_data, data_file_path) + objs.append(obj) + else: + folder_path = obj_data.get("folder_path", None) + if folder_path is None: + logger.log_error( + f"Object configuration {k} does not contain a valid mesh file or folder path." + ) + obj = Object.from_folder(folder_path, obj_data) + objs.append(obj) + + # TODO: to be improved. + if len(objs) == len(obj_related): + order = [int(k) for k in obj_related.keys()] + return [objs[i] for i in np.argsort(order)] + else: + return objs diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 40ced87f..b6aee4a0 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -437,6 +437,9 @@ class ComponentCfg: env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) + # load dataset config + env_cfg.dataset = config["env"].get("dataset", None) + # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. env_cfg.dataset = ComponentCfg() @@ -471,6 +474,7 @@ class ComponentCfg: "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", "embodichain.lab.gym.envs.managers.events", + "embodichain.lab.gym.envs.managers.real2sim", ] # parser env events config diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py index b75b70af..9f839491 100644 --- a/embodichain/lab/gym/utils/misc.py +++ b/embodichain/lab/gym/utils/misc.py @@ -28,7 +28,7 @@ from collections import OrderedDict from importlib import import_module from scipy.spatial.transform import Rotation as R -from typing import Any, Dict, List, Tuple, Union, Sequence, Callable, Mapping +from typing import Any, Dict, List, Tuple, Union, Sequence, Callable, Mapping, Optional import numpy as np @@ -757,6 +757,13 @@ def resolve_env_attr(obj: Any, env: Any) -> Any: _EXPR = re.compile(r"\$\{([^}]+)\}") # For searching ${...} marker +def is_binocularcam(sensor): + from dexsim.sensor import BinocularCam + from embodichain.lab.sim.sensors import StereoCamera + + return isinstance(sensor, BinocularCam) or isinstance(sensor, StereoCamera) + + def resolve_formatted_string(obj, local_vars=None, global_vars=None): """Given a dict carrys "${...}"-like strings , `eval` the "${...}$" values while keep the dict structure. @@ -1382,3 +1389,43 @@ def is_stereocam(sensor) -> bool: from embodichain.lab.sim.sensors import StereoCamera return isinstance(sensor, StereoCamera) + + +def _data_key_to_control_part(robot, control_parts, data_key: str) -> Optional[str]: + # TODO: Temporary workaround, should be removed after refactoring data dict extractor. + # @lru_cache(max_size=None) # NOTE: no way to pass a hashable parameter + def is_eef_hand_func(robot, control_parts) -> bool: + # TODO: This is a temporary workaround, should be used a more general method to check + # whether the end-effector is a hand. + for part in control_parts: + if "eef" in part: + joint_ids = robot.get_joint_ids(part, remove_mimic=True) + return len(joint_ids) >= 2 + return False + + from embodichain.data.enum import ( + ControlParts, + EndEffector, + ActionMode, + JointType, + EefType, + ) + + is_eef_hand = is_eef_hand_func(robot, control_parts) + + for control_part in ControlParts: + if EndEffector.DEXTROUSHAND.value in data_key and is_eef_hand: + return data_key.replace(EndEffector.DEXTROUSHAND.value, "") + elif EndEffector.DEXTROUSHAND.value in data_key and not is_eef_hand: + continue + elif EndEffector.GRIPPER.value in data_key and not is_eef_hand: + return data_key.replace(EndEffector.GRIPPER.value, "") + elif EndEffector.GRIPPER.value in data_key and is_eef_hand: + continue + elif ActionMode.RELATIVE.value + JointType.QPOS.value in data_key: + continue + elif EefType.POSE.value in data_key: + continue + elif control_part.value in data_key: + return control_part.value + return None diff --git a/embodichain/lab/scripts/generate_video.py b/embodichain/lab/scripts/generate_video.py new file mode 100644 index 00000000..d0a89d7c --- /dev/null +++ b/embodichain/lab/scripts/generate_video.py @@ -0,0 +1,175 @@ +from embodichain.utils.logger import log_info, log_warning + +try: + import h5ffmpeg as hf +except Exception as e: + log_warning("Fail to import h5ffmpeg.") +import h5py +import argparse +import numpy as np +import os +from tqdm import tqdm +from dexsim.utility import images_to_video +from typing import Dict, Callable, Tuple +from embodichain.utils.visualizer import draw_keypoints, draw_action_distribution +from embodichain.data.enum import EefType, JointType, Modality, PrivilegeType +from embodichain.data.data_engine.indices_unifier import ActionIndicesGenerator + + +class VideoCreator: + def __init__(self) -> None: + pass + + @staticmethod + def _sub_function( + images, + output_path, + video_key, + exteroceptions: Dict = None, + multiplier: int = 1, + drawer: Callable = lambda x: x, + ): + for key in images.keys(): + imgs = images[key] + if imgs is None: + log_warning(f"No images found for key: {key}. Skipping.") + continue + img_list = [] + for i in tqdm(range(imgs.shape[0])): + image_i = drawer(imgs[i] * multiplier) + if exteroceptions is not None and len(exteroceptions[key]) != 0: + image_i = draw_keypoints( + image_i, exteroceptions[key][i].reshape(-1, 2) + ) + img_list.append(image_i) + + images_to_video(img_list, output_path, f"{key}_{video_key}") + + @staticmethod + def monocular_save( + observations: Dict, + video_key: str, + output_path: str, + multiplier: int = 1, + drawer: Callable = lambda x: x, + draw_exteroception: bool = True, + ): + images = observations[video_key] + if ( + PrivilegeType.EXTEROCEPTION.value in observations.keys() + and draw_exteroception + ): + exteroceptions = observations[PrivilegeType.EXTEROCEPTION.value] + else: + exteroceptions = None + VideoCreator._sub_function( + images, + output_path, + video_key, + exteroceptions, + multiplier, + drawer, + ) + + +def visualize_data_dict(f: Dict, output_path: str): + observations = f["observations"] + + if PrivilegeType.MASK.value in observations.keys(): + VideoCreator.monocular_save( + observations, + PrivilegeType.MASK.value, + output_path, + 255, + draw_exteroception=False, + ) + + if Modality.GEOMAP.value in observations.keys(): + from embodichain.utils.utility_3d import gen_disp_colormap + + VideoCreator.monocular_save( + observations, + Modality.GEOMAP.value, + output_path, + 1, + lambda x: (gen_disp_colormap(x).transpose(1, 2, 0) * 255).astype(np.uint8), + draw_exteroception=False, + ) + + VideoCreator.monocular_save(observations, Modality.IMAGES.value, output_path) + + +def main(args): + + data_path = args.data_path + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + assert data_path.endswith(".hdf5"), "Data path must have format of .hdf5" + with h5py.File(data_path, "r") as f: + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + import hdfdict + + data = hdfdict.load(data_path) + data = CompressedVideoHDF5( + output_path, chunks=data["chunks"].item() + ).safe_filter(data) + + # NOTE: DO NOT USE THIS IN SCRIPT, IT IS FOR DEBUGGING PURPOSES ONLY + # slice_id = 20 + # data_copy = hdfdict.load(data_path) + # data_copy = CompressedVideoHDF5(output_path).safe_filter(data_copy, slice_id=slice_id) + + # a = data["observations"]["images"]["cam_high"][slice_id] + # b = data_copy["observations"]["images"]["cam_high"] + # print(a, b.shape) + # delta = a-b[slice_id] + # print(np.linalg.norm(delta)) + + visualize_data_dict(data, output_path) + if "robot_meta" in data.keys(): + log_warning("Simulation data.") + robot_meta = data["robot_meta"] + arm_dofs = robot_meta["arm_dofs"][()] + actions = f[Modality.ACTIONS.value][()] + else: + from embodichain.data.data_engine.datasets.sim_real_unified_dict_dataset import ( + RobotRealDataRouter, + ) + + log_warning("Real data.") + actions, _, _, _, _ = RobotRealDataRouter( + robot_name=args.robot_name + ).realdata2simdata(f) + indices_generator = ActionIndicesGenerator(arm_dofs) + + key_names = indices_generator.global_mapping.mapping_from_name_to_indices.keys() + log_info(f"Arm dofs: {arm_dofs}", color="green") + indices_dict = {} + for key_name in key_names: + indices_dict[key_name] = indices_generator.get([key_name]) + draw_action_distribution(actions, indices_dict, output_path, smooth=args.smooth) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, help="Path to the data file.") + parser.add_argument( + "--output_path", + type=str, + help="Path to the output video file.", + default="./outputs", + ) + parser.add_argument( + "--smooth", + action="store_true", + default=False, + help="whether smooth joints.", + ) + parser.add_argument("--robot_name", default="DexforceW1", type=str) + args = parser.parse_args() + + main(args) diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py new file mode 100644 index 00000000..94a8ce95 --- /dev/null +++ b/embodichain/lab/scripts/run_agent.py @@ -0,0 +1,273 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import gymnasium +import numpy as np +import argparse +import os +import torch + +from threading import Thread +from tqdm import tqdm +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.lab.scripts.generate_video import visualize_data_dict +from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, +) +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.lab.sim.cfg import MarkerCfg + + +def generate_function( + env, + time_id: int = 0, + online_training: bool = False, + save_path: str = "", + save_video: bool = False, + debug_mode: bool = False, + regenerate: bool = True, + **kwargs, +): + """ + Generate and execute a sequence of actions in the environment. + + This function resets the environment, generates and executes action trajectories, + collects data, and optionally saves videos of the episodes. It supports both online + and offline data generation modes. + + Args: + env: The environment instance. + time_id (int, optional): Identifier for the current time step or episode. + online_training (bool, optional): Whether to use online data generation. + save_path (str, optional): Path to save generated videos. + save_video (bool, optional): Whether to save episode videos. + debug_mode (bool, optional): Enable debug mode for visualization and logging. + regenerate (bool, optional): Whether enable regenerating if existed. + **kwargs: Additional keyword arguments for data generation. + + Returns: + list or bool: Returns a list of data dicts if online_training is True, + otherwise returns True if generation is successful. + """ + + def wait_for_threads(threads): + for t in threads: + t.join() + + vis_threads = [] + + while True: # repeat until success + env.reset() + + ret = [] + trajectory_idx = 0 + + # Access the wrapped environment's method + env.get_wrapper_attr("create_demo_action_list")(regenerate=regenerate) + + # --------------------------------------------------------- + # SUCCESS CASE + # --------------------------------------------------------- + if not debug_mode and env.get_wrapper_attr("is_task_success")().item(): + + dataset_id = f"time_{time_id}_trajectory_{trajectory_idx}" + + # online training: dataset may not be saved every iteration + if online_training: + dataset_id += "_online_generated" + num_samples = kwargs.get("num_samples", 0) + is_save_dataset = time_id < num_samples + + data_dict = env.get_wrapper_attr("to_dataset")( + id=dataset_id if is_save_dataset else None + ) + ret.append(data_dict) + else: + data_dict = env.get_wrapper_attr("to_dataset")(id=dataset_id) + + # episode id + try: + episode = env.get_wrapper_attr("get_current_episode")() + except AttributeError: + episode = time_id + + # video saving + if save_video: + video_path = os.path.join(save_path, f"episode_{episode}") + if online_training: + vis_thread = Thread( + target=visualize_data_dict, + args=(data_dict["data"], video_path), + daemon=True, + ) + vis_thread.start() + vis_threads.append(vis_thread) + else: + visualize_data_dict(data_dict["data"], video_path) + + break # success + + # --------------------------------------------------------- + # FAILURE CASE + # --------------------------------------------------------- + else: + log_warning("Task fail, Skip to next generation and retry.") + continue # retry until success + + wait_for_threads(vis_threads) + return ret if online_training else True + + +def main(args, env, gym_config): + is_online_training = os.path.exists(args.online_config) + if is_online_training: + + log_info("Start online data generation.", color="green") + assert os.path.exists(args.online_config), "{} does not exist.".format( + args.online_config + ) + + online_config = load_json(args.online_config) + online_callback = OnlineGenerator(**online_config) + + generator_func = lambda time_id, **kwargs: generate_function( + env, + time_id, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + regenerate=args.regenerate, + **kwargs, + ) + online_callback.generator(generator_func, **online_config) + else: + log_info("Start offline data generation.", color="green") + for i in range(gym_config["max_episodes"]): + generate_function( + env, + i, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + debug_mode=args.debug_mode, + regenerate=args.regenerate, + ) + + if args.headless: + env.reset(options={"final": True}) + + +if __name__ == "__main__": + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--render_backend", + help="The rendering backend to use for the simulation.", + default="egl", + type=str, + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--save_video", + help="Whether to save data as video.", + default=False, + action="store_true", + ) + parser.add_argument( + "--save_path", help="path", default="./outputs/thirdviewvideo", type=str + ) + parser.add_argument( + "--debug_mode", + help="Enable debug mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + parser.add_argument("--online_config", type=str, help="online_config", default="") + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument( + "--task_name", type=str, help="Name of the task.", required=True + ) + + # Agent related configs + parser.add_argument( + "--agent_config", type=str, help="agent_config", default=None, required=True + ) + parser.add_argument( + "--regenerate", + type=bool, + help="Whether regenerate code if already existed.", + default=False, + ) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + agent_config = load_json(args.agent_config) + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make( + id=gym_config["id"], + cfg=cfg, + agent_config=agent_config, + task_name=args.task_name, + ) + main(args, env, gym_config) diff --git a/embodichain/toolkits/code_generation.py b/embodichain/toolkits/code_generation.py new file mode 100644 index 00000000..c222b8aa --- /dev/null +++ b/embodichain/toolkits/code_generation.py @@ -0,0 +1,80 @@ +from typing import List, Dict, Tuple, Any +from langchain_core.output_parsers import BaseOutputParser +from pygments import highlight +from pygments.lexers import PythonLexer +from pygments.formatters import TerminalFormatter + +import numpy as np + + +def merge_dicts(dicts: Dict): + return {k: v for d in dicts for k, v in d.items()} + + +def get_executable_code_str(input_string, language="python"): + start_marker = f"```{language}" + end_marker = f"```" + if input_string.find(start_marker) >= 0: + + start_index = input_string.find(start_marker) + len(start_marker) + end_index = input_string.rfind(end_marker) + + code_string = input_string[start_index:end_index].strip() + else: + code_string = input_string + + return code_string + + +class OutputFormatting: + @staticmethod + def flatten_dict(output: Dict[str, Dict]) -> Dict[str, np.ndarray]: + ret = {} + for _, val in output.items(): + ret.update(val) + return ret + + +class ExecutableOutputParser(BaseOutputParser): + # https://python.langchain.com/v0.1/docs/modules/model_io/output_parsers/custom/ + + _fixed_vars = {"np": np} + variable_vars = {} + + def update_vars(self, variable_vars: Dict): + self.variable_vars = variable_vars + + def parse(self, text: str) -> Tuple[str, Dict, Dict]: + code_str = get_executable_code_str(text) + # if self._cfg["include_context"] and context != "": + # to_exec = f"{context}\n{code_str}" + # to_log = f"{context}\n{use_query}\n{code_str}" + # else: + # to_exec = code_str + # to_log = f"{use_query}\n{to_exec}" + + # to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) + # print( + # f"\033[34m====================================================\nLMP {self._name} exec:\033[0m\n\n{to_log_pretty}\n\n\033[34m====================================================\n\033[0m" + # ) + + # generate new functions + # new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) + # self._variable_vars.update(new_fs) + + gvars = merge_dicts([self._fixed_vars, self.variable_vars]) + lvars = None + + # + # banned_phrases = ["import", "__"] + # for phrase in banned_phrases: + # assert phrase not in code_str + + if gvars is None: + gvars = {} + if lvars is None: + lvars = {} + empty_fn = lambda *args, **kwargs: None + custom_gvars = merge_dicts([gvars, {"exec": empty_fn, "eval": empty_fn}]) + + return code_str, custom_gvars, lvars diff --git a/embodichain/toolkits/interfaces.py b/embodichain/toolkits/interfaces.py new file mode 100644 index 00000000..b91a008d --- /dev/null +++ b/embodichain/toolkits/interfaces.py @@ -0,0 +1,1241 @@ +from typing import List, Dict +from embodichain.lab.gym.structs import Object +import numpy as np +from embodichain.toolkits.toolkits import ToolkitsBase +from embodichain.utils.logger import log_info, log_warning, log_error +from copy import deepcopy +from embodichain.lab.gym.utils.misc import ( + mul_linear_expand, + is_qpos_flip, + get_rotation_replaced_pose, +) +from embodichain.toolkits.graspkit.pg_grasp.antipodal import GraspSelectMethod +from matplotlib import pyplot as plt +import torch +from tqdm import tqdm +from embodichain.lab.gym.motion_generation.action.arm_action import ArmAction +from embodichain.data.enum import ControlParts, EndEffector, JointType +from scipy.spatial.transform import Rotation as R +from embodichain.utils.utility import encode_image +import ast + +""" +--------------------------------------------Some useful functions---------------------------------------------------- +--------------------------------------------Some useful functions---------------------------------------------------- +--------------------------------------------Some useful functions---------------------------------------------------- +""" + + +def draw_axis(env, pose): + from embodichain.lab.sim.cfg import MarkerCfg + + marker_cfg = MarkerCfg( + name="test", + marker_type="axis", + axis_xpos=pose, + axis_size=0.01, + axis_len=0.2, + arena_index=-1, # All arenas + ) + env.sim.draw_marker(cfg=marker_cfg) + env.sim.update() + + +def get_arm_states(env, robot_name): + + left_arm_current_qpos, right_arm_current_qpos = env.get_current_qpos_agent() + left_arm_current_pose, right_arm_current_pose = env.get_current_xpos_agent() + left_arm_current_gripper_state, right_arm_current_gripper_state = ( + env.get_current_gripper_state_agent() + ) + + side = "right" if "right" in robot_name else "left" + is_left = True if side == "left" else False + select_arm = "left_arm" if is_left else "right_arm" + + arms = { + "left": ( + left_arm_current_qpos, + left_arm_current_pose, + left_arm_current_gripper_state, + ), + "right": ( + right_arm_current_qpos, + right_arm_current_pose, + right_arm_current_gripper_state, + ), + } + ( + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = arms[side] + + return ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) + + +def find_nearest_valid_pose(env, select_arm, pose, xpos_resolution=0.1): + # use the validator to choose the nearest valid pose + # delete the cache every time + if isinstance(pose, torch.Tensor): + pose = pose.detach().cpu().numpy() + ret, _ = env.robot.compute_xpos_reachability( + select_arm, + pose, + xpos_resolution=xpos_resolution, + qpos_resolution=np.radians(60), + cache_mode="disk", + use_cached=False, + visualize=False, + ) + ret = np.stack(ret, axis=0) + # find the nearest valid pose + xyz = pose[:3, 3] + ts = np.stack([M[:3, 3] for M in ret], axis=0) # shape (N,3) + dists = np.linalg.norm(ts - xyz[None, :], axis=1) + best_idx = np.argmin(dists) + nearest_valid_pose = ret[best_idx] + return torch.from_numpy(nearest_valid_pose) + + +def get_qpos(env, is_left, select_arm, pose, qpos_seed, force_valid=False, name=""): + if force_valid: + try: + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) + if not ret: + log_error(f"Generate {name} qpos failed.\n") + except Exception as e: + log_warning( + f"Original {name} pose invalid, using nearest valid pose. ({e})\n" + ) + pose = find_nearest_valid_pose(env, select_arm, pose) + + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) + else: + ret, qpos = env.get_arm_ik(pose, is_left=is_left, qpos_seed=qpos_seed) + if not ret: + log_error(f"Generate {name} qpos failed.\n") + + return qpos + + +def get_offset_pose( + pose_to_change: torch.Tensor, + offset_value: float, + direction: str = "z", + mode: str = "intrinsic", +) -> torch.Tensor: + + device = pose_to_change.device + dtype = pose_to_change.dtype + + if isinstance(direction, str): + if direction == "x": + direction_vec = torch.tensor([1.0, 0.0, 0.0], device=device, dtype=dtype) + elif direction == "y": + direction_vec = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=dtype) + elif direction == "z": + direction_vec = torch.tensor([0.0, 0.0, 1.0], device=device, dtype=dtype) + else: + log_error(f"Invalid direction '{direction}'. Must be 'x', 'y', or 'z'.") + return pose_to_change + else: + direction_vec = torch.as_tensor(direction, device=device, dtype=dtype) + + direction_vec = direction_vec / torch.linalg.norm(direction_vec) + offset_matrix = torch.eye(4, device=device, dtype=dtype) + offset_matrix[:3, 3] = offset_value * direction_vec + + if mode == "extrinsic": + offset_pose = offset_matrix @ pose_to_change + elif mode == "intrinsic": + offset_pose = pose_to_change @ offset_matrix + else: + log_error(f"Invalid mode '{mode}'. Must be 'extrinsic' or 'intrinsic'.") + return pose_to_change + + return offset_pose + + +def plan_trajectory( + env, + select_arm, + qpos_list, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, +): + traj_list, _ = ArmAction.create_discrete_trajectory( + agent=env.robot, + uid=select_arm, + qpos_list=qpos_list, + sample_num=sample_num, + qpos_seed=qpos_list[0], + is_use_current_qpos=False, + **getattr(env, "planning_config", {}), + ) + + select_qpos_traj.extend(traj_list) + ee_state_list_select.extend([select_arm_current_gripper_state] * len(traj_list)) + + +def plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select, +): + open_state = env.open_state + close_state = env.close_state + + if execute_open: + ee_state_expand_select = np.array([close_state, open_state]) + env.set_current_gripper_state_agent(open_state, is_left=is_left) + else: + ee_state_expand_select = np.array([open_state, close_state]) + env.set_current_gripper_state_agent(close_state, is_left=is_left) + + ee_state_expand_select = mul_linear_expand(ee_state_expand_select, [sample_num]) + + select_qpos_traj.extend([select_arm_current_qpos] * sample_num) + ee_state_list_select.extend(ee_state_expand_select) + + +def finalize_actions(select_qpos_traj, ee_state_list_select): + # mimic eef state + actions = np.concatenate( + [ + np.array(select_qpos_traj), + np.array(ee_state_list_select), + np.array(ee_state_list_select), + ], + axis=-1, + ) + return actions + + +def extract_drive_calls(code_str: str) -> list[str]: + tree = ast.parse(code_str) + lines = code_str.splitlines() + + drive_blocks = [] + + for node in tree.body: + # Match: drive(...) + if ( + isinstance(node, ast.Expr) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == "drive" + ): + # AST line numbers are 1-based + start = node.lineno - 1 + end = node.end_lineno + block = "\n".join(lines[start:end]) + drive_blocks.append(block) + + return drive_blocks + + +""" +--------------------------------------------Atom action functions---------------------------------------------------- +--------------------------------------------Atom action functions---------------------------------------------------- +--------------------------------------------Atom action functions---------------------------------------------------- +""" + + +# TODO: write a move_to_pose atom action, the use this action to form other atom actions +def grasp( + robot_name: str, + obj_name: str, + pre_grasp_dis: float = 0.05, + env=None, + force_valid=False, + **kwargs, +): + # Get target object + obj_uids = env.sim.get_rigid_object_uid_list() + if obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + else: + log_error(f"No matched object {obj_uids}.") + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0) + + # Open the gripper if currently closed + actions = None + select_arm_current_gripper_state = ( + env.left_arm_current_gripper_state + if "left" in robot_name + else env.right_arm_current_gripper_state + ) + if select_arm_current_gripper_state <= env.open_state - 0.01: + actions = open_gripper(robot_name, env, **kwargs) + + # Retract the end-effector to avoid collision + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + select_arm_base_pose = ( + env.left_arm_base_pose if is_left else env.right_arm_base_pose + ) + base_to_eef_xy_dis = torch.norm( + select_arm_base_pose[:2, 3] - select_arm_current_pose[:2, 3] + ) + base_to_obj_xy_dis = torch.norm( + select_arm_base_pose[:2, 3] - target_obj_pose[:2, 3] + ) + dis_eps = kwargs.get("dis_eps", 0.05) + select_arm_init_pose = ( + env.left_arm_init_xpos if is_left else env.right_arm_init_xpos + ) + if base_to_eef_xy_dis > base_to_obj_xy_dis and not torch.allclose( + select_arm_current_pose, select_arm_init_pose, rtol=1e-5, atol=1e-8 + ): + delta = base_to_eef_xy_dis - (base_to_obj_xy_dis - dis_eps) + back_actions = move_by_relative_offset( + robot_name=robot_name, + dx=0.0, + dy=0.0, + dz=-delta, + env=env, + force_valid=force_valid, + mode="intrinsic", + sample_num=15, + **kwargs, + ) + actions = ( + np.concatenate([actions, back_actions], axis=0) + if actions is not None + else back_actions + ) + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Move the end-effector to a good place for starting grasping to avoid bad poses + select_arm_retract_pose = deepcopy( + env.left_arm_init_xpos if is_left else env.right_arm_init_xpos + ) + select_arm_retract_pose = get_offset_pose( + select_arm_retract_pose, 0.15, "z", "intrinsic" + ) + select_arm_retract_qpos = get_qpos( + env, + is_left, + select_arm, + select_arm_retract_pose, + env.left_arm_init_qpos if is_left else env.right_arm_init_qpos, + force_valid=force_valid, + name="retract_to_good_pose", + ) + qpos_list_back_to_retract = [select_arm_current_qpos, select_arm_retract_qpos] + sample_num = 30 + + plan_trajectory( + env, + select_arm, + qpos_list_back_to_retract, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + select_arm_current_qpos = select_arm_retract_qpos + select_arm_current_pose = select_arm_retract_pose + + # Rotate the arm base to face the object for better grasping + delta_xy = target_obj_pose[:2, 3] - select_arm_base_pose[:2, 3] + dx, dy = delta_xy[0], delta_xy[1] + aim_horizontal_angle = np.arctan2(dy, dx) + delta_angle = abs(select_arm_current_qpos[0] - aim_horizontal_angle) + select_arm_aim_qpos = deepcopy(select_arm_current_qpos) + select_arm_aim_qpos[0] = aim_horizontal_angle + + # Get best grasp pose from affordance data + grasp_pose_object = env.init_obj_info.get(obj_name)["grasp_pose_obj"] + if ( + grasp_pose_object[0, 2] > 0.5 + ): # whether towards x direction TODO: make it robust + # Align the object pose's z-axis with the arm's aiming direction + target_obj_pose = torch.tensor( + get_rotation_replaced_pose( + np.array(target_obj_pose), + float(select_arm_aim_qpos[0]), + "z", + "intrinsic", + ) + ) + best_pickpose = target_obj_pose @ grasp_pose_object + grasp_pose = deepcopy(best_pickpose) + grasp_pose_pre1 = deepcopy(grasp_pose) + grasp_pose_pre1 = get_offset_pose(grasp_pose_pre1, -pre_grasp_dis, "z", "intrinsic") + + # Solve IK for pre-grasp and grasp poses + grasp_qpos_pre1 = get_qpos( + env, + is_left, + select_arm, + grasp_pose_pre1, + select_arm_aim_qpos, + force_valid=force_valid, + name="grasp pre1", + ) + grasp_qpos = get_qpos( + env, + is_left, + select_arm, + grasp_pose, + grasp_qpos_pre1, + force_valid=force_valid, + name="grasp", + ) + + # Update env state to final grasp pose + env.set_current_qpos_agent(grasp_qpos, is_left=is_left) + env.set_current_xpos_agent(grasp_pose, is_left=is_left) + + # ------------------------------------ Traj 0: init → aim ------------------------------------ + qpos_list_init_to_aim = [select_arm_current_qpos, select_arm_aim_qpos] + # base_sample_num = 10 + # base_angle = 0.08 + # sample_num = max(int(delta_angle / base_angle * base_sample_num), 2) + + sample_num = 10 + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_aim, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ------------------------------------ Traj 1: aim → pre-grasp ------------------------------------ + qpos_list_aim_to_pre1 = [select_arm_aim_qpos, grasp_qpos_pre1] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_aim_to_pre1, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ------------------------------------ Traj 2: pre-grasp → grasp ------------------------------------ + qpos_list_pre1_to_grasp = [grasp_qpos_pre1, grasp_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_pre1_to_grasp, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + traj_actions = finalize_actions(select_qpos_traj, ee_state_list_select) + actions = ( + traj_actions + if actions is None + else np.concatenate([actions, traj_actions], axis=0) + ) + + # ------------------------------------ Close gripper ------------------------------------ + close_gripper_actions = close_gripper(robot_name, env, **kwargs) + actions = np.concatenate([actions, close_gripper_actions], axis=0) + + log_info( + f"Total generated trajectory number for grasp: {len(actions)}.", color="green" + ) + + return actions + + +# def place_on_table( +# robot_name: str, +# obj_name: str, +# x: float = None, +# y: float = None, +# pre_place_dis: float = 0.08, +# env=None, +# force_valid=False, +# **kwargs +# ): +# +# # ---------------------------------------- Prepare ---------------------------------------- +# select_qpos_traj = [] +# ee_state_list_select = [] +# +# is_left, select_arm, select_arm_current_qpos, select_arm_current_pose, \ +# select_arm_current_gripper_state = get_arm_states(env, robot_name) +# +# grasp_pose_object = env.init_obj_info.get(obj_name).get('grasp_pose_obj') +# init_obj_pose = env.init_obj_info.get(obj_name).get('pose') +# +# select_arm_base_pose = env.left_arm_base_pose if is_left else env.right_arm_base_pose +# delta_xy = init_obj_pose[:2, 3] - select_arm_base_pose[:2, 3] +# aim_horizontal_angle = np.arctan2(delta_xy[1], delta_xy[0]) +# +# # Align the object pose's z-axis with the arm's aiming direction +# init_obj_pose = torch.tensor(get_rotation_replaced_pose(np.array(init_obj_pose), float(aim_horizontal_angle), "z","intrinsic")) +# +# place_pose = init_obj_pose @ grasp_pose_object +# place_pose[0, 3] = x +# place_pose[1, 3] = y +# place_pose[2, 3] += kwargs.get('eps', 0.02) +# +# pre_place_pose = deepcopy(place_pose) +# pre_place_pose[2, 3] += pre_place_dis +# +# # Solve IK for pre-place and place poses +# place_qpos_pre1 = get_qpos(env, is_left, select_arm, pre_place_pose, select_arm_current_qpos, force_valid=force_valid, name='place pre1') +# place_qpos = get_qpos(env, is_left, select_arm, place_pose, place_qpos_pre1, force_valid=force_valid, name='place') +# +# # Update env states +# env.set_current_qpos_agent(place_qpos, is_left=is_left) +# env.set_current_xpos_agent(place_pose, is_left=is_left) +# +# # ------------------------------------ Traj 0: current → pre-place ------------------------------------ +# qpos_list_current_to_preplace = [select_arm_current_qpos, place_qpos_pre1] +# sample_num = 30 +# +# plan_trajectory( +# env, +# select_arm, +# qpos_list_current_to_preplace, +# sample_num, +# select_arm_current_gripper_state, +# select_qpos_traj, +# ee_state_list_select +# ) +# +# # ------------------------------------ Traj 1: pre-place → place ------------------------------------ +# qpos_list_preplace_to_place = [place_qpos_pre1, place_qpos] +# sample_num = 20 +# +# plan_trajectory( +# env, +# select_arm, +# qpos_list_preplace_to_place, +# sample_num, +# select_arm_current_gripper_state, +# select_qpos_traj, +# ee_state_list_select +# ) +# +# # ---------------------------------------- Final ---------------------------------------- +# traj_actions = finalize_actions(select_qpos_traj, ee_state_list_select) +# +# open_actions = open_gripper(robot_name, env, **kwargs) +# +# actions = np.concatenate([traj_actions, open_actions], axis=0) +# +# log_info(f"Total generated trajectory number for place on table: {len(actions)}.", color="green") +# +# return actions + + +def place_on_table( + robot_name: str, + obj_name: str, + x: float = None, + y: float = None, + pre_place_dis: float = 0.08, + env=None, + force_valid=False, + **kwargs, +): + + init_obj_height = env.init_obj_info.get(obj_name).get("height") + height = init_obj_height + kwargs.get("eps", 0.03) + + traj_actions = move_to_absolute_position( + robot_name, x=x, y=y, z=height, env=env, force_valid=force_valid, **kwargs + ) + open_actions = open_gripper(robot_name, env, **kwargs) + + actions = np.concatenate([traj_actions, open_actions], axis=0) + + log_info( + f"Total generated trajectory number for place on table: {len(actions)}.", + color="green", + ) + + return actions + + +def move_relative_to_object( + robot_name: str, + obj_name: str, + x_offset: float = 0, + y_offset: float = 0, + z_offset: float = 0, + env=None, + force_valid=False, + **kwargs, +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Resolve target object + obj_uids = env.sim.get_rigid_object_uid_list() + if obj_name in obj_uids: + target_obj = env.sim.get_rigid_object(obj_name) + else: + log_error("No matched object.") + + # Get object base pose (4x4 matrix) + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0) + + # Construct target pose (preserve orientation) + move_target_pose = deepcopy(select_arm_current_pose) + move_target_pose[:3, 3] = target_obj_pose[:3, 3] + move_target_pose[0, 3] += x_offset + move_target_pose[1, 3] += y_offset + move_target_pose[2, 3] += z_offset + + # Solve IK for target pose + move_target_qpos = get_qpos( + env, + is_left, + select_arm, + move_target_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move relative to object", + ) + + # Update env states + env.set_current_qpos_agent(move_target_qpos, is_left=is_left) + env.set_current_xpos_agent(move_target_pose, is_left=is_left) + + # ------------------------------------ Traj 1: init → target ------------------------------------ + qpos_list_init_to_target = [select_arm_current_qpos, move_target_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_target, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for move relative to object: {len(actions)}.", + color="green", + ) + + return actions + + +def move_to_absolute_position( + robot_name: str, + x: float = None, + y: float = None, + z: float = None, + env=None, + force_valid=False, + **kwargs, +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Start from current pose, then selectively update xyz + move_pose = deepcopy(select_arm_current_pose) + + current_xyz = move_pose[:3, 3].clone() + + target_xyz = current_xyz.clone() + if x is not None: + target_xyz[0] = x + if y is not None: + target_xyz[1] = y + if z is not None: + target_xyz[2] = z + + move_pose[:3, 3] = target_xyz + + # Try IK on target pose + move_qpos = get_qpos( + env, + is_left, + select_arm, + move_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move to absolute position", + ) + + # Update env states + env.set_current_qpos_agent(move_qpos, is_left=is_left) + env.set_current_xpos_agent(move_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_move = [select_arm_current_qpos, move_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_move, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for move to absolute position: {len(actions)}.", + color="green", + ) + + return actions + + +def move_by_relative_offset( + robot_name: str, + dx: float = 0.0, + dy: float = 0.0, + dz: float = 0.0, + mode: str = "extrinsic", + env=None, + force_valid=False, + **kwargs, +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + move_pose = deepcopy(select_arm_current_pose) + + # Apply relative offsets (dx, dy, dz always floats) + move_pose = get_offset_pose(move_pose, dx, "x", mode) + move_pose = get_offset_pose(move_pose, dy, "y", mode) + move_pose = get_offset_pose(move_pose, dz, "z", mode) + + # Solve IK + move_qpos = get_qpos( + env, + is_left, + select_arm, + move_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="move by relative offset", + ) + + # Update environment states + env.set_current_qpos_agent(move_qpos, is_left=is_left) + env.set_current_xpos_agent(move_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_move = [select_arm_current_qpos, move_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_move, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for move by relative offset: {len(actions)}.", + color="green", + ) + + return actions + + +def back_to_initial_pose(robot_name: str, env=None, **kwargs): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + # Get arm states + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # Retrieve the initial joint configuration of this arm + target_qpos = env.left_arm_init_qpos if is_left else env.right_arm_init_qpos + target_qpos = torch.as_tensor(target_qpos, dtype=select_arm_current_qpos.dtype) + + # ---------------------------------------- Pose ---------------------------------------- + # Pre-back pose: move along tool z by a small offset (use intrinsic frame) + pre_back_pose = deepcopy(select_arm_current_pose) + pre_back_pose = get_offset_pose(pre_back_pose, -0.08, "z", "intrinsic") + + # IK for pre-back + pre_back_qpos = get_qpos( + env, + is_left, + select_arm, + pre_back_pose, + select_arm_current_qpos, + force_valid=kwargs.get("force_valid", False), + name="pre back pose", + ) + + # Update env states (move to target pose) + target_pose = env.get_arm_fk(qpos=target_qpos, is_left=is_left) + env.set_current_qpos_agent(target_qpos, is_left=is_left) + env.set_current_xpos_agent(target_pose, is_left=is_left) + + # ------------------------------------ Traj: init → pre back_pose ------------------------------------ + qpos_list_init_to_preback = [select_arm_current_qpos, pre_back_qpos] + sample_num = 20 + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_preback, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ------------------------------------ Traj: init → initial_pose ------------------------------------ + qpos_list_preback_to_target = [pre_back_qpos, target_qpos] + sample_num = kwargs.get("sample_num", 30) + + plan_trajectory( + env, + select_arm, + qpos_list_preback_to_target, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for back to initial pose: {len(actions)}.", + color="green", + ) + + return actions + + +def rotate_eef(robot_name: str, degree: float = 0, env=None, **kwargs): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Compute new joint positions + rotated_qpos = deepcopy(select_arm_current_qpos) + rotated_qpos[5] += np.deg2rad(degree) + + # Optional: limit checking (commented out by default) + # joint5_limit = env.get_joint_limits(select_arm)[5] + # if rotated_qpos[5] < joint5_limit[0] or rotated_qpos[5] > joint5_limit[1]: + # log_warning("Rotated qpos exceeds joint limits.\n") + + # Compute FK for new pose + rotated_pose = env.get_arm_fk( + qpos=rotated_qpos, + is_left=is_left, + ) + + # Update environment state + env.set_current_qpos_agent(rotated_qpos, is_left=is_left) + env.set_current_xpos_agent(rotated_pose, is_left=is_left) + + # ------------------------------------ Traj 1: init → rotated ------------------------------------ + qpos_list_init_to_rotated = [select_arm_current_qpos, rotated_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_rotated, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for rotate eef: {len(actions)}.", + color="green", + ) + + return actions + + +def orient_eef( + robot_name: str, + direction: str = "front", # 'front' or 'down' + env=None, + force_valid=False, + **kwargs, +): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + # Get arm state + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Pose ---------------------------------------- + # Generate replacement rotation matrix + replaced_rotation_matrix = np.eye(4) + if direction == "front": + rotation_matrix = R.from_euler("xyz", [180, -90, 0], degrees=True).as_matrix() + replaced_rotation_matrix[:3, :3] = ( + rotation_matrix @ replaced_rotation_matrix[:3, :3] + ) + elif direction == "down": + rotation_matrix = R.from_euler("x", 180, degrees=True).as_matrix() + replaced_rotation_matrix[:3, :3] = ( + rotation_matrix @ replaced_rotation_matrix[:3, :3] + ) + else: + log_error("Rotation direction must be 'front' or 'down'.") + + rotation_replaced_pose = deepcopy(select_arm_current_pose) + rot_torch = torch.as_tensor( + replaced_rotation_matrix[:3, :3], + dtype=rotation_replaced_pose.dtype, + device=rotation_replaced_pose.device, + ) + rotation_replaced_pose[:3, :3] = rot_torch + + # Solve IK for the new pose + replace_target_qpos = get_qpos( + env, + is_left, + select_arm, + rotation_replaced_pose, + select_arm_current_qpos, + force_valid=force_valid, + name="replaced-rotation", + ) + + # ---------------------------------------- Update env ---------------------------------------- + env.set_current_qpos_agent(replace_target_qpos, is_left=is_left) + env.set_current_xpos_agent(rotation_replaced_pose, is_left=is_left) + + # ------------------------------------ Traj: init → target ------------------------------------ + qpos_list_init_to_rotated = [select_arm_current_qpos, replace_target_qpos] + sample_num = kwargs.get("sample_num", 20) + + plan_trajectory( + env, + select_arm, + qpos_list_init_to_rotated, + sample_num, + select_arm_current_gripper_state, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for orient eef: {len(actions)}.", + color="green", + ) + + return actions + + +def close_gripper(robot_name: str, env=None, **kwargs): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Traj ---------------------------------------- + sample_num = kwargs.get("sample_num", 15) + execute_open = False # False → closing motion + + plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for close gripper: {len(actions)}.", + color="green", + ) + + return actions + + +def open_gripper(robot_name: str, env=None, **kwargs): + + # ---------------------------------------- Prepare ---------------------------------------- + select_qpos_traj = [] + ee_state_list_select = [] + + ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = get_arm_states(env, robot_name) + + # ---------------------------------------- Traj ---------------------------------------- + sample_num = kwargs.get("sample_num", 15) + execute_open = True # True → opening motion + + plan_gripper_trajectory( + env, + is_left, + sample_num, + execute_open, + select_arm_current_qpos, + select_qpos_traj, + ee_state_list_select, + ) + + # ---------------------------------------- Final ---------------------------------------- + actions = finalize_actions(select_qpos_traj, ee_state_list_select) + + log_info( + f"Total generated trajectory number for open gripper: {len(actions)}.", + color="green", + ) + + return actions + + +def drive( + left_arm_action=None, + right_arm_action=None, + env=None, + **kwargs, +): + + if left_arm_action is not None and right_arm_action is not None: + len_left = len(left_arm_action) + len_right = len(right_arm_action) + + if len_left < len_right: + diff = len_right - len_left + padding = np.repeat(left_arm_action[-1:], diff, axis=0) + left_arm_action = np.concatenate([left_arm_action, padding], axis=0) + elif len_right < len_left: + diff = len_left - len_right + padding = np.repeat(right_arm_action[-1:], diff, axis=0) + right_arm_action = np.concatenate([right_arm_action, padding], axis=0) + + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + actions = np.zeros((len(right_arm_action), len(env.init_qpos))) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + elif left_arm_action is None and right_arm_action is not None: + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + left_arm_action = finalize_actions( + env.left_arm_current_qpos, env.left_arm_current_gripper_state + ) + left_arm_action = np.repeat( + left_arm_action[None, :], len(right_arm_action), axis=0 + ) + + actions = np.zeros( + (len(right_arm_action), len(env.robot.get_qpos().squeeze(0))), + dtype=np.float32, + ) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + elif right_arm_action is None and left_arm_action is not None: + left_arm_index = env.left_arm_joints + env.left_eef_joints + right_arm_index = env.right_arm_joints + env.right_eef_joints + right_arm_action = finalize_actions( + env.right_arm_current_qpos, env.right_arm_current_gripper_state + ) + right_arm_action = np.repeat( + right_arm_action[None, :], len(left_arm_action), axis=0 + ) + + actions = np.zeros( + (len(left_arm_action), len(env.robot.get_qpos().squeeze(0))), + dtype=np.float32, + ) + actions[:, left_arm_index] = left_arm_action + actions[:, right_arm_index] = right_arm_action + + else: + log_error("At least one arm action should be provided.") + + actions = torch.from_numpy(actions).to(dtype=torch.float32).unsqueeze(1) + actions = list(actions.unbind(dim=0)) + for i in tqdm(range(len(actions))): + action = actions[i] + obs, reward, terminated, truncated, info = env.step(action) + return actions + + +def save_observations( + step_id: int = 0, + step_name: str = None, + env=None, + **kwargs, +): + # When using feedback script + log_dir = kwargs.get("log_dir") + if log_dir: + save_dir = log_dir / "camera_images" + + # Prepare subfolder: {id}_generate_num/episode{current_check_num} + gen_id = kwargs.get("id", "unknown_id") + episode_id = kwargs.get("current_check_num", 0) + + sub_dir = save_dir / f"{gen_id}_generate_num" / f"episode{episode_id}" + sub_dir.mkdir(parents=True, exist_ok=True) + + # Encode image to Base64 + base64_image = encode_image(env.get_obs_for_agent()["rgb"]) + + # Decode Base64 back to raw image bytes + import base64 + + img_bytes = base64.b64decode(base64_image) + + # Ensure step_name is not None + step_name = step_name if step_name is not None else "unnamed_step" + + # Save the decoded image + output_path = sub_dir / f"step{step_id}_{step_name}.png" + with open(output_path, "wb") as f: + f.write(img_bytes) + + # Print save info + log_info(f"[save_observations] Saved image to: {output_path}") + + # When only running the script (no feedback script) + else: + pass diff --git a/embodichain/toolkits/processor/component/__init__.py b/embodichain/toolkits/processor/component/__init__.py new file mode 100644 index 00000000..774cdb9f --- /dev/null +++ b/embodichain/toolkits/processor/component/__init__.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .component import * diff --git a/embodichain/toolkits/processor/component/component.py b/embodichain/toolkits/processor/component/component.py new file mode 100644 index 00000000..040aa23f --- /dev/null +++ b/embodichain/toolkits/processor/component/component.py @@ -0,0 +1,311 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import abc +import hashlib +import numpy as np +import open3d as o3d +from typing import List, Dict, Any +from dataclasses import dataclass, field, fields, is_dataclass, asdict +from scipy.spatial.transform import Rotation + +from embodichain.utils.cfg import CfgNode +import dexsim.utility as dexutils +from embodichain.toolkits.processor.types import CFG_DEF_TYPE_KEYS +from dexsim.kit.meshproc.generate_thicker_acd import get_pc_thickness + + +class BaseMetaClass(type): + register_class = {} + + def __new__(cls, name: str, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "EntityComponent": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class EntityComponent(metaclass=BaseMetaClass): + """An abstract class for entity components. + + EntityComponent 只是单纯的数据类,没有其他功能行为。 因此我们可以用 `@dataclass` 来装饰他们。 + + 在使用 `@dataclass` 装饰器时,有几点需要注意: + 1. 使用 eq=False 避免使用 dataclass 默认的 __eq__ 方法,而是使用基类的 __eq__ 方法, + 特别是当数据类中有 np.ndarray 类型的变量 比较时会出现错误 + 2. 使用 fronzen=True 避免多个 entity 共享同一个组件实例的时候, + 修改某个组件的属性会影响到其他 entity 的问题。这样做的话,一旦一个组件被初始化,其属性就不能再被修改了,只能通过 new 方法创建新的组件实例。 + 当一个 Component 可能会被多个 entity 共享时,就应该使用 frozen=True 来避免这个问题。 + """ + + def __eq__(self, other): + """ + 用于检查两个相同 EntityComponent 是否相等。 + + 重写 dataclass 的 __eq__ 方法,是增加 EntityComponent 中有 np.ndarray 类型变量的比较。 + + Args: + other (object): The object to compare with. + Returns: + bool: True if the two instances are equal, False otherwise. + """ + if isinstance(other, self.__class__): + for k, v in self.__dict__.items(): + if k not in other.__dict__: + return False + if isinstance(v, np.ndarray): + if not np.array_equal(v, other.__dict__[k]): + return False + elif v != other.__dict__[k]: + return False + return True + return False + + def new(self, **kwargs): + """ + 根据传入的参数更新组件的属性,并返回一个新的组件实例。 + + 一个组件可能会被多个 entity 共享,所以在更新组件的属性时,如果需要对不同的 entity 有不同的属性值, + 使用该接口进行参数更新,而不是直接修改组件的属性。 + + Args: + **kwargs: Keyword arguments to set the attributes of the new instance. + + Raises: + ValueError: If any of the keyword arguments are not valid fields of the class. + + Returns: + An instance of the class with the given keyword arguments set as attributes. + + """ + for k, v in kwargs.items(): + if not hasattr(self, k): + raise ValueError(f"{k} is not a valid field") + # TODO check the type of value + # setattr(self, k, v) + for field in fields(self): + if field.name not in kwargs: + kwargs[field.name] = getattr(self, field.name) + return type(self)(**kwargs) + + def save(self) -> Dict: + """ + 保存当前组件实例的数据。 + + 该函数首先检查当前实例是否为 dataclass,如果是,就使用 `asdict` 函数来保存数据。如果不是,就抛出一个 `NotImplementedError` 异常。 + 即当前我们仅支持 dataclass 类型的组件实例。 + + Returns: + Dict: The data of the current instance as a dictionary. + + Raises: + NotImplementedError: If the current instance is not a dataclass. + """ + if not is_dataclass(self): + raise NotImplementedError + return asdict(self) + + @classmethod + def from_config(cls, cfg: CfgNode): + if not is_dataclass(cls): + raise NotImplementedError + data_fields = fields(cls) + if not isinstance(cfg, dict): + if len(data_fields) > 1: + raise ValueError(f"Config should be a dict, but got {cfg}.") + return cls(cfg) + else: + params = {} + for key, val in cfg.items(): + params[key.lower()] = val + return cls(**params) + + +def build_component_from_config(cfg: CfgNode, name: str = None) -> EntityComponent: + type_key = None + if name is None: + for tk in CFG_DEF_TYPE_KEYS: + if tk in cfg: + if type_key is not None: + raise ValueError( + f"Config should only contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = tk + if type_key is None: + raise ValueError( + f"Config should contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = cfg.pop(type_key) + else: + type_key = name + register_class = EntityComponent.register_class + if isinstance(type_key, str): + type_key = type_key.upper() + if type_key not in register_class: + raise ValueError(f"Class {type_key} is not registered") + else: + raise TypeError(f"Class type {type_key} is not a string") + try: + return register_class[type_key].from_config(cfg) + except Exception as e: + raise ValueError(f"Failed to build component {type_key}, {e}") + + +@dataclass(eq=False) +class AxisAlignedBoundingBox(EntityComponent): + min_bound: np.ndarray + max_bound: np.ndarray + + def is_close( + self, other: "AxisAlignedBoundingBox", threshold: float = 1e-3 + ) -> bool: + return np.allclose( + self.min_bound, other.min_bound, atol=threshold + ) and np.allclose(self.max_bound, other.max_bound, atol=threshold) + + +@dataclass(eq=False) +class OrientedBoundingBox(EntityComponent): + center: np.ndarray + extent: np.ndarray + R: Rotation + + +@dataclass(eq=False, frozen=True) +class TriangleComponent(EntityComponent): + vertices: np.ndarray + triangles: np.ndarray + triangle_uvs: np.ndarray = np.empty((0, 3, 2)) + vertex_uvs: np.ndarray = np.empty((0, 2)) + vertex_colors: np.ndarray = np.empty((0, 3)) # or 4 + vertex_normals: np.ndarray = np.empty((0, 3)) + texture: np.ndarray = np.empty((0, 0, 3)) # hwc + mesh_fpath: str = None + optional_params: Dict[str, Any] = field(default_factory=dict) + + def md5_hash(self) -> str: + md5 = hashlib.md5() + hash_attr_keys = ["vertices", "triangles", "mesh_fpath"] + for key in hash_attr_keys: + val = getattr(self, key) + if val is None: + continue + if isinstance(val, np.ndarray): + md5.update(val.tobytes()) + elif isinstance(val, str): + md5.update(val.encode()) + + return md5.hexdigest() + + @classmethod + def from_config(cls, cfg: CfgNode): + if "MESH_FPATH" not in cfg: + raise ValueError(f"Config should contains key MESH_FPATH, but got {cfg}.") + mesh_fpath = cfg.MESH_FPATH + mesh = cls.__from_fpath(mesh_fpath) + + optional_params = {} + if "OPTIONAL_PARAMS" in cfg and cfg.OPTIONAL_PARAMS is not None: + for key, val in cfg.OPTIONAL_PARAMS.items(): + optional_params[key.lower()] = val + return cls( + vertices=mesh.vertices, + triangles=mesh.triangles, + triangle_uvs=mesh.triangle_uvs, + vertex_uvs=mesh.vertex_uvs, + vertex_colors=mesh.vertex_colors, + vertex_normals=mesh.vertex_normals, + texture=mesh.texture, + mesh_fpath=mesh_fpath, + optional_params=optional_params, + ) + + @classmethod + def from_fpath(cls, mesh_fpath: str): + mesh = cls.__from_fpath(mesh_fpath) + + return cls( + vertices=mesh.vertices, + triangles=mesh.triangles, + triangle_uvs=mesh.triangle_uvs, + vertex_uvs=mesh.vertex_uvs, + vertex_colors=mesh.vertex_colors, + vertex_normals=mesh.vertex_normals, + texture=mesh.texture, + mesh_fpath=mesh_fpath, + ) + + @staticmethod + def __from_fpath(mesh_fpath: str): + from dexsim.kit.meshproc.mesh_io import load_mesh + + mesh = load_mesh(mesh_fpath) + if isinstance(mesh, List): + msg = f"Mesh file {mesh_fpath} contains multiple meshes. Only the first mesh will be used." + dexutils.log_warning(msg) + mesh = mesh[0] + # vertices = mesh.vertices + # triangles = mesh.triangles + if mesh.vertex_uvs is None: + mesh.vertex_uvs = np.empty((0, 2)) + if mesh.triangle_uvs is None: + mesh.triangle_uvs = np.empty((0, 3, 2)) + if mesh.vertex_colors is None: + mesh.vertex_colors = np.empty((0, 3)) + if mesh.vertex_normals is None: + mesh.vertex_normals = np.empty((0, 3)) + if mesh.texture is None: + mesh.texture = np.empty((0, 0, 3)) + return mesh + + def get_thickness(self): + mesh_o3d = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(self.vertices), + o3d.utility.Vector3iVector(self.triangles), + ) + surface_pc_o3d = mesh_o3d.sample_points_uniformly(number_of_points=3000) + surface_pc = np.array(surface_pc_o3d.points) + thickness, standard_pose = get_pc_thickness(surface_pc) + return thickness + + +@dataclass() +class VisualComponent(EntityComponent): + is_visual: bool = True + + +@dataclass(eq=False) +class ScaleComponent(EntityComponent): + scale: np.ndarray = np.array([1.0, 1.0, 1.0]) + + def __post_init__(self): + if self.scale is not None: + self.scale = np.array(self.scale) + + +@dataclass(eq=False) +class SpatializationComponenet(EntityComponent): + location: np.ndarray = np.array([0, 0, 0], dtype=np.float32) + rotation: Rotation = Rotation.from_matrix(np.eye(3)) + + def __post_init__(self): + if self.location is not None: + self.location = np.array(self.location, dtype=np.float32) + if self.rotation is not None and not isinstance(self.rotation, Rotation): + self.rotation = Rotation.from_euler("xyz", self.rotation, degrees=True) + + def save(self) -> Dict: + return_dict = asdict(self) + return_dict["rotation"] = self.rotation.as_euler("xyz", degrees=True) + return return_dict + + def get_pose(self) -> np.ndarray: + pose = np.eye(4) + if self.location is not None: + pose[:3, 3] = self.location + if self.rotation is not None: + pose[:3, :3] = self.rotation.as_matrix() + return pose diff --git a/embodichain/toolkits/processor/entity/__init__.py b/embodichain/toolkits/processor/entity/__init__.py new file mode 100644 index 00000000..a3a7bea9 --- /dev/null +++ b/embodichain/toolkits/processor/entity/__init__.py @@ -0,0 +1,8 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .entity_base import EntityBase, build_entity_from_config +from .meshobject import MeshEntity diff --git a/embodichain/toolkits/processor/entity/entity_base.py b/embodichain/toolkits/processor/entity/entity_base.py new file mode 100644 index 00000000..a6a09511 --- /dev/null +++ b/embodichain/toolkits/processor/entity/entity_base.py @@ -0,0 +1,217 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from typing import Union, Dict, Any + +# from dataclasses import asdict, is_dataclass + +from embodichain.utils.cfg import CfgNode + +from embodichain.toolkits.processor.component import EntityComponent +from embodichain.toolkits.processor.types import CFG_DEF_TYPE_KEYS + + +class EntityMetaClass(type): + register_class = {} + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "EntityBase": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class EntityBase(metaclass=EntityMetaClass): + """ + The base class for all entities in the scene. + + Args: + name (str): The name of the entity. + + Attributes: + name (str): The name of the entity. + _components (Dict[type, EntityComponent]): A dictionary of components of type `EntityComponent`. + """ + + def __init__(self, name: str) -> None: + self.name = name + + self._components: Dict[type, EntityComponent] = {} + self._custom_properties: Dict[str, Any] = {} + + @classmethod + def kind(cls) -> str: + """ + The kind of the entity. + + Returns: + str: A string representing the name of the class. + """ + return cls.__name__ + + def get_name(self) -> str: + """ + Get the name of the entity. + + Returns: + str: the name of the entity + """ + return self.name + + def set_custom_property(self, **kwargs): + for key, value in kwargs.items(): + self._custom_properties[key] = value + + def has_custom_property(self, *key: str): + for k in key: + if k not in self._custom_properties: + return False + return True + + def get_custom_property(self, key: str): + if key not in self._custom_properties: + return None + return self._custom_properties[key] + + def get_custom_properties(self) -> Dict[str, Any]: + return self._custom_properties + + def remove_custom_property(self, *key: str): + for k in key: + if k in self._custom_properties: + del self._custom_properties[k] + + def add_component(self, *component: EntityComponent): + """ + Adds one or more components to the entity. + + Args: + *component (EntityComponent): One or more components to be added. + + Raises: + AssertionError: If any of the components are not instances of EntityComponent or not direct subclasses of EntityComponent. + + Description: + This method adds one or more components to the entity. If a component of the same type already exists, it will be replaced. + + Note: + Currently, the method only accepts components that are instances of EntityComponent and direct subclasses of EntityComponent. + + Example: + >>> entity = Entity() + >>> entity.add_component(ScaleComponent(), SpatializationComponenet()) + """ + # The existing component of the same type will be replaced. + for comp in component: + assert isinstance( + comp, EntityComponent + ), f"Component {type(comp)} must be an instance of EntityComponent" + assert ( + EntityComponent in comp.__class__.__bases__ + ), f"Component {type(comp)} must be a direct subclass of EntityComponent" + # we only accept dataclass type component + # assert is_dataclass(comp) + self._components[type(comp)] = comp + + def has_component(self, comp_type: type) -> bool: + """ + Check if the entity has a component of a specific type. + + Args: + comp_type (type): The type of component to check for. + + Returns: + bool: True if the entity has a component of the specified type, False otherwise. + """ + return comp_type in self._components + + def get_component(self, comp_type: type) -> Union[None, EntityComponent]: + """ + Get the component of the specified type from the entity. + + Args: + comp_type (type): The type of component to retrieve. + + Returns: + Union[None, EntityComponent]: The component of the specified type, or None if it does not exist. + """ + return self._components.get(comp_type, None) + + def remove_component(self, comp_type: type): + """ + Remove a component of the specified type from the entity. + + Parameters: + comp_type (type): The type of component to remove. + + Raises: + ValueError: If the component of the specified type does not exist. + + Returns: + None + """ + if comp_type in self._components: + del self._components[comp_type] + else: + raise ValueError(f"{comp_type} component does not exist") + + def save(self) -> Dict: + """ + Saves the components of the entity to a dictionary. + + Returns: + Dict: A dictionary containing the names of the component types as keys and the saved components as values. + """ + results = {} + for comp_type, comp in self._components.items(): + # # only save the components that are decorated by dataclass + # if is_dataclass(comp): + results[comp_type.__name__] = comp.save() + return results + + @classmethod + def from_config(cls, cfg: Union[str, CfgNode]) -> "EntityBase": + if isinstance(cfg, CfgNode): + if "NAME" not in cfg: + raise ValueError(f"Config should contains key NAME, but got {cfg}.") + name = cfg.pop("NAME") + elif isinstance(cfg, str): + name = cfg + else: + raise TypeError(f"Config should be a string or CfgNode, but got {cfg}.") + return cls(name) + + +def build_entity_from_config( + cfg: Union[str, CfgNode], entity_type: str = None +) -> EntityBase: + if isinstance(cfg, str) and entity_type is None: + err_msg = f"'entity_type' is required when 'cfg' is a string." + raise ValueError(err_msg) + type_key = None + if entity_type is None: + for tk in CFG_DEF_TYPE_KEYS: + if tk in cfg: + if type_key is not None: + raise ValueError( + f"Config should only contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = tk + if type_key is None: + raise ValueError( + f"Config should contains one of keys {CFG_DEF_TYPE_KEYS}, but got {cfg}." + ) + type_key = cfg.pop(type_key) + else: + type_key = entity_type + register_class = EntityBase.register_class + if isinstance(type_key, str): + type_key = type_key.upper() + if type_key not in register_class: + raise ValueError(f"Class {type_key} is not registered") + else: + raise TypeError(f"Class type {type_key} is not a string") + return register_class[type_key].from_config(cfg) diff --git a/embodichain/toolkits/processor/entity/meshobject.py b/embodichain/toolkits/processor/entity/meshobject.py new file mode 100644 index 00000000..6552646f --- /dev/null +++ b/embodichain/toolkits/processor/entity/meshobject.py @@ -0,0 +1,122 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +import numpy as np +import open3d as o3d + +from embodichain.toolkits.processor.component import EntityComponent +from embodichain.toolkits.processor.component import ( + OrientedBoundingBox, + ScaleComponent, + SpatializationComponenet, + TriangleComponent, + AxisAlignedBoundingBox, + TriangleComponent, + VisualComponent, +) +from .entity_base import EntityBase + + +class MeshEntity(EntityBase): + def __init__( + self, + name: str, + triangle_comp: TriangleComponent = None, + spatialization_comp: SpatializationComponenet = None, + visual_comp: VisualComponent = None, + ) -> None: + super().__init__(name) + if triangle_comp is not None: + self.add_component(triangle_comp) + if visual_comp is None: + visual_comp = VisualComponent() + self.add_component(visual_comp) + + # init with default component + if spatialization_comp is None: + spatialization_comp = SpatializationComponenet() + # add the initial component + self.add_component(spatialization_comp) + + def is_visible(self) -> bool: + # if not self.has_component(TriangleComponent): + # return False + visual_comp = self.get_component(VisualComponent) + if visual_comp: + visual = visual_comp.is_visual + else: + visual = False + return visual + + def add_component(self, *component: EntityComponent): + for comp in component: + if isinstance(comp, TriangleComponent): + self.triangle_comp = comp + # remove the old bounding box component + if self.has_component(AxisAlignedBoundingBox): + self.remove_componenet(AxisAlignedBoundingBox) + if self.has_component(OrientedBoundingBox): + self.remove_componenet(OrientedBoundingBox) + # add default visual component + if not self.has_component(VisualComponent): + self.add_component(VisualComponent()) + super().add_component(*component) + + def get_axis_aligned_bounding_box(self) -> o3d.geometry.AxisAlignedBoundingBox: + triangle_comp: TriangleComponent = self.get_component(TriangleComponent) + scale_comp = self.get_component(ScaleComponent) + vertices = triangle_comp.vertices + scale = np.array([1, 1, 1]) + if scale_comp is not None: + scale = scale_comp.scale + o3d_mesh = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(vertices * scale), + o3d.utility.Vector3iVector(triangle_comp.triangles), + ) + spatial_comp: SpatializationComponenet = self.get_component( + SpatializationComponenet + ) + o3d_mesh.transform(spatial_comp.get_pose()) + aabbox_o3d = o3d_mesh.get_axis_aligned_bounding_box() + return aabbox_o3d + + # def get_oriented_bounding_box(self) -> o3d.geometry.OrientedBoundingBox: + # if not self.has_component(OrientedBoundingBox): + # # self.add_component(OrientedBoundingBox(self.triangle_comp.vertices)) + # o3d_mesh = o3d.geometry.TriangleMesh(self.triangle_comp.vertices, + # self.triangle_comp.triangles) + # bbox = o3d_mesh.get_oriented_bounding_box() + # obb = OrientedBoundingBox(bbox.get_center(), bbox.get_extent(), + # bbox.get_rotation()) + # self.add_component(obb) + # obbox = self.get_component(OrientedBoundingBox) + # return obbox + + def get_o3d_mesh( + self, add_scale: bool = True, add_transform: bool = True + ) -> o3d.geometry.TriangleMesh: + triangle_comp: TriangleComponent = self.get_component(TriangleComponent) + scale_comp = self.get_component(ScaleComponent) + vertices = triangle_comp.vertices + scale = np.array([1, 1, 1]) + if add_scale and scale_comp is not None: + scale = scale_comp.scale + o3d_mesh = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(vertices * scale), + o3d.utility.Vector3iVector(triangle_comp.triangles), + ) + if add_transform: + spatial_comp: SpatializationComponenet = self.get_component( + SpatializationComponenet + ) + o3d_mesh.transform(spatial_comp.get_pose()) + return o3d_mesh + + def save_mesh(self, file_path: str): + from dexsim.kit.meshproc.mesh_io import save_mesh + + tri_comp: TriangleComponent = self.get_component(TriangleComponent) + save_mesh(file_path, **tri_comp.save()) diff --git a/embodichain/toolkits/processor/function/mesh_processor/__init__.py b/embodichain/toolkits/processor/function/mesh_processor/__init__.py new file mode 100644 index 00000000..145396a6 --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/__init__.py @@ -0,0 +1,2 @@ +from .base import MeshProcessor, build_mesh_processors +from .processor import * diff --git a/embodichain/toolkits/processor/function/mesh_processor/base.py b/embodichain/toolkits/processor/function/mesh_processor/base.py new file mode 100644 index 00000000..2b4d3314 --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/base.py @@ -0,0 +1,42 @@ +import abc +from typing import List, Dict + +from embodichain.utils.cfg import CfgNode + +from embodichain.toolkits.processor.entity import MeshEntity + + +class MeshProcessorMetaClass(type): + register_class = {} + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if name != "MeshProcessor": + cls.register_class[name.upper()] = new_cls + return new_cls + + +class MeshProcessor(metaclass=MeshProcessorMetaClass): + def __init__(self, **kwargs): + pass + + @abc.abstractmethod + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + pass + + +class MeshProcessorList(List[MeshProcessor]): + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for processor in self: + meshes = processor.apply(meshes) + return meshes + + +def build_mesh_processors(config: CfgNode) -> MeshProcessorList: + processors = MeshProcessorList() + for name, cfg in config.items(): + if cfg is None: + cfg = {} + processor = MeshProcessor.register_class[name.upper()](**cfg) + processors.append(processor) + return processors diff --git a/embodichain/toolkits/processor/function/mesh_processor/processor.py b/embodichain/toolkits/processor/function/mesh_processor/processor.py new file mode 100644 index 00000000..dc851f12 --- /dev/null +++ b/embodichain/toolkits/processor/function/mesh_processor/processor.py @@ -0,0 +1,163 @@ +import open3d as o3d +import numpy as np +from typing import List, Tuple + +from dexsim.kit.meshproc import face_uv_to_vert_uv +from dexsim.kit.meshproc.compute_uv import get_mesh_auto_uv +from dexsim.kit.meshproc import ( + simplification_decimation, + remesh_isotropic_explicit, +) +import dexsim.utility as dexsutil + +from embodichain.toolkits.processor.entity import MeshEntity +from embodichain.toolkits.processor.component import TriangleComponent + +from .base import MeshProcessor + + +class ComputeUV(MeshProcessor): + def __init__( + self, + compute_vertex_uvs: bool = True, + max_triangle_count: int = 10000, + remesh: bool = False, + ): + """Compute UV coordinates. + + Args: + compute_vertex_uvs (bool, optional): Compute texture uvs or triangle uvs, if True, compute texture uvs. Defaults to True. + max_triangle_count (int, optional): If the number of faces is larger than this value and there is no uvs, simplification will be applied. + It will cost more time to compute uvs if the number of faces is large .Defaults to 10000. + remesh (bool, optional): If set to True, remesh will be applied, and the uvs will be re-computed. Defaults to False. + """ + self.compute_vertex_uvs = compute_vertex_uvs + self.max_triangle_count = max_triangle_count + self.remesh = remesh + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for mesh in meshes: + tri_comp: TriangleComponent = mesh.get_component(TriangleComponent) + has_uvs = tri_comp.vertex_uvs.size > 0 or tri_comp.triangle_uvs.size > 0 + + mesh_o3d = mesh.get_o3d_mesh(add_scale=False, add_transform=False) + mesh_o3dt = o3d.t.geometry.TriangleMesh.from_legacy(mesh_o3d) + # if the number of faces is larger than max_triangle_count and there is no uvs, simplification will be applied + if not has_uvs: + if tri_comp.triangles.shape[0] > self.max_triangle_count: + # simplification + is_success, mesh_o3dt = simplification_decimation( + mesh_o3dt, sample_triangle_num=self.max_triangle_count + ) + if not is_success: + dexsutil.log_warning("failed to do simplification.") + # remesh need to apply after simplification + if self.remesh: + is_success, mesh_o3dt = remesh_isotropic_explicit( + mesh_o3dt, is_visual=False + ) + # has_uvs = False # need to recompute uvs + if self.compute_vertex_uvs: + if tri_comp.vertex_uvs.size == 0 or self.remesh: + if tri_comp.triangle_uvs.size > 0: + vertex_uvs = face_uv_to_vert_uv( + tri_comp.triangles, + tri_comp.triangle_uvs, + len(tri_comp.vertices), + ) + else: + _, vertex_uvs = get_mesh_auto_uv(mesh_o3dt) + tri_comp = tri_comp.new(vertex_uvs=vertex_uvs) + mesh.add_component(tri_comp) + else: + dexsutil.log_error("Not implemented for compute triangle uvs.") + return meshes + + +class MeshNormalize(MeshProcessor): + def __init__( + self, + set_origin: str = "center", + scale: float = 1.0, + unify_longest_side: bool = False, + ): + """Normalize the mesh to a standard size and origin. + + Args: + set_origin (str, optional): Set the origin location of the mesh to it's center or it's bottom center. + Choices=["center", "bottom"]. Defaults to 'center'. + scale (float, optional): Scale factor for the mesh . Defaults to 1.0. + unify_longest_side (float, optional): If True, the longest side of the mesh will be scaled to the scale factor. + Defaults to False. + """ + assert set_origin in [ + "center", + "bottom", + ], f"Invalid value for set_origin: {set_origin}" + self.set_origin = set_origin + self.scale = scale + self.unify_longest_side = unify_longest_side + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + for mesh in meshes: + tri_comp: TriangleComponent = mesh.get_component(TriangleComponent) + vertices = tri_comp.vertices + # set center of the mesh to the origin + if self.set_origin == "center": + center = np.mean(vertices, axis=0) + elif self.set_origin == "bottom": + center_xy = np.mean(vertices[:, :2], axis=0) + center = np.array([center_xy[0], center_xy[1], np.min(vertices[:, 2])]) + else: + raise ValueError(f"Invalid value for set_origin: {self.set_origin}") + vertices -= center # in-place operation + + # scale the mesh + if self.unify_longest_side: + max_length = np.max(vertices, axis=0) - np.min(vertices, axis=0) + scale = self.scale / np.max(max_length) + else: + scale = self.scale + vertices *= scale + return meshes + + +class MeshAlign(MeshProcessor): + def __init__( + self, + method: str = "obb", + symmetry_axis: int = 0, + is_larger_positive: bool = True, + ): + assert method in ["obb", "svd"], f"Invalid value for method: {method}" + self.method = method + self.symmetry_axis = symmetry_axis + self.is_larger_positive = is_larger_positive + + def apply(self, meshes: List[MeshEntity]) -> List[MeshEntity]: + from dexsim.kit.meshproc import cad_standardlize_svd, cad_standardlize_obb + + for mesh in meshes: + mesh_o3d = mesh.get_o3d_mesh(add_scale=False, add_transform=False) + if self.method == "obb": + is_success, mesh_o3dt = cad_standardlize_obb( + mesh_o3d, + is_use_mesh_clean=False, + is_cad_eliminate_symmetry=True, + symmetry_axis=self.symmetry_axis, + is_larger_positive=self.is_larger_positive, + ) + elif self.method == "svd": + is_success, mesh_o3dt = cad_standardlize_obb( + mesh_o3d, + is_use_mesh_clean=False, + is_cad_eliminate_symmetry=True, + symmetry_axis=self.symmetry_axis, + is_larger_positive=self.is_larger_positive, + ) + vertices = mesh_o3dt.vertex.positions.numpy() + triangles = mesh_o3dt.triangle.indices.numpy() + tri_comp = mesh.get_component(TriangleComponent) + tri_comp = tri_comp.new(vertices=vertices, triangles=triangles) + mesh.add_component(tri_comp) + return meshes diff --git a/embodichain/toolkits/processor/types/__init__.py b/embodichain/toolkits/processor/types/__init__.py new file mode 100644 index 00000000..b6b342b8 --- /dev/null +++ b/embodichain/toolkits/processor/types/__init__.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +from .types import * diff --git a/embodichain/toolkits/processor/types/types.py b/embodichain/toolkits/processor/types/types.py new file mode 100644 index 00000000..ed48ce23 --- /dev/null +++ b/embodichain/toolkits/processor/types/types.py @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# All rights reserved. +# ---------------------------------------------------------------------------- + +CFG_DEF_TYPE_KEYS = ["type", "_TYPE", "NAME", "name", "TYPE"] diff --git a/embodichain/toolkits/toolkits.py b/embodichain/toolkits/toolkits.py new file mode 100644 index 00000000..9b2ca344 --- /dev/null +++ b/embodichain/toolkits/toolkits.py @@ -0,0 +1,18 @@ +from abc import ABCMeta, abstractmethod +import os +import cv2 +from embodichain.utils.utility import load_json + + +class ToolkitsBase(metaclass=ABCMeta): + @classmethod + def from_config(cls, path: str): + assert ( + os.path.basename(path).split(".")[-1] == "json" + ), "only json file is supported." + config = load_json(path) + return config["ToolKits"][cls.__name__] + + @abstractmethod + def call(self, **kwargs): + pass diff --git a/pyproject.toml b/pyproject.toml index 50fef254..a770d55c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dependencies = [ "dexsim_engine==0.3.9", "setuptools>=78.1.1", "gymnasium>=0.29.1", + "langchain==0.2.14", + "langchain-openai==0.1.22", "toppra==0.6.3", "pin", "pin-pink", @@ -49,6 +51,8 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "zmq==0.0.0", + "pycocotools", ] [project.optional-dependencies]