Browse Source

fix se-resnet performance and ctpn standalone scripts

pull/16107/head
qujianwei 4 years ago
parent
commit
ca895d2d6e
2 changed files with 11 additions and 4 deletions
  1. +10
    -3
      model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh
  2. +1
    -1
      model_zoo/official/cv/resnet/train.py

+ 10
- 3
model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh View File

@@ -13,9 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_standalone_train.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]"
echo "for example: sh run_standalone_train.sh Pretraining /path/vgg16_backbone.ckpt 0"
echo "when device id is occupied, choose for another one"
echo "It is better to use absolute path."
echo "=============================================================================================================="
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]"
echo "Usage: sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] [DEVICE_ID]"
exit 1
fi

@@ -38,7 +45,7 @@ fi

ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export DEVICE_ID=$3
export RANK_ID=0
export RANK_SIZE=1



+ 1
- 1
model_zoo/official/cv/resnet/train.py View File

@@ -218,7 +218,7 @@ if __name__ == '__main__':
metrics = {"acc"}
if args_opt.run_distribute:
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)}
if (args_opt.net not in ("resnet18", "resnet50", "resnet101")) or \
if (args_opt.net not in ("resnet18", "resnet50", "resnet101", "se-resnet50")) or \
args_opt.parameter_server or target == "CPU":
## fp32 training
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network)


Loading…
Cancel
Save