js"""
地图可视化脚本 - 用于显示Waymo场景的地图和轨迹
"""
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2
# 使用相对导入
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from src.utils.config import DATA_CONFIG
def visualize_scenario(scenario_id: str, split: str = 'train'):
"""
可视化单个场景的地图和轨迹
Args:
scenario_id: 场景ID
split: 数据集划分 ('train', 'valid', 'test')
"""
# 创建图形
plt.figure(figsize=(30, 30))
plt.rcParams['axes.facecolor'] = 'grey'
# 加载场景数据
processed_file = os.path.join(DATA_CONFIG['waymo_path'], 'processed', split, f'sample_{scenario_id}.pkl')
if not os.path.exists(processed_file):
print(f"场景 {scenario_id} 不存在")
return
with open(processed_file, 'rb') as f:
scenario_data = pickle.load(f)
# 画地图线
for map_feature in scenario_data['map_features']:
# 车道线
if hasattr(map_feature, 'lane') and map_feature.lane:
line_x = [point.x for point in map_feature.lane.polyline]
line_y = [point.y for point in map_feature.lane.polyline]
plt.scatter(line_x, line_y, c='g', s=5)
# 边界线
if hasattr(map_feature, 'road_edge') and map_feature.road_edge:
road_edge_x = [point.x for point in map_feature.road_edge.polyline]
road_edge_y = [point.y for point in map_feature.road_edge.polyline]
if map_feature.road_edge.type == 2: # 道路边界
plt.scatter(road_edge_x, road_edge_y, c='k')
elif map_feature.road_edge.type == 3: # 道路边界(特殊类型)
plt.scatter(road_edge_x, road_edge_y, c='purple')
else:
plt.scatter(road_edge_x, road_edge_y, c='k')
# 道路标线
if hasattr(map_feature, 'road_line') and map_feature.road_line:
road_line_x = [point.x for point in map_feature.road_line.polyline]
road_line_y = [point.y for point in map_feature.road_line.polyline]
if map_feature.road_line.type == 7: # 双实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif map_feature.road_line.type == 8: # 双虚实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif map_feature.road_line.type == 6: # 单实黄线
plt.plot(road_line_x, road_line_y, c='y')
elif map_feature.road_line.type == 1: # 单虚白线
for i in range(int(len(road_line_x)/7)):
plt.plot(road_line_x[i*7:5+i*7], road_line_y[i*7:5+i*7], color='w')
elif map_feature.road_line.type == 2: # 单实白线
plt.plot(road_line_x, road_line_y, c='w')
else:
plt.plot(road_line_x, road_line_y, c='w')
# 画车辆轨迹
for track in scenario_data['tracks']:
# 提取轨迹点
traj_x = [state.center_x for state in track.states if state.valid]
traj_y = [state.center_y for state in track.states if state.valid]
if not traj_x or not traj_y:
continue
# 如果是自动驾驶车辆
if track.id == scenario_data['sdc_track_index']:
plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s') # 起始点
plt.scatter(traj_x, traj_y, s=14, c='r') # 轨迹
else:
plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s') # 起始点
plt.scatter(traj_x, traj_y, s=14, c='b') # 轨迹
# 设置标题和坐标轴
plt.title(f'Scenario {scenario_id}', fontsize=20)
plt.xlabel('X (meters)', fontsize=16)
plt.ylabel('Y (meters)', fontsize=16)
# 保存图像
output_dir = os.path.join(DATA_CONFIG['waymo_path'], 'visualizations', split)
os.makedirs(output_dir, exist_ok=True)
plt.savefig(os.path.join(output_dir, f'scenario_{scenario_id}.png'))
plt.close()
def visualize_random_scenarios(split: str = 'train', num_scenarios: int = 5):
"""
随机可视化多个场景
Args:
split: 数据集划分 ('train', 'valid', 'test')
num_scenarios: 要可视化的场景数量
"""
processed_dir = os.path.join(DATA_CONFIG['waymo_path'], 'processed', split)
scenario_files = [f for f in os.listdir(processed_dir) if f.startswith('sample_')]
# 随机选择场景
selected_scenarios = np.random.choice(scenario_files, min(num_scenarios, len(scenario_files)), replace=False)
for scenario_file in selected_scenarios:
scenario_id = scenario_file.replace('sample_', '').replace('.pkl', '')
print(f"正在可视化场景: {scenario_id}")
visualize_scenario(scenario_id, split)
if __name__ == "__main__":
# 可视化训练集中的指定个随机场景
visualize_random_scenarios('train', 20)
# 数据集配置 DATA_CONFIG = { 'waymo_path': '/mnt/f/waymoMotion/motionData/waymo_data', # Waymo数据集根目录 'train_path': '/mnt/f/waymoMotion/motionData/waymo_data/train', # 训练集路径 'valid_path': '/mnt/f/waymoMotion/motionData/waymo_data/valid', # 验证集路径 'test_path': '/mnt/f/waymoMotion/motionData/waymo_data/test', # 测试集路径 }
相关信息
本篇代码受 https://blog.csdn.net/weixin_50232758/article/details/132260047 启发,如有侵权回复删除
本文作者:丰墨
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!