diff --git a/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh index 1649590a04..394ea84531 100644 --- a/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh @@ -13,9 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -if [ $# -ne 2 ] +echo "==============================================================================================================" +echo "Please run the script as: " +echo "sh run_standalone_train.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]" +echo "for example: sh run_standalone_train.sh Pretraining /path/vgg16_backbone.ckpt 0" +echo "when device id is occupied, choose for another one" +echo "It is better to use absolute path." +echo "==============================================================================================================" +if [ $# -ne 3 ] then - echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]" + echo "Usage: sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]" exit 1 fi @@ -38,7 +45,7 @@ fi ulimit -u unlimited export DEVICE_NUM=1 -export DEVICE_ID=0 +export DEVICE_ID=$3 export RANK_ID=0 export RANK_SIZE=1 diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index f03a7e2a9b..4298331aaf 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -218,7 +218,7 @@ if __name__ == '__main__': metrics = {"acc"} if args_opt.run_distribute: metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)} - if (args_opt.net not in ("resnet18", "resnet50", "resnet101")) or \ + if (args_opt.net not in ("resnet18", "resnet50", "resnet101", "se-resnet50")) or \ args_opt.parameter_server or target == "CPU": ## fp32 training model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)