数据提取
1 2 3 4 5 MAX_NUM_OBJECTS = 64 MAX_POLYLINES = 256 MAX_TRAFFIC_LIGHTS = 16 CURRENT_INDEX = 10 NUM_POINTS_POLYLINE = 30
1 2 3 4 5 6 7 8 tf_dataset = dataloader.tf_examples_dataset( path=data_dir, data_format=DataFormat.TFRECORD, preprocess_fn=tf_preprocess, repeat=1 , deterministic=True , )
tf_preprocess用于 TensorFlow 数据集的预处理,tf_postprocess用于后处理,在data_utils.py定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 def tf_preprocess (serialized: bytes ) -> dict [str , tf.Tensor]: """ Preprocesses the serialized data. Args: serialized (bytes): The serialized data. Returns: dict[str, tf.Tensor]: The preprocessed data. """ womd_features = dataloader.womd_utils.get_features_description( include_sdc_paths=False , max_num_rg_points=30000 , num_paths=None , num_points_per_path=None , ) womd_features['scenario/id' ] = tf.io.FixedLenFeature([1 ], tf.string) deserialized = tf.io.parse_example(serialized, womd_features) parsed_id = deserialized.pop('scenario/id' ) deserialized['scenario/id' ] = tf.io.decode_raw(parsed_id, tf.uint8) return dataloader.preprocess_womd_example( deserialized, aggregate_timesteps=True , max_num_objects=None , )
将原始的序列化数据集样本转换为模型可用的格式化张量字典,为后续的轨迹预测或运动分析模型准备输入数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def tf_postprocess (example: dict [str , tf.Tensor] ): """ Postprocesses the example. Args: example (dict[str, tf.Tensor]): The example to be postprocessed. Returns: tuple: A tuple containing the scenario ID and the postprocessed scenario. """ scenario = dataloader.simulator_state_from_womd_dict(example) scenario_id = example['scenario/id' ] return scenario_id, scenario
后处理分离出场景 ID 和 原始场景数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 for example in tf_dataset_iter: scenario_id_binary, scenario = tf_postprocess(example) scenario_id = scenario_id_binary.tobytes().decode('utf-8' ) scenario_filename = os.path.join(save_dir, 'scenario_' +scenario_id+'.pkl' ) if os.path.exists(scenario_filename): continue if only_raw: data_dict = {'scenario_raw' : scenario} else : data_dict = data_process_scenario( scenario, max_num_objects=MAX_NUM_OBJECTS, max_polylines=MAX_POLYLINES, current_index=CURRENT_INDEX, num_points_polyline=NUM_POINTS_POLYLINE, ) if save_raw: data_dict['scenario_raw' ] = scenario data_dict['scenario_id' ] = scenario_id with open (scenario_filename, 'wb' ) as f: pickle.dump(data_dict, f)
data_process_scenario——核心场景处理函数
智能体
1 2 3 4 5 6 7 8 (agents_history, agents_future, agents_interested, agents_type, agents_id) = data_process_agent ( scenario, max_num_objects = max_num_objects, current_index = current_index, use_log = use_log, selected_agents = selected_agents, remove_history=remove_history, )
1 2 3 4 5 6 7 8 9 10 11 agents_history[i] = np.column_stack([ log_trajectory.xy[a][:current_index+1 , 0 ], log_trajectory.xy[a][:current_index+1 , 1 ], log_trajectory.yaw[a][:current_index+1 ], log_trajectory.vel_x[a][:current_index+1 ], log_trajectory.vel_y[a][:current_index+1 ], log_trajectory.length[a][:current_index+1 ], log_trajectory.width[a][:current_index+1 ], log_trajectory.height[a][:current_index+1 ], ]) agents_history[i][~log_trajectory.valid[a, :current_index+1 ]] = 0
历史轨迹(max_objects, history_length, 8)的8个特征:x, y, yaw, vel_x, vel_y, length, width, height
1 2 3 4 5 6 7 agents_future[i] = np.column_stack([ log_trajectory.xy[a][current_index:, 0 ], log_trajectory.xy[a][current_index:, 1 ], log_trajectory.yaw[a][current_index:], log_trajectory.vel_x[a][current_index:], log_trajectory.vel_y[a][current_index:] ])
未来轨迹 (max_objects, future_length, 5)的5个特征x, y, yaw, vel_x, vel_y
agents_interested: 智能体关注度(模型化对象=10,其他=1,无效=0)
agents_type: 智能体类型(车辆、行人等)
agent_ids: 实际处理的智能体ID列表
有效性掩码 : 使用valid标志过滤无效时间步的数据
零填充 : 对无效位置用0填充
历史清除 : remove_history参数可清除除当前时刻外的所有历史
最后返回
1 return (agents_history, agents_future, agents_interested, agents_type, agent_ids)
1 2 3 4 5 6 agents_history.shape agents_future.shape agents_interested.shape agents_type.shape agent_ids
如果智能体比较少,agents_history,agents_future会形成很大的稀疏矩阵
交通灯
1 2 3 4 (traffic_light_points, traffic_lane_ids, traffic_light_states) = data_process_traffic_light ( scenario, current_index = current_index, )
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def data_process_traffic_light ( scenario, current_index = 10 , ): """ Process traffic light data from the given scenario. Args: scenario (datatypes.SimulatorState): The simulator state containing traffic light information. Returns: tuple: A tuple containing the processed traffic light points, lane IDs, and states. """ traffic_lights = scenario.log_traffic_light traffic_lane_ids = np.asarray(traffic_lights.lane_ids)[:, current_index] traffic_light_states = np.asarray(traffic_lights.state)[:, current_index] traffic_stop_points = np.asarray(traffic_lights.xy)[:, current_index] traffic_light_valid = np.asarray(traffic_lights.valid)[:, current_index] traffic_light_points = np.concatenate([traffic_stop_points, traffic_light_states[:, None ]], axis=1 ) traffic_light_points = np.float32(traffic_light_points) traffic_light_points = np.where( traffic_light_valid[:, None ], traffic_light_points, 0.0 ) return traffic_light_points, traffic_lane_ids, traffic_light_states
输出x,y坐标,航道id,交通灯状态
航点
1 roadgraph_points = scenario.roadgraph_points
不作处理
道路图
从完整的 Waymo 道路图中,只提取与当前场景中活跃智能体相关 的局部地图。
1 2 3 4 5 6 7 for a in range (agents_history.shape[0 ]): if not current_valid[a]: continue agent_position = agents_history[a, -1 , :2 ] nearby_roadgraph_points = filter_topk_roadgraph_points(roadgraph_points, agent_position, 3000 ) map_ids.append(nearby_roadgraph_points.ids.tolist())
只为有效的智能体筛选附近的路网点,找到距离智能体最近的 K K K 个道路图点 (这里 K = 3000 K=3000 K = 3 0 0 0 ),收集所有这些点所属的地图元素 ID (即车道、人行横道等的 ID)
1 2 3 4 5 sorted_map_ids = [] for i in range (nearby_roadgraph_points.shape[0 ]): for j in range (len (map_ids)): if map_ids[j][i] != -1 and map_ids[j][i] not in sorted_map_ids: sorted_map_ids.append(map_ids[j][i])
去重
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 for id in sorted_map_ids: p_x = roadgraph_points_x[roadgraph_points.ids == id ] p_y = roadgraph_points_y[roadgraph_points.ids == id ] dir_x = roadgraph_points_dir_x[roadgraph_points.ids == id ] dir_y = roadgraph_points_dir_y[roadgraph_points.ids == id ] heading = np.arctan2(dir_y, dir_x) lane_type = roadgraph_points_types[roadgraph_points.ids == id ] traffic_light_state = traffic_light_states[traffic_lane_ids == id ] if id in traffic_lane_ids else 0 traffic_light_state = np.repeat(traffic_light_state, len (p_x)) polyline = np.stack([p_x, p_y, heading, traffic_light_state, lane_type], axis=1 ) polyline_len = polyline.shape[0 ] sampled_points = np.linspace(0 , polyline_len-1 , num_points_polyline, dtype=np.int32) cur_polyline = np.take(polyline, sampled_points, axis=0 ) polylines.append(cur_polyline)
每条多段线由 5 个特征组成:(x, y) 坐标、航向角 (heading)、交通灯状态、车道类型 ,由于地图元素的实际点数不固定,使用 np.linspace 和 np.take 对其进行均匀采样,以保证每条多段线最终包含固定的 num_points_polyline(例如 30)个点。
1 2 3 4 5 6 7 8 9 10 11 12 13 if len (polylines) > 0 : polylines = np.stack(polylines, axis=0 ) polylines_valid = np.ones((polylines.shape[0 ],), dtype=np.int32) else : polylines = np.zeros((1 , num_points_polyline, 5 ), dtype=np.float32) polylines_valid = np.zeros((1 ,), dtype=np.int32) if polylines.shape[0 ] >= max_polylines: polylines = polylines[:max_polylines] polylines_valid = polylines_valid[:max_polylines] else : polylines = np.pad(polylines, ((0 , max_polylines-polylines.shape[0 ]), (0 , 0 ), (0 , 0 ))) polylines_valid = np.pad(polylines_valid, (0 , max_polylines-polylines_valid.shape[0 ]))
截断和填充,将polylines固定在[max_polylines, num_points_polyline, 5]
相对关系
1 2 relations = calculate_relations(agents_history, polylines, traffic_light_points) relations = np.asarray(relations)
N = n_{\text{agents}}(64) + n_{\text{polylines}}(256) + n_{\text{traffic_lights}}
输出的 [ N , N , 3 ] [N, N, 3] [ N , N , 3 ] 关系特征数组,编码了节点 j j j 相对于节点 i i i 的局部几何位置 :
local_pos_x (Δ x ′ \Delta x' Δ x ′ ): 元素 j j j 在元素 i i i 视野中的前后距离 。
local_pos_y (Δ y ′ \Delta y' Δ y ′ ): 元素 j j j 在元素 i i i 视野中的左右距离 。
theta_diff (Δ θ \Delta \theta Δ θ ): 元素 i i i 相对于元素 j j j 的相对航向角 。
感觉可以根据对称压缩一半,还有智能体只与附近的地图元素有关,这会是个比较大的稀疏矩阵吧
最终输出的数据结构
1 2 3 4 5 6 7 8 9 10 11 data_dict = { 'agents_history' : (64 , 11 , 8 ) 'agents_interested' : (64 ,) 'agents_type' : (64 ,) 'agents_future' : (64 , 81 , 5 ) 'traffic_light_points' : (n_traffic_lights, 3 ) 'polylines' : (256 , 30 , 5 ) 'polylines_valid' : (256 ,) 'relations' : (N,N,3 ) 'agents_id' : (64 ,) }
感觉压缩空间还是蛮大的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 data_dict = data_process_scenario( scenario, max_num_objects=MAX_NUM_OBJECTS, max_polylines=MAX_POLYLINES, current_index=CURRENT_INDEX, num_points_polyline=NUM_POINTS_POLYLINE, ) if save_raw: data_dict['scenario_raw' ] = scenario data_dict['scenario_id' ] = scenario_id with open (scenario_filename, 'wb' ) as f: pickle.dump(data_dict, f)
最后加上id,存入pkl文件
模型