-
Trajectory-Transformer代码:一款用于轨迹预测的变压器网
资源介绍
用于轨迹预测的 Transformer 网络
这是论文的代码
要求
pytorch 1.0+
麻木
西比
熊猫
张量板
(项目中包含的是修改版)
用法
数据设置
数据集文件夹必须具有以下结构:
- dataset
- dataset_name
- train_folder
- test_folder
- validation_folder (optional)
- clusters.mat (For quantizedTF)
个人变压器
要训练,只需运行具有不同参数的train_individual.py
示例:训练 eth 的数据
CUDA_VISIBLE_DEVICES=0 python train_individualTF.py --dataset_name eth --name eth --max_epoch 240 --bat