diff --git a/model_examples/DETR3D/README.md b/model_examples/DETR3D/README.md index b39d719cdd0c4da53a825307cf7b209a4e67b95c..410e219fd566f3bf9d139ff85efb868e6f315e4f 100644 --- a/model_examples/DETR3D/README.md +++ b/model_examples/DETR3D/README.md @@ -225,6 +225,7 @@ cd model_examples/DETR3D/detr3d 2024.12.30:首次发布 2025.1.13: 性能优化 +2025.9.1: 更改训练脚本配置 # FAQ diff --git a/model_examples/DETR3D/test/train_8p_full.sh b/model_examples/DETR3D/test/train_8p_full.sh index c1382fcda2d58cbfe2557ad0027112ff5597e5c2..2abb42c0e7be6cc77e1b183914453831e4dd592f 100644 --- a/model_examples/DETR3D/test/train_8p_full.sh +++ b/model_examples/DETR3D/test/train_8p_full.sh @@ -1,6 +1,6 @@ # 网络名称,同目录名称,需要模型审视修改 Network="DETR3D" -batch_size=1 +batch_size=3 world_size=8 py_config="projects/configs/detr3d/detr3d_res101_gridmask.py" diff --git a/model_examples/DETR3D/test/train_8p_performance.sh b/model_examples/DETR3D/test/train_8p_performance.sh index bced6cb5bbef338e1e3d8ace2444387e026c428c..a1a88c89d27f804195faa6a89b6cc9eafb7dd0ce 100644 --- a/model_examples/DETR3D/test/train_8p_performance.sh +++ b/model_examples/DETR3D/test/train_8p_performance.sh @@ -1,6 +1,6 @@ # 网络名称,同目录名称,需要模型审视修改 Network="DETR3D" -batch_size=1 +batch_size=3 world_size=8 py_config="projects/configs/detr3d/detr3d_res101_gridmask.py" @@ -41,7 +41,7 @@ echo "end_time=$(date -d @${end_time} "+%Y-%m-%d %H:%M:%S")" e2e_time=$(( $end_time - $start_time )) # 模型单epoch的step数量 -total_step=3517 +total_step=1173 # 单迭代训练时长 avg_time=$(echo "scale=3; $e2e_time / $total_step" | bc)