数据提取

1
2
3
4
5
MAX_NUM_OBJECTS = 64#车辆数目
MAX_POLYLINES = 256#多段线数目
MAX_TRAFFIC_LIGHTS = 16#交通灯数目
CURRENT_INDEX = 10
NUM_POINTS_POLYLINE = 30#线的点数目
1
2
3
4
5
6
7
8
tf_dataset = dataloader.tf_examples_dataset(
path=data_dir,
data_format=DataFormat.TFRECORD,
preprocess_fn=tf_preprocess,
repeat=1,
# num_shards=16,
deterministic=True,
)

tf_preprocess用于 TensorFlow 数据集的预处理,tf_postprocess用于后处理,在data_utils.py定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

def tf_preprocess(serialized: bytes) -> dict[str, tf.Tensor]:
"""
Preprocesses the serialized data.

Args:
serialized (bytes): The serialized data.

Returns:
dict[str, tf.Tensor]: The preprocessed data.
"""
womd_features = dataloader.womd_utils.get_features_description(
include_sdc_paths=False,
max_num_rg_points=30000,
num_paths=None,
num_points_per_path=None,
)
womd_features['scenario/id'] = tf.io.FixedLenFeature([1], tf.string)

deserialized = tf.io.parse_example(serialized, womd_features)
parsed_id = deserialized.pop('scenario/id')
deserialized['scenario/id'] = tf.io.decode_raw(parsed_id, tf.uint8)
return dataloader.preprocess_womd_example(
deserialized,
aggregate_timesteps=True,
max_num_objects=None,
)

将原始的序列化数据集样本转换为模型可用的格式化张量字典,为后续的轨迹预测或运动分析模型准备输入数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def tf_postprocess(example: dict[str, tf.Tensor]):
"""
Postprocesses the example.

Args:
example (dict[str, tf.Tensor]): The example to be postprocessed.

Returns:
tuple: A tuple containing the scenario ID and the postprocessed scenario.
"""
scenario = dataloader.simulator_state_from_womd_dict(example)
scenario_id = example['scenario/id']
return scenario_id, scenario

后处理分离出场景 ID 和 原始场景数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
for example in tf_dataset_iter:

scenario_id_binary, scenario = tf_postprocess(example)
scenario_id = scenario_id_binary.tobytes().decode('utf-8')

scenario_filename = os.path.join(save_dir, 'scenario_'+scenario_id+'.pkl')

# check if file exists
if os.path.exists(scenario_filename):
continue

if only_raw:
data_dict = {'scenario_raw': scenario}
else:
data_dict = data_process_scenario(
scenario,
max_num_objects=MAX_NUM_OBJECTS,
max_polylines=MAX_POLYLINES,
current_index=CURRENT_INDEX,
num_points_polyline=NUM_POINTS_POLYLINE,
)
if save_raw:
data_dict['scenario_raw'] = scenario

data_dict['scenario_id'] = scenario_id

with open(scenario_filename, 'wb') as f:
pickle.dump(data_dict, f)

data_process_scenario——核心场景处理函数

智能体

1
2
3
4
5
6
7
8
(agents_history, agents_future, agents_interested, agents_type, agents_id) = data_process_agent(
scenario,
max_num_objects = max_num_objects,
current_index = current_index,
use_log = use_log,
selected_agents = selected_agents,
remove_history=remove_history,
)
1
2
3
4
5
6
7
8
9
10
11
agents_history[i] = np.column_stack([
log_trajectory.xy[a][:current_index+1, 0],
log_trajectory.xy[a][:current_index+1, 1],
log_trajectory.yaw[a][:current_index+1],
log_trajectory.vel_x[a][:current_index+1],
log_trajectory.vel_y[a][:current_index+1],
log_trajectory.length[a][:current_index+1],
log_trajectory.width[a][:current_index+1],
log_trajectory.height[a][:current_index+1],
])
agents_history[i][~log_trajectory.valid[a, :current_index+1]] = 0

历史轨迹(max_objects, history_length, 8)的8个特征:x, y, yaw, vel_x, vel_y, length, width, height

1
2
3
4
5
6
7
agents_future[i] = np.column_stack([
log_trajectory.xy[a][current_index:, 0],
log_trajectory.xy[a][current_index:, 1],
log_trajectory.yaw[a][current_index:],
log_trajectory.vel_x[a][current_index:],
log_trajectory.vel_y[a][current_index:]
])

未来轨迹(max_objects, future_length, 5)的5个特征x, y, yaw, vel_x, vel_y

  • agents_interested: 智能体关注度(模型化对象=10,其他=1,无效=0)

  • agents_type: 智能体类型(车辆、行人等)

  • agent_ids: 实际处理的智能体ID列表

  • 有效性掩码: 使用valid标志过滤无效时间步的数据

  • 零填充: 对无效位置用0填充

  • 历史清除: remove_history参数可清除除当前时刻外的所有历史

最后返回

1
return (agents_history, agents_future, agents_interested, agents_type, agent_ids)
1
2
3
4
5
6
# 假设场景中有23个有效智能体,current_index=10,总时间步=91
agents_history.shape # (64, 11, 8) - 前23行有数据,后41行全0
agents_future.shape # (64, 81, 5) - 同上
agents_interested.shape # (64,) - [10,1,1,10,0,0,...]
agents_type.shape # (64,) - [1,1,2,1,0,0,...]
agent_ids # [23, 45, 12, 67, ...] - 长度23的实际ID列表

如果智能体比较少,agents_history,agents_future会形成很大的稀疏矩阵

交通灯

1
2
3
4
(traffic_light_points, traffic_lane_ids, traffic_light_states) = data_process_traffic_light(
scenario,
current_index = current_index,
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def data_process_traffic_light(
scenario,
current_index = 10,
):
"""
Process traffic light data from the given scenario.

Args:
scenario (datatypes.SimulatorState): The simulator state containing traffic light information.

Returns:
tuple: A tuple containing the processed traffic light points, lane IDs, and states.
"""
traffic_lights = scenario.log_traffic_light

############# Get Traffic Lights #############
traffic_lane_ids = np.asarray(traffic_lights.lane_ids)[:, current_index]
traffic_light_states = np.asarray(traffic_lights.state)[:, current_index]
traffic_stop_points = np.asarray(traffic_lights.xy)[:, current_index]
traffic_light_valid = np.asarray(traffic_lights.valid)[:, current_index]

traffic_light_points = np.concatenate([traffic_stop_points, traffic_light_states[:, None]], axis=1)
traffic_light_points = np.float32(traffic_light_points)
traffic_light_points = np.where(
traffic_light_valid[:, None],
traffic_light_points,
0.0
)

return traffic_light_points, traffic_lane_ids, traffic_light_states

输出x,y坐标,航道id,交通灯状态

航点

1
roadgraph_points = scenario.roadgraph_points

不作处理

道路图

从完整的 Waymo 道路图中,只提取与当前场景中活跃智能体相关的局部地图。

1
2
3
4
5
6
7
for a in range(agents_history.shape[0]):
if not current_valid[a]:
continue

agent_position = agents_history[a, -1, :2]
nearby_roadgraph_points = filter_topk_roadgraph_points(roadgraph_points, agent_position, 3000)
map_ids.append(nearby_roadgraph_points.ids.tolist())

只为有效的智能体筛选附近的路网点,找到距离智能体最近的 KK 个道路图点(这里 K=3000K=3000),收集所有这些点所属的地图元素 ID(即车道、人行横道等的 ID)

1
2
3
4
5
sorted_map_ids = []
for i in range(nearby_roadgraph_points.shape[0]):
for j in range(len(map_ids)):
if map_ids[j][i] != -1 and map_ids[j][i] not in sorted_map_ids:
sorted_map_ids.append(map_ids[j][i])

去重

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for id in sorted_map_ids:
# get polyline
p_x = roadgraph_points_x[roadgraph_points.ids == id]
p_y = roadgraph_points_y[roadgraph_points.ids == id]
dir_x = roadgraph_points_dir_x[roadgraph_points.ids == id]
dir_y = roadgraph_points_dir_y[roadgraph_points.ids == id]
heading = np.arctan2(dir_y, dir_x)
lane_type = roadgraph_points_types[roadgraph_points.ids == id]
traffic_light_state = traffic_light_states[traffic_lane_ids == id] if id in traffic_lane_ids else 0
traffic_light_state = np.repeat(traffic_light_state, len(p_x))
polyline = np.stack([p_x, p_y, heading, traffic_light_state, lane_type], axis=1)

# sample points and fill into fixed-size array
polyline_len = polyline.shape[0]
sampled_points = np.linspace(0, polyline_len-1, num_points_polyline, dtype=np.int32)
cur_polyline = np.take(polyline, sampled_points, axis=0)
polylines.append(cur_polyline)

每条多段线由 5 个特征组成:(x, y) 坐标、航向角 (heading)、交通灯状态、车道类型,由于地图元素的实际点数不固定,使用 np.linspacenp.take 对其进行均匀采样,以保证每条多段线最终包含固定的 num_points_polyline(例如 30)个点。

1
2
3
4
5
6
7
8
9
10
11
12
13
if len(polylines) > 0:
polylines = np.stack(polylines, axis=0)
polylines_valid = np.ones((polylines.shape[0],), dtype=np.int32)
else:
polylines = np.zeros((1, num_points_polyline, 5), dtype=np.float32)
polylines_valid = np.zeros((1,), dtype=np.int32)

if polylines.shape[0] >= max_polylines:
polylines = polylines[:max_polylines]
polylines_valid = polylines_valid[:max_polylines]
else:
polylines = np.pad(polylines, ((0, max_polylines-polylines.shape[0]), (0, 0), (0, 0)))
polylines_valid = np.pad(polylines_valid, (0, max_polylines-polylines_valid.shape[0]))

截断和填充,将polylines固定在[max_polylines, num_points_polyline, 5]

相对关系

1
2
relations = calculate_relations(agents_history, polylines, traffic_light_points)
relations = np.asarray(relations)

N = n_{\text{agents}}(64) + n_{\text{polylines}}(256) + n_{\text{traffic_lights}}

输出的 [N,N,3][N, N, 3] 关系特征数组,编码了节点 jj 相对于节点 ii局部几何位置

  1. local_pos_x (Δx\Delta x'): 元素 jj 在元素 ii 视野中的前后距离
  2. local_pos_y (Δy\Delta y'): 元素 jj 在元素 ii 视野中的左右距离
  3. theta_diff (Δθ\Delta \theta): 元素 ii 相对于元素 jj相对航向角

感觉可以根据对称压缩一半,还有智能体只与附近的地图元素有关,这会是个比较大的稀疏矩阵吧

最终输出的数据结构

1
2
3
4
5
6
7
8
9
10
11
data_dict = {
'agents_history': (64, 11, 8) # 智能体历史轨迹
'agents_interested': (64,) # 智能体关注度
'agents_type': (64,) # 智能体类型
'agents_future': (64, 81, 5) # 智能体未来轨迹(标签)
'traffic_light_points': (n_traffic_lights, 3) # 交通灯信息
'polylines': (256, 30, 5) # 道路折线
'polylines_valid': (256,) # 折线有效性掩码
'relations': (N,N,3) # 空间关系
'agents_id': (64,) # 智能体原始ID
}

感觉压缩空间还是蛮大的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
    data_dict = data_process_scenario(
scenario,
max_num_objects=MAX_NUM_OBJECTS,
max_polylines=MAX_POLYLINES,
current_index=CURRENT_INDEX,
num_points_polyline=NUM_POINTS_POLYLINE,
)
if save_raw:
data_dict['scenario_raw'] = scenario

data_dict['scenario_id'] = scenario_id

with open(scenario_filename, 'wb') as f:
pickle.dump(data_dict, f)

最后加上id,存入pkl文件

模型