From 642761c2b1ec84f8f6f2858c249a0491186b2bcb Mon Sep 17 00:00:00 2001 From: jjfeing Date: Mon, 25 May 2020 15:09:18 +0800 Subject: [PATCH] adapte Second order optimization ops for thor ops for impl of 2nd-order and format for format for pylint 2nd for pylint 3rd for pylint 4th for pylint 5th for pylint nth for comments for debug for DEBUG for DEBUG for DEBUG for DEBUG for well performance for pylint for te chip for pylint for pylint nth for modification of comments --- example/resnet50_imagenet2012_THOR/config.py | 12 +- example/resnet50_imagenet2012_THOR/eval.py | 60 + .../resnet50_imagenet2012_THOR/model/thor.py | 13 +- .../model/thor_layer.py | 38 +- .../run_distribute_train.sh | 3 +- .../resnet50_imagenet2012_THOR/run_infer.sh | 64 + example/resnet50_imagenet2012_THOR/train.py | 2 +- mindspore/ops/_op_impl/__init__.py | 1 + mindspore/ops/_op_impl/_custom_op/__init__.py | 16 + .../_op_impl/_custom_op/batch_matmul_impl.py | 257 ++++ .../_op_impl/_custom_op/cholesky_trsm_impl.py | 111 ++ .../_custom_op/fused_abs_max1_impl.py | 1082 ++++++++++++++++ .../ops/_op_impl/_custom_op/img2col_impl.py | 1151 +++++++++++++++++ .../_custom_op/matmul_cube_dense_left_impl.py | 468 +++++++ .../matmul_cube_dense_right_impl.py | 172 +++ .../matmul_cube_fracz_left_cast_impl.py | 526 ++++++++ .../matmul_cube_fracz_right_mul_impl.py | 247 ++++ .../_op_impl/_custom_op/matmul_cube_impl.py | 397 ++++++ .../_custom_op/matrix_combine_impl.py | 81 ++ .../_custom_op/transpose02314_impl.py | 289 +++++ .../_op_impl/custom_op/batch_matmul_impl.py | 76 -- .../ops/_op_impl/custom_op/cholesky_trsm.py | 64 - .../ops/_op_impl/custom_op/fused_abs_max1.py | 69 - .../ops/_op_impl/custom_op/img2col_impl.py | 87 -- .../custom_op/matmul_cube_dense_left.py | 101 -- .../matmul_cube_fracz_left_cast_impl.py | 102 -- .../matmul_cube_fracz_right_mul_impl.py | 113 -- .../_op_impl/custom_op/matmul_cube_impl.py | 114 -- .../_op_impl/custom_op/matrix_combine_impl.py | 63 - .../_op_impl/custom_op/transpose02314_impl.py | 63 - mindspore/ops/op_info_register.py | 6 +- mindspore/ops/operations/__init__.py | 1 + mindspore/ops/operations/thor_ops.py | 291 ++++- 33 files changed, 5218 insertions(+), 922 deletions(-) create mode 100755 example/resnet50_imagenet2012_THOR/eval.py create mode 100755 example/resnet50_imagenet2012_THOR/run_infer.sh create mode 100644 mindspore/ops/_op_impl/_custom_op/__init__.py create mode 100644 mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/img2col_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py create mode 100644 mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/cholesky_trsm.py delete mode 100644 mindspore/ops/_op_impl/custom_op/fused_abs_max1.py delete mode 100644 mindspore/ops/_op_impl/custom_op/img2col_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py delete mode 100644 mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py delete mode 100644 mindspore/ops/_op_impl/custom_op/transpose02314_impl.py diff --git a/example/resnet50_imagenet2012_THOR/config.py b/example/resnet50_imagenet2012_THOR/config.py index 6c664891f7..cd0a81d5e6 100644 --- a/example/resnet50_imagenet2012_THOR/config.py +++ b/example/resnet50_imagenet2012_THOR/config.py @@ -23,7 +23,7 @@ config = ed({ "loss_scale": 128, "momentum": 0.9, "weight_decay": 5e-4, - "epoch_size": 50, + "epoch_size": 45, "buffer_size": 1000, "image_height": 224, "image_width": 224, @@ -31,15 +31,7 @@ config = ed({ "save_checkpoint_steps": 5004, "keep_checkpoint_max": 20, "save_checkpoint_path": "./", - "lr_init": 0.01, - "lr_end": 0.00001, - "lr_max": 0.1, - "warmup_epochs": 0, - "lr_decay_mode": "cosine", "label_smooth": 1, "label_smooth_factor": 0.1, - "lr": 0.1, - "T_max": 90, - "eta_min": 0, - "frequency": 278 + "frequency": 834 }) diff --git a/example/resnet50_imagenet2012_THOR/eval.py b/example/resnet50_imagenet2012_THOR/eval.py new file mode 100755 index 0000000000..db82b9fcac --- /dev/null +++ b/example/resnet50_imagenet2012_THOR/eval.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +eval. +""" +import os +import argparse +from dataset_imagenet import create_dataset +from config import config +from mindspore import context +from mindspore.model_zoo.resnet import resnet50 +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from crossentropy import CrossEntropy + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.') +parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) +context.set_context(device_id=device_id) + +if __name__ == '__main__': + + net = resnet50(class_num=config.class_num) + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + + if args_opt.do_eval: + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/example/resnet50_imagenet2012_THOR/model/thor.py b/example/resnet50_imagenet2012_THOR/model/thor.py index 44c0fd45db..0da1714fe6 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor.py +++ b/example/resnet50_imagenet2012_THOR/model/thor.py @@ -21,11 +21,6 @@ from mindspore.common.tensor import Tensor from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean - -from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight -from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast -from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft -from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul from model.grad_reducer_thor import DistributedGradReducerThor momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -68,10 +63,10 @@ class THOR(Optimizer): self.matrix_G = ParameterTuple(matrix_G) self.A_inv_max = ParameterTuple(A_inv_max) self.G_inv_max = ParameterTuple(G_inv_max) - self.cube_matmul_left = CusMatMulCubeFraczLeftCast() - self.cube_matmul_left_fc = CusMatMulCubeDenseLeft() - self.cube_matmul_right_fc = CusMatMulCubeDenseRight() - self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul() + self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() + self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() + self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() + self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() self.transpose = P.Transpose() self.shape = P.Shape() self.reshape = P.Reshape() diff --git a/example/resnet50_imagenet2012_THOR/model/thor_layer.py b/example/resnet50_imagenet2012_THOR/model/thor_layer.py index 8097d729ea..fea74605b6 100644 --- a/example/resnet50_imagenet2012_THOR/model/thor_layer.py +++ b/example/resnet50_imagenet2012_THOR/model/thor_layer.py @@ -23,19 +23,9 @@ from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation from mindspore.ops import operations as P - -from cus_ops.cus_batch_matmul import CusBatchMatMul -from cus_ops.cus_cholesky_trsm import CusCholeskyTrsm -from cus_ops.cus_fused_abs_max1 import CusFusedAbsMax1 -from cus_ops.cus_img2col import CusImg2Col -from cus_ops.cus_matmul_cube import CusMatMulCube -from cus_ops.cus_matrix_combine import CusMatrixCombine -from cus_ops.cus_transpose02314 import CusTranspose02314 - import numpy as np C0 = 16 - def caculate_device_shape(matrix_dim, channel, is_A): ll = (0) if is_A: @@ -153,11 +143,11 @@ class Conv2d_Thor(_Conv): group=self.group ) - self.img2col = CusImg2Col(ksizes=ksizes, strides=strides) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() - self.transpose02314 = CusTranspose02314() + self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.transpose02314 = P.CusTranspose02314() self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] self.matrix_G_dim = self.out_channels self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, @@ -190,7 +180,7 @@ class Conv2d_Thor(_Conv): self.mul = P.Mul() self.cast = P.Cast() self.damping = Tensor(damping) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.diag_block_dim = 128 self.channels_slice_flag = False if self.in_channels % C0 != 0: @@ -221,8 +211,8 @@ class Conv2d_Thor(_Conv): self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) - self.fused_abs_max1 = CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.sqrt = P.Sqrt() @@ -375,9 +365,9 @@ class Dense_Thor(Cell): self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) self.matmul = P.MatMul(transpose_b=True) - self.cube_matmul = CusMatMulCube(transpose_a=True) - self.matrix_combine = CusMatrixCombine() - self.cholesky = CusCholeskyTrsm() + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() self.shape = P.Shape() self.reshape = P.Reshape() self.transpose = P.Transpose() @@ -386,7 +376,7 @@ class Dense_Thor(Cell): self.cast = P.Cast() self.damping = Tensor(damping) self.loss_scale = Tensor(1 / loss_scale, mstype.float16) - self.vector_matmul = CusBatchMatMul() + self.vector_matmul = P.CusBatchMatMul() self.pad = P.Pad(((0, 24), (0, 24))) self.pad1 = P.Pad(((0, 8), (0, 8))) self.slice = P.Slice() @@ -396,8 +386,8 @@ class Dense_Thor(Cell): self.axis = 0 self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max1 = CusFusedAbsMax1([1000, 1000]) - self.fused_abs_max2 = CusFusedAbsMax1() + self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) + self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() self.dampingA = Tensor(np.identity(2048), mstype.float32) diff --git a/example/resnet50_imagenet2012_THOR/run_distribute_train.sh b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh index ae05c45dfe..e39034a912 100644 --- a/example/resnet50_imagenet2012_THOR/run_distribute_train.sh +++ b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh @@ -45,8 +45,7 @@ do mkdir ./train_parallel$i cp *.py ./train_parallel$i cp *.sh ./train_parallel$i - cp -r second_order ./train_parallel$i/second_order - cp -r test_ops ./train_parallel$i/test_ops + cp -r model ./train_parallel$i cd ./train_parallel$i || exit echo "start training for rank $RANK_ID, device $DEVICE_ID" diff --git a/example/resnet50_imagenet2012_THOR/run_infer.sh b/example/resnet50_imagenet2012_THOR/run_infer.sh new file mode 100755 index 0000000000..14d7faf981 --- /dev/null +++ b/example/resnet50_imagenet2012_THOR/run_infer.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "infer" ]; +then + rm -rf ./infer +fi +mkdir ./infer +cp *.py ./infer +cp *.sh ./infer +cd ./infer || exit +env > env.log +echo "start infering for device $DEVICE_ID" +python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & +cd .. diff --git a/example/resnet50_imagenet2012_THOR/train.py b/example/resnet50_imagenet2012_THOR/train.py index b98d13b8a0..15710bc66b 100644 --- a/example/resnet50_imagenet2012_THOR/train.py +++ b/example/resnet50_imagenet2012_THOR/train.py @@ -109,7 +109,7 @@ if __name__ == '__main__': step_size = dataset.get_dataset_size() loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004)) + lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), diff --git a/mindspore/ops/_op_impl/__init__.py b/mindspore/ops/_op_impl/__init__.py index 65a12cd73c..725977877d 100644 --- a/mindspore/ops/_op_impl/__init__.py +++ b/mindspore/ops/_op_impl/__init__.py @@ -19,5 +19,6 @@ from .aicpu import * if "Windows" not in platform.system(): from .akg.gpu import * from .tbe import * + from ._custom_op import * __all__ = [] diff --git a/mindspore/ops/_op_impl/_custom_op/__init__.py b/mindspore/ops/_op_impl/_custom_op/__init__.py new file mode 100644 index 0000000000..5fe583a60f --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""custom ops""" diff --git a/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py new file mode 100644 index 0000000000..97982c53cf --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py @@ -0,0 +1,257 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""batch_matmul_impl""" + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batchmatmul.so") \ + .compute_cost(10) \ + .kernel_name("CusBatchMatMul") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +def _get_flattern_shape(shape): + """_get_flattern_shape""" + flattern_shape = 1 + for dim in shape: + flattern_shape *= dim + return (flattern_shape,) + + +def _inner_matmul_new(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new""" + input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", scope=tik.scope_ubuf) + t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 16, 0, 0) + with tik_instance.for_range(0, 2) as vec_i: + tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, 64, 1, 1, 16, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + bisec_last_axis_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB_dst_tmp", + scope=tik.scope_ubuf) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) + tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 8192], 0, 1, 1024, 0, 0) + tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) + tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, + 16, 16, 16) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) + with tik_instance.for_range(0, 64) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], bisec_last_axis_local_UB[cc6 * 128], + 1, 1, 1, 8) + tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, + matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[res_index + thread_idx2 * 64], + matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) + + +def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, input2, input2_index, res, res_index): + """_inner_matmul_new_1_64_32_64""" + input_1_local_UB = tik_instance.Tensor(dtype, [64], name="input_1_local_UB", scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[input1_index], 0, 1, 8, 0, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [32 * 64], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [32], name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_2_local_UB, input2[input2_index + thread_idx2 * 2048], 0, 1, 256, 0, 0) + tik_instance.vmul(64, t_1_local_UB, input_1_local_UB, input_2_local_UB, 32, 1, 1, 1, 8, 0, 8) + with tik_instance.for_range(0, 32) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB[cc6], t_1_local_UB[cc6 * 64], + 1, 1, 1, 8) + tik_instance.data_move(res[res_index + thread_idx2 * 32], + matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0) + + +@op_info_register(cus_batchmatmul_op_info) +def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): + """CusBatchMatMul""" + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + x1_shape = input_x1.get("shape") + dtype = input_x1.get("dtype").lower() + x2_shape = input_x2.get("shape") + if dtype != input_x2.get("dtype").lower(): + raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % ( + dtype, input_x2.get("dtype").lower())) + input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b) + support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True), + ((36, 128, 128), (36, 128, 128), "float32", False, True), + ((5, 128, 128), (5, 128, 128), "float32", False, True), + ((18, 128, 128), (18, 128, 128), "float32", False, True), + ((16, 128, 128), (16, 128, 128), "float32", False, True), + ((9, 128, 128), (9, 128, 128), "float32", False, True), + ((1, 64, 64), (1, 64, 64), "float32", False, True), + ((1, 128, 128), (1, 128, 128), "float32", False, True), + ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((2, 128, 128), (2, 128, 128), "float32", False, True)] + if input_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + # if not transpose_a and transpose_b: + batch, m, k = x1_shape + + input1_shape = _get_flattern_shape(x1_shape) + input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm) + input2_shape = _get_flattern_shape(x2_shape) + input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm) + + output_shape = x1_shape + res_shape = _get_flattern_shape(output_shape) + res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm) + + if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 2) as cc0: + with tik_instance.for_range(0, 128, thread_num=2) as cc1: + input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + input2_index = block_idx * 32768 + cc0 * 16384 + res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 30, block_num=30) as block_idx: + with tik_instance.for_range(0, 11) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx: + with tik_instance.if_scope(((((block_idx % 6) * 22) + (cc1_db * 2) + thread_idx) < 128)): + input_1_local_UB = tik_instance.Tensor(dtype, [128], name="input_1_local_UB", + scope=tik.scope_ubuf) + t_1_0_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="t_1_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input1[ + (block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + thread_idx * 128], 0, 1, + 16, 0, 0) + with tik_instance.for_range(0, 2) as vec_i: + tik_instance.vadds(64, t_1_0_local_UB[vec_i * 64], input_1_local_UB[vec_i * 64], 0, + 64, 1, 1, 16, 0) + with tik_instance.for_range(0, 2, thread_num=2) as thread_idx2: + input_2_local_UB = tik_instance.Tensor(dtype, [64 * 128], name="input_2_local_UB", + scope=tik.scope_ubuf) + t_1_local_UB = input_2_local_UB + bisec_last_axis_local_UB = input_2_local_UB + matmul_hybrid_f_t_local_UB = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB", + scope=tik.scope_ubuf) + matmul_hybrid_f_t_local_UB_dst_tmp = tik_instance.Tensor(dtype, [64], + name="matmul_hybrid_f_t_local_UB_dst_tmp", + scope=tik.scope_ubuf) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB, 0, 1, 1, 8) + tik_instance.data_move(input_2_local_UB, + input2[(block_idx // 6) * 16384 + thread_idx2 * 8192], 0, 1, + 1024, 0, 0) + tik_instance.vmul(64, t_1_local_UB, t_1_0_local_UB, input_2_local_UB, 128, 1, 1, 1, 8, 8, 8) + tik_instance.vadd(64, bisec_last_axis_local_UB, t_1_local_UB, t_1_local_UB[64], 64, 1, 1, 1, + 16, 16, 16) + tik_instance.vector_dup(64, matmul_hybrid_f_t_local_UB_dst_tmp, 0, 1, 1, 8) + with tik_instance.for_range(0, 64) as cc6: + tik_instance.vcadd(64, matmul_hybrid_f_t_local_UB_dst_tmp[cc6], + bisec_last_axis_local_UB[cc6 * 128], + 1, 1, 1, 8) + tik_instance.vadd(64, matmul_hybrid_f_t_local_UB, matmul_hybrid_f_t_local_UB_dst_tmp, + matmul_hybrid_f_t_local_UB, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move( + res[(block_idx // 6) * 16384 + (block_idx % 6) * 2816 + cc1_db * 256 + + thread_idx * 128 + thread_idx2 * 64], + matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0) + + if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 18, block_num=18) as block_idx: + with tik_instance.for_range(0, 128, thread_num=2) as cc0: + input1_index = block_idx * 16384 + cc0 * 128 + input2_index = block_idx * 16384 + res_index = block_idx * 16384 + cc0 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True): + with tik_instance.for_range(0, 27, block_num=27) as block_idx: + with tik_instance.for_range(0, 42, thread_num=2) as cc0: + input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 + input2_index = (block_idx // 3) * 16384 + res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + cc0 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + with tik_instance.if_scope((block_idx % 3) < 2): + input1_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 + input2_index = (block_idx // 3) * 16384 + res_index = (block_idx // 3) * 16384 + (block_idx % 3) * 5504 + 42 * 128 + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 2, thread_num=2) as cc0: + input1_index = block_idx * 128 + cc0 * 64 + input2_index = 0 + res_index = block_idx * 128 + cc0 * 64 + _inner_matmul_new_1_64_32_64(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True), + ((2, 128, 128), (2, 128, 128), "float32", False, True), + ((4, 128, 128), (4, 128, 128), "float32", False, True), + ((8, 128, 128), (8, 128, 128), "float32", False, True), + ((16, 128, 128), (16, 128, 128), "float32", False, True) + ] + if input_shape in input_shape_list: + block_num = 32 + input1_unit_size = 128 + input2_unint_size = 128 * 128 + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + block_process_ele_num = (batch * m * k) // block_num + loop_time = (batch * m * k) // block_num // input1_unit_size + thread_num = 2 + with tik_instance.for_range(0, loop_time, thread_num=thread_num) as cc0: + input1_index = block_idx * block_process_ele_num + cc0 * input1_unit_size + if batch > 1: + input2_index = block_idx // (block_num // batch) * input2_unint_size + else: + input2_index = 0 + res_index = block_idx * block_process_ele_num + cc0 * input1_unit_size + _inner_matmul_new(tik_instance, dtype, + input1, input1_index, + input2, input2_index, + res, res_index) + + tik_instance.BuildCCE(kernel_name, inputs=[input1, input2], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py new file mode 100644 index 0000000000..71dd1ccb2d --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusCholeskyTrsm""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_cholesky_trsm_op_info = TBERegOp("CusCholeskyTrsm") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("choleskytrsm.so") \ + .compute_cost(10) \ + .kernel_name("CusCholeskyTrsm") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_cholesky_trsm_op_info) +def CusCholeskyTrsm(input_x, output, kernel_name): + """CusCholeskyTrsm""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + split_dim = 128 + matrix_dim = input_x_shape[0] + split_dim = min(matrix_dim, split_dim) + vector_repeat_times = int(split_dim // 64) + blocks = int(matrix_dim // split_dim) + if blocks == 0: + blocks = 1 + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="input_x_ub", scope=tik.scope_ubuf) + temp_ub = tik_instance.Tensor("float32", (split_dim, split_dim), name="temp_ub", scope=tik.scope_ubuf) + assist_1_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_1_ub", scope=tik.scope_ubuf) + assist_2_ub = tik_instance.Tensor("float32", (split_dim,), name="assist_2_ub", scope=tik.scope_ubuf) + with tik_instance.for_range(0, split_dim) as i: + tik_instance.data_move(input_x_ub[i, 0], input_x[block_index * split_dim + i, block_index * split_dim], 0, + 1, vector_repeat_times * 8, 0, 0) + scalar1 = tik_instance.Scalar("float32", init_value=-0.5) + + with tik_instance.for_range(0, split_dim) as i: + scalar2 = tik_instance.Scalar("float32") + tik_instance.vln(64, assist_1_ub[0], input_x_ub[i, 0], vector_repeat_times, 1, 1, 8, 8) + tik_instance.vmuls(64, assist_2_ub[0], assist_1_ub[0], scalar1, vector_repeat_times, 1, 1, 8, 8) + tik_instance.vexp(64, assist_1_ub[0], assist_2_ub[0], vector_repeat_times, 1, 1, 8, 8) + scalar2.set_as(assist_1_ub[i]) + tik_instance.vmuls(64, input_x_ub[i, 0], input_x_ub[i, 0], scalar2, vector_repeat_times, 1, 1, 8, 8) + with tik_instance.for_range(i + 1, split_dim) as j: + scalar3 = tik_instance.Scalar("float32") + scalar3.set_as(input_x_ub[i, j]) + tik_instance.vmuls(64, temp_ub[j, 0], input_x_ub[i, 0], scalar3, vector_repeat_times, 1, 1, 8, 8) + tik_instance.vsub(64, input_x_ub[i + 1, 0], input_x_ub[i + 1, 0], temp_ub[i + 1, 0], + (split_dim - 1 - i) * vector_repeat_times, 1, 1, 1, 8, 8, 8) + + zero = tik_instance.Scalar("float32") + zero.set_as(0.0) + one = tik_instance.Scalar("float32") + one.set_as(1.0) + with tik_instance.for_range(0, split_dim) as i: + tik_instance.vector_dup(64, temp_ub[i, 0], zero, vector_repeat_times, 1, 8) + temp_ub.__setitem__(i * split_dim + i, one) + + chol_diag_element_final = tik_instance.Scalar("float32") + chol_diag_element_final.set_as(input_x_ub[split_dim * split_dim - 1]) + trsm_diag_element = tik_instance.Scalar("float32") + trsm_diag_element.set_as(1.0 / chol_diag_element_final) + temp_ub.__setitem__(split_dim * split_dim - 1, trsm_diag_element) + + with tik_instance.for_range(1, split_dim) as i: + index = split_dim - i - 1 + tik_instance.vector_dup(64, assist_1_ub, zero, vector_repeat_times, 1, 8) + with tik_instance.for_range(0, i) as j: + chol_diag_element_loop = tik_instance.Scalar("float32") + chol_diag_element_loop.set_as(input_x_ub[index, index + 1 + j]) + tik_instance.vmuls(64, assist_2_ub, temp_ub[j + index + 1, 0], chol_diag_element_loop, + vector_repeat_times, 1, 1, 8, 8) + tik_instance.vadd(64, assist_1_ub, assist_2_ub, assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, 8) + temp_scalar = tik_instance.Scalar("float32") + temp_scalar.set_as(input_x_ub[index, index]) + chol_diag_element = tik_instance.Scalar("float32") + chol_diag_element.set_as(1.0 / temp_scalar) + tik_instance.vsub(64, temp_ub[index, 0], temp_ub[index, 0], assist_1_ub, vector_repeat_times, 1, 1, 1, 8, 8, + 8) + tik_instance.vmuls(64, temp_ub[index, 0], temp_ub[index, 0], chol_diag_element, vector_repeat_times, 1, 1, + 8, 8) + + tik_instance.data_move(res[block_index, 0, 0], temp_ub, 0, 1, 8 * vector_repeat_times * split_dim, 0, 0) + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py new file mode 100644 index 0000000000..f4b8d44063 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py @@ -0,0 +1,1082 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusFusedAbsMax1""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fusedabsmax1.so") \ + .compute_cost(10) \ + .kernel_name("CusFusedAbsMax1") \ + .partial_flag(True) \ + .attr("origin_shape", "required", "listInt", "all") \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_fused_abs_max1_op_info) +def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): + """CusFusedAbsMax1""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + if len(input_x_shape) > 2: + if (input_x_shape[0] == 1 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 4 and input_x_shape[1] == 16) or (input_x_shape[0] == 16 and input_x_shape[1] == 4): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 8): + if origin_shape[0] == 147 and ( + input_x_shape[0] == 2 and input_x_shape[1] == 128 and input_x_shape[2] == 128): + assert origin_shape[0] == 147 + assert origin_shape[1] == 147 + phase_1 = 16384 + phase_2 = 1216 + blocks = 32 + each_block_element = phase_1 // blocks + 64 + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[512 * block_index], 0, 1, 512 // 8, 0, 0) + line_id = block_index % 19 + tik_instance.data_move(input_x_ub[512], input_x[16384 + 128 * line_id], 0, 1, 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(19, input_x_ub, input_x_ub, input_x_ub[512], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + else: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 4 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 8 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 8): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 32 and input_x_shape[1] == 16) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 32): + if (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ + 0] == 1000: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + blocks = 32 + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 8 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + vmask = 1000 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + + elif (input_x_shape[0] == 8 and input_x_shape[1] == 128 and input_x_shape[2] == 128) and origin_shape[ + 0] == 1001: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + blocks = 32 + each_block_element = 7 * 128 * 128 // 32 + 4 * 128 + phase_1 = 7 * 128 * 128 // 32 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, + 0) + move_idx = block_index % 9 + tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1, + 128 // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + vmask = 1001 - 7 * 128 - 64 + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx], + input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 4) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + else: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, + 1, 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, + 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 16 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 16 and input_x_shape[1] == 64) or ( + input_x_shape[0] == 64 and input_x_shape[1] == 16): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 5 and input_x_shape[1] == 128 and input_x_shape[2] == 128 and origin_shape[0] == 576: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 69632 + blocks = 32 + each_block_element = total_elements // blocks + phase_1 = 2048 + phase_2 = 128 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1], input_x[65536 + phase_2 * block_index * 2], 0, 1, 8, 0, 0) + tik_instance.data_move(input_x_ub[phase_1 + 64], input_x[65536 + 128 + phase_2 * block_index * 2], 0, 1, + 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub[2048], input_x_ub[2048], input_x_ub[2048 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 9 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 72 and input_x_shape[1] == 8): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 18 and input_x_shape[1] == 128 and input_x_shape[2] == 128: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 36 and input_x_shape[1] == 128 and input_x_shape[2] == 128) or ( + input_x_shape[0] == 144 and input_x_shape[1] == 16): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, + 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 1024], 16, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 512], 8, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 256], 4, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 128], 2, 1, 1, 1, 8, 8, + 8) + tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 128 and input_x_shape[1] == 63: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 * 3 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 3 * 64], input_x_ub[repeat_time_1 * 3 * 64], + repeat_time_2, 1, 1, 8, 8) + loop_size = each_block_element // 16384 + with tik_instance.for_range(0, loop_size) as loop_idx: + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, loop_size - 1) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, + 8) + tail_element = each_block_element - 16384 * loop_size + repeats = tail_element // 64 + with tik_instance.for_range(0, repeats) as i: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * loop_size + i * 64], 1, 1, 1, 1, 8, + 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, input_x_ub[64 + cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[2048 + 64], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[1024 + 64], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[512 + 64], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[256 + 64], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[128 + 64], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[64], input_x_ub[64], input_x_ub[64 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[block_index, 0], input_x_ub[64], 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 32 and input_x_shape[1] == 128) or ( + input_x_shape[0] == 128 and input_x_shape[1] == 32): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 * 2 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + loop_size = each_block_element // 16384 + with tik_instance.for_range(0, loop_size) as loop_idx: + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16384 * loop_idx], input_x_ub[16384 * loop_idx], + input_x_ub[16384 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, loop_size - 1) as loop_idx: + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384 * (loop_idx + 1)], 1, 1, 1, 1, 8, 8, + 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 288 and input_x_shape[1] == 32: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0) + tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) + input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + repeat_time_1 = 255 + repeat_time_2 = 32768 // 64 - 255 * 2 + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 65536], 0, 1, 1024, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, 128, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(assist_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 64 and input_x_shape[1] == 128: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + assist_ub = tik_instance.Tensor("float32", (64,), name="assist_ub", scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0) + tik_instance.vector_dup(64, assist_ub, zero, 1, 1, 8) + input_x_ub = tik_instance.Tensor("float32", (32768,), name="input_x_ub", scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + repeat_time_1 = 255 + repeat_time_2 = 32768 // 64 - 255 * 2 + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 0], 0, 1, 4096, 0, 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + tik_instance.data_move(input_x_ub[0], input_x[each_block_element * block_index + 32768], 0, 1, 4096, 0, + 0) + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_1, 1, + 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 2 * 64], input_x_ub[repeat_time_1 * 2 * 64], + repeat_time_2, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 255, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[16320], input_x_ub[16320], input_x_ub[32704], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, assist_ub, assist_ub, input_x_ub, 1, 1, 1, 1, 8, 8, 8) + + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(assist_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif (input_x_shape[0] == 64 and input_x_shape[1] == 32) or (input_x_shape[0] == 32 and input_x_shape[1] == 64): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time_1 = 255 + repeat_time_2 = each_block_element // 64 - 255 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time_1, 1, 1, 8, 8) + tik_instance.vabs(64, input_x_ub[repeat_time_1 * 64], input_x_ub[repeat_time_1 * 64], repeat_time_2, 1, + 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 128, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 36 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 4 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 49 and input_x_shape[1] == 4: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, 24, 1, 1, 8, 8) + tik_instance.vabs(32, input_x_ub[1536], input_x_ub[1536], 1, 1, 1, 8, 8) + tik_instance.vmax(32, input_x_ub[1504], input_x_ub[1504], input_x_ub[1536], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub[1024], input_x_ub[1024], input_x_ub[1024 + 64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + elif input_x_shape[0] == 1 and input_x_shape[1] == 64 and input_x_shape[2] == 64: + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + total_elements = 1 + for val in input_x_shape: + total_elements *= val + blocks = 32 + each_block_element = total_elements // blocks + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub", + scope=tik.scope_ubuf) + broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1, + each_block_element // 8, 0, 0) + repeat_time = each_block_element // 64 + tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + with tik_instance.for_range(0, 64) as cc0: + data_temp = tik_instance.Scalar("float32") + data_temp.set_as(input_x_ub[cc0]) + tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1, + 1, 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1, + 8, 8, 8) + tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1, + 8, 8, 8) + tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0) + + else: + raise RuntimeError("UnSupportedShape") + elif len(input_x_shape) == 2 and (input_x_shape[0] == 32 and input_x_shape[1] == 64): + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + input_x_ub = tik_instance.Tensor("float32", (32 * 64,), name="input_x_ub", scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x, 0, 1, 256, 0, 0) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8) + tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8) + tik_instance.data_move(res[0], input_x_ub, 0, 1, 1, 0, 0) + else: + raise RuntimeError("UnSupportedShape") + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/img2col_impl.py b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py new file mode 100644 index 0000000000..433e335565 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/img2col_impl.py @@ -0,0 +1,1151 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusImg2ColNC1HWC0""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_img2col_info = TBERegOp("CusImg2Col") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("img2col.so") \ + .compute_cost(10) \ + .kernel_name("CusImg2Col") \ + .partial_flag(True) \ + .attr("ksizes", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("dilates", "required", "listInt", "all") \ + .attr("mode", "required", "str", "all") \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(cus_img2col_info) +def CusImg2Col(input_x, output, ksizes, strides, dilates, mode, kernel_name="img2col"): + """CusImg2Col""" + input_x_shape = input_x.get("shape") + input_x_dtype = input_x.get("dtype") + N, C1, H, W, C0 = input_x_shape + C = C1 * C0 + padding = 'SAME' + _, filter_h, filter_w, _ = ksizes + _, stride_h, stride_w, _ = strides + _, dilation_filter_h, dilation_filter_w, _ = dilates + + input_shape = (tuple(input_x_shape), input_x_dtype, (filter_h, filter_w), (stride_h, stride_w)) + supported_shape = [((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2)), + ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2)), + ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1)), + ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2)), + ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1)), + ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2)), + ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1)), + ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1)), + ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1)), + ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1)), + ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2)), + ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2)), + ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2)), + ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)), + ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)), + ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)), + ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)),] + + if input_shape not in supported_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + output_tmp = [N * int(H // stride_h) * int(W // stride_w), filter_h * filter_w * C] + output_shape = [output_tmp[1] // 16, output_tmp[0] // 16, 16, 16] + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm) + + if input_shape == ((32, 1, 224, 224, 16), 'float16', (7, 7), (2, 2)): + pad = [3, 3, 3, 3] + l1_h = 56 + l1_w = 224 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53760,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + left_top_h = -3 + + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move(res[cc0 + eeb * 7, cc1 + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + with tik_instance.for_range(1, 3) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 56 * eeb0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep_prefix = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + rep = 196 + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + + left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2 + + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move( + res[cc0 + eeb * 7, cc1 + rep_prefix + (eeb0 - 1) * rep + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 56 * 3, 0, 0], 0, 1, 12544, 0, 0) + + with tik_instance.for_range(0, 7) as eeb: + with tik_instance.for_range(0, 7) as cc0: + temp = eeb % 2 + rep_prefix = ((55 - temp - (-3 + eeb)) // 2 + 1) * 7 + 196 * 2 + rep = 784 - rep_prefix + fetch_filter_w = cc0 + fetch_filter_h = eeb + left_top_w = -3 + left_top_h = 1 + ((55 - temp - (-3 + eeb)) // 2 - 29) * 2 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, rep) as cc1: + tik_instance.data_move(res[cc0 + eeb * 7, cc1 + rep_prefix + 784 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[cc1 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 4, 56, 56, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 56 + l1_w = 56 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 9) as eeb0: + rep = 196 + fetch_filter_w = eeb0 % 3 + fetch_filter_h = eeb0 // 3 + left_top_w = -1 + left_top_h = -1 + with tik_instance.for_range(0, 4) as eeb1: + c1_index = eeb1 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, rep) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb0, i + 196 * block_index, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 56, 56, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 56 + l1_w = 56 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (401408,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (112896,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 8) as eeb0: + with tik_instance.for_range(0, 9) as eeb1: + rep = 49 + fetch_filter_w = eeb1 % 3 + fetch_filter_h = eeb1 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = eeb0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[49 * 256 * eeb1], input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 9) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb1 + eeb0 * 9, 49 * block_index + i, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 28, 28, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (112896,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 6272, 0, 0) + with tik_instance.for_range(0, 8) as eeb0: + with tik_instance.for_range(0, 9) as eeb1: + rep = 49 + fetch_filter_w = eeb1 % 3 + fetch_filter_h = eeb1 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = eeb0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[49 * 256 * eeb1], input_1_1_local_L1, + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 9) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb1 + eeb0 * 9, 49 * block_index + i, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 28, 28, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (200704,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_1_local_L1[i * 12544], input_x[i + 16 * eeb0, eeb1, 0, 0, 0], 0, 1, 784, + 0, 0) + + with tik_instance.for_range(0, 9) as eeb3: + rep = 13 + fetch_filter_w = eeb3 % 3 + fetch_filter_h = eeb3 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 16) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + + with tik_instance.for_range(196 * eeb0, 196 * (eeb0 + 1)) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb3, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - 196 * eeb0)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 14, 14, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (50176,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i + 16 * eeb0, eeb1, 0, 0, 0], 0, 1, 196, + 0, 0) + + with tik_instance.for_range(0, 9) as eeb3: + rep = 13 + fetch_filter_w = eeb3 % 3 + fetch_filter_h = eeb3 // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 16) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + + with tik_instance.for_range(196 * eeb0, 196 * (eeb0 + 1)) as i: + tik_instance.data_move(res[eeb1 * 9 + eeb3, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - 196 * eeb0)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 14, 14, 16), 'float16', (3, 3), (2, 2)): + pad = [1, 1, 1, 1] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index, 0, 0, 0], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 9) as eeb: + rep = 4 + fetch_filter_w = eeb % 3 + fetch_filter_h = eeb // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (2, 2)): + pad = [0, 0, 0, 0] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + + with tik_instance.for_range(0, 2) as eeb0: + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index * 2 + eeb0, 0, 0, 0], 0, + 1, 196, 0, 0) + with tik_instance.for_range(0, 32) as i: + rep = 4 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb0 + block_index * 2, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 7, 7, 16), 'float16', (3, 3), (1, 1)): + pad = [1, 1, 1, 1] + l1_h = 7 + l1_w = 7 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 9) as eeb: + rep = 4 + fetch_filter_w = eeb % 3 + fetch_filter_h = eeb // 3 + left_top_w = -1 + left_top_h = -1 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb + block_index * 9, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 128, 7, 7, 16), 'float16', (1, 1), (1, 1)): + pad = [0, 0, 0, 0] + l1_h = 7 + l1_w = 7 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 4) as eeb0: + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, eeb0 + block_index * 4, 0, 0, 0], 0, + 1, 49, 0, 0) + with tik_instance.for_range(0, 32) as i: + rep = 4 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + with tik_instance.for_range(0, 32) as i: + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[eeb0 + block_index * 4, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], + 0, 1, 16, 0, 0) + + if input_shape == ((32, 64, 14, 14, 16), 'float16', (1, 1), (1, 1)): + pad = [0, 0, 0, 0] + l1_h = 14 + l1_w = 14 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_2_local_L1 = tik_instance.Tensor("float16", (100352,), scope=tik.scope_cbuf, + name="input_1_2_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, block_index * 2, 0, 0, 0], 0, 1, 196, 0, + 0) + tik_instance.data_move(input_1_2_local_L1[i * 3136], input_x[i, block_index * 2 + 1, 0, 0, 0], 0, 1, + 196, 0, 0) + with tik_instance.for_range(0, 2) as eeb1: + with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - eeb1 * 16)], + input_1_1_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(eeb1 * 196, (eeb1 + 1) * 196) as i: + tik_instance.data_move(res[block_index * 2, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - eeb1 * 196)], 0, 1, 16, 0, 0) + + with tik_instance.for_range(0, 2) as eeb1: + with tik_instance.for_range(eeb1 * 16, (eeb1 + 1) * 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - eeb1 * 16)], + input_1_2_local_L1[3136 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(eeb1 * 196, (eeb1 + 1) * 196) as i: + tik_instance.data_move(res[block_index * 2 + 1, i, 0, 0], + input_1_2_fractal_L1_local_UB[256 * (i - eeb1 * 196)], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (2, 2)): + pad = [0, 0, 0, 0] + l1_h = 28 + l1_w = 28 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (401408,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (53248,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (50176,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 12544], input_x[i, block_index, 0, 0, 0], 0, 1, 784, 0, 0) + with tik_instance.for_range(0, 16) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * i], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], 0, 1, 16, 0, + 0) + + with tik_instance.for_range(16, 32) as i: + rep = 13 + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_w = 0 + left_top_h = 0 + c1_index = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[3328 * (i - 16)], input_1_1_local_L1[12544 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + rep) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(196, 392) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[256 * (i - 196)], 0, 1, + 16, 0, 0) + + if input_shape == ((32, 32, 7, 7, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 7 + l1_w = 7 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (32768,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (25088,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 784], input_x[i, block_index, 0, 0, 0], 0, 1, 49, 0, 0) + + with tik_instance.for_range(0, 32) as i: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[1024 * i], input_1_1_local_L1[784 * i], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 4) + + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 49 * 16], + input_1_1_fractal_L1_local_UB[i * 1024], 0, 1, 49, 0, 0) + with tik_instance.for_range(0, 98) as i: + tik_instance.data_move(res[block_index, i, 0, 0], input_1_2_fractal_L1_local_UB[i * 256], 0, 1, 16, 0, + 0) + + if input_shape == ((32, 4, 56, 56, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (12544 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (100352 // 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 12544, 0, 0) + with tik_instance.for_range(0, 4) as eeb: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB, input_1_1_local_L1[eeb * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 196) + with tik_instance.for_range(0, 196) as rep: + tik_instance.data_move(res[eeb, rep + block_index * 196, 0, 0], + input_1_1_fractal_L1_local_UB[rep * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 8, 28, 28, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 28 + l1_w = 28 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (6272 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 6272, 0, 0) + with tik_instance.for_range(0, 1) as eeb0: + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[(eeb1 + eeb0 * 8) * 28 * 28 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 32, 28, 28, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 28 + l1_w = 28 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, 0, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 4) as eeb0: + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[(eeb1 + eeb0 * 8) * 28 * 28 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[i * 256 + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + + if input_shape == ((32, 16, 14, 14, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 14 + l1_w = 14 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + eeb0 = block_index % 2 + eeb1 = block_index // 2 + input_1_1_local_L1 = tik_instance.Tensor("float16", (196 * 32 * 16,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (106496 // 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + input_1_2_fractal_L1_local_UB = tik_instance.Tensor("float16", (196 * 16 * 16,), scope=tik.scope_ubuf, + name="input_1_2_fractal_L1_local_UB") + with tik_instance.for_range(0, 32) as i: + tik_instance.data_move(input_1_1_local_L1[i * 3136], input_x[i, eeb1, 0, 0, 0], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 16) as i: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[i * 3328], + input_1_1_local_L1[i * 3136 + eeb0 * 16 * 3136], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 13) + with tik_instance.for_range(0, 16) as i: + tik_instance.data_move(input_1_2_fractal_L1_local_UB[i * 196 * 16], + input_1_1_fractal_L1_local_UB[i * 3328], 0, 1, 196, 0, 0) + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[eeb1, i + 196 * eeb0, 0, 0], input_1_2_fractal_L1_local_UB[256 * i], 0, 1, + 16, 0, 0) + + if input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (1, 1)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (196 * 256 * 2,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + with tik_instance.for_range(0, 2) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, eeb0 * 8, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 4) as eeb1: + with tik_instance.for_range(0, 2) as eeb2: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb2 * 196 * 256], + input_1_1_local_L1[(eeb2 + eeb1 * 2) * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 196) + with tik_instance.for_range(0, 2) as eeb2: + with tik_instance.for_range(0, 196) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1 * 2 + eeb2, i + block_index * 196, 0, 0], + input_1_1_fractal_L1_local_UB[256 * i + eeb2 * 196 * 256], 0, 1, 16, + 0, 0) + + if input_shape == ((32, 16, 56, 56, 16), 'float16', (1, 1), (2, 2)): + if padding == 'SAME': + padding_left = 0 + padding_right = 0 + padding_top = 0 + padding_bottom = 0 + pad = [padding_left, padding_right, padding_top, padding_bottom] + l1_h = 56 + l1_w = 56 + c1_index = 0 + jump_stride = 1 + repeat_mode = 1 + with tik_instance.for_range(0, 32, block_num=32) as block_index: + input_1_1_local_L1 = tik_instance.Tensor("float16", (25088 * 32 // 2,), scope=tik.scope_cbuf, + name="input_1_1_local_L1") + input_1_1_fractal_L1_local_UB = tik_instance.Tensor("float16", (49 * 256 * 8,), scope=tik.scope_ubuf, + name="input_1_1_fractal_L1_local_UB") + with tik_instance.for_range(0, 2) as eeb0: + tik_instance.data_move(input_1_1_local_L1, input_x[block_index, eeb0 * 8, 0, 0, 0], 0, 1, 25088, 0, 0) + with tik_instance.for_range(0, 8) as eeb1: + fetch_filter_w = 0 + fetch_filter_h = 0 + left_top_h = 0 + left_top_w = 0 + tik_instance.load3dv1(input_1_1_fractal_L1_local_UB[eeb1 * 49 * 256], + input_1_1_local_L1[eeb1 * 56 * 56 * 16], + pad, + l1_h, + l1_w, + c1_index, + fetch_filter_w, + fetch_filter_h, + left_top_w, + left_top_h, + stride_w, + stride_h, + filter_w, + filter_h, + dilation_filter_w, + dilation_filter_h, + jump_stride, + repeat_mode, + 49) + with tik_instance.for_range(0, 8) as eeb1: + with tik_instance.for_range(0, 49) as i: + tik_instance.data_move(res[eeb0 * 8 + eeb1, i + block_index * 49, 0, 0], + input_1_1_fractal_L1_local_UB[256 * i + eeb1 * 49 * 256], 0, 1, 16, 0, 0) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py new file mode 100644 index 0000000000..e5c380369d --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py @@ -0,0 +1,468 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.lang.cce +import te.platform.cce_params as cce +from te import tik +from te import tvm +from topi import generic +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubedenseleft.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeDenseLeft") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + print(km_shape, kn_shape) + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias) != 0: + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + shb = [] + if bias_length % 16 == 0: + shb = shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + shb = shape_bias + return shb + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +# @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) +@op_info_register(matmul_cube_dense_left_op_info) +def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + print("!!!!come into zzt~~~~~~~!!!!") + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + shape_output = output_y.get("ori_shape") + print("============") + print(input_x1.get("format"), input_x2.get("format")) + print(shape_a, shape_b) + print("============") + if input_x2.get("format") == "FRACTAL_Z": + n, c, h, w = shape_b + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_b = [n, c1 * h * w * c0] + shape_a = [n, n] + + if input_x1.get("format") == "FRACTAL_Z": + n, c, h, w = shape_a + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_a = [n, c1 * h * w * c0] + shape_b = [c1 * h * w * c0, c1 * h * w * c0] + + if input_x2.get("format") == "FRACTAL_NZ": + shape_a = [shape_b[0], shape_b[0]] + shape_b = shape_b + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = shape_a + shape_b = [shape_a[1], shape_a[1]] + + shape_a = list(shape_a) + shape_b = list(shape_b) + + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + dst_dtype = output_y.get("dtype").lower() + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "FRACTAL_NZ" + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "FRACTAL_NZ" + + print("=======================================") + print(shape_a_temp, shape_b_temp) + print(format_a, format_b) + print("=======================================") + tensor_bias = None + tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', + dtype=src_dtype) + tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', + dtype=src_dtype) + + if len(shape_bias) > 0: + tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', + dtype=dst_dtype) + + if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63: + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float16", shape_output, name="output", scope=tik.scope_gm) + with tik_instance.for_range(0, 32, block_num=32) as block_index: + resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_ubuf, + name="resMatmul_local_UB") + resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (128 * 256,), scope=tik.scope_cc, + name="resMatmul_local_UB") + input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_ca, + name="input_1_local_L1_local_L0A") + input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cbuf, + name="input_2_local_L1") + input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cbuf, + name="input_1_local_L1") + input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 256,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B") + core_m_idx = block_index % 8 + core_n_idx = block_index // 8 + with tik_instance.if_scope(core_m_idx != 7): + tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128, + 55 * 16, 0) + tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, + 32, 128, 55 * 16, 0) + with tik_instance.for_range(0, 8) as cc12: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc12 * 2048], input_1_local_L1[cc12 * 256], 0, 8, + 8, 0, False) + with tik_instance.for_range(0, 2) as cc6: + with tik_instance.for_range(0, 8) as cc121: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc121 * 4096], + input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16, 8, 0, True) + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 128, 128, 256, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0, 1) + tik_instance.data_move(resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], + resMatmul_local_UB, 0, 16, 256 // 2, 0, 55 * 16 * 2 // 2) + with tik_instance.else_scope(): + tik_instance.data_move(input_1_local_L1, input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112, + 56 * 16, 0) + tik_instance.data_move(input_2_local_L1, input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], 0, + 32, 112, 56 * 16, 0) + with tik_instance.for_range(0, 7) as cc10: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc10 * 1792], input_1_local_L1[cc10 * 256], 0, 7, + 7, 0, False) + with tik_instance.for_range(0, 2) as cc5: + with tik_instance.for_range(0, 7) as cc101: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc101 * 4096], + input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16, 7, 0, True) + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 112, 112, 256, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 112, 0, 0, 1) + tik_instance.data_move(resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 + core_n_idx * 512 * 1008], + resMatmul_local_UB, 0, 16, 224 // 2, 0, 56 * 16 * 2 // 2) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resMatmul]) + return tik_instance + else: + print("come into tbe, shape is error!") + result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, + format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) + + with tvm.target.cce(): + schedule = generic.auto_schedule(result) + + tensor_list = [tensor_a, tensor_b, result] + if len(shape_bias) > 0: + tensor_list = [tensor_a, tensor_b, tensor_bias, result] + + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(schedule, config) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py new file mode 100644 index 0000000000..4a1982738d --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +matmul_cube_dense_right_op_info = TBERegOp("CusMatMulCubeDenseRight") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubedenseright.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeDenseRight") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "required", "all") \ + .input(3, "x4", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, + DataType.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(matmul_cube_dense_right_op_info) +def CusMatMulCubeDenseRight(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeDenseRight""" + shape_a_temp = (128, 63, 16, 16) + shape_b_temp = (128, 128, 16, 16) + shape_output = output_y.get("shape") + matrix_max_shape = (1,) + support_shape = [(shape_a_temp, shape_b_temp, matrix_max_shape),] + shape_a_input = input_x1.get("shape") + shape_b_input = input_x2.get("shape") + matrix_max_input = input_x3.get("shape") + input_shape = (tuple(shape_a_input), tuple(shape_b_input), tuple(matrix_max_input)) + if input_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + if shape_a_temp[0] == 128 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 128: + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + input_x1 = tik_instance.Tensor("float16", shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", shape_b_temp, name="right_matrix", scope=tik.scope_gm) + input_x3 = tik_instance.Tensor("float32", [1,], name="matrix_max", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float32", shape_output, name="output", scope=tik.scope_gm) + with tik_instance.for_range(0, 32, block_num=32) as block_index: + core_m_idx = block_index // 16 + core_n_idx = block_index % 16 + matrix_max_scalar = tik_instance.Scalar("float32") + matrix_max_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="matrix_max_local_UB") + tik_instance.data_move(matrix_max_local_UB, input_x3, 0, 1, 1, 0, 0) + matrix_max_scalar.set_as(matrix_max_local_UB[0]) + + resMatmul_local_UB = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_ubuf, + name="resMatmul_local_UB") + resMatmul_local_UB1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_ubuf, + name="resMatmul_local_UB1") + + resMatmul_local_UB_local_L0C = tik_instance.Tensor("float32", (256 * 128,), scope=tik.scope_cc, + name="resMatmul_local_UB_local_L0C") + resMatmul_local_UB_local_L0C1 = tik_instance.Tensor("float32", (240 * 128,), scope=tik.scope_cc, + name="resMatmul_local_UB_local_L0C1") + + input_1_local_L1_local_L0A = tik_instance.Tensor("float16", (256 * 128,), scope=tik.scope_ca, + name="input_1_local_L1_local_L0A") + input_2_local_L1 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, + name="input_2_local_L1") + input_2_local_L11 = tik_instance.Tensor("float16", (8 * 128 * 16,), scope=tik.scope_cbuf, + name="input_2_local_L11") + + input_1_local_L1 = tik_instance.Tensor("float16", (8 * 256 * 16,), scope=tik.scope_cbuf, + name="input_1_local_L1") + input_1_local_L11 = tik_instance.Tensor("float16", (8 * 240 * 16,), scope=tik.scope_cbuf, + name="input_1_local_L11") + + input_2_local_L1_local_L0B = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B") + input_2_local_L1_local_L0B1 = tik_instance.Tensor("float16", (128 * 128,), scope=tik.scope_cb, + name="input_2_local_L1_local_L0B1") + + with tik_instance.if_scope(core_m_idx == 0): + with tik_instance.for_range(0, 2) as cc1: + tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, + 128, 1920, 0) + tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + cc1 * 4096], 0, 8, 256, 752, + 0) + with tik_instance.for_range(0, 8) as cc10: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, + 8, 8, 0, True) + with tik_instance.for_range(0, 16) as cc101: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], + 0, 8, 16, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B, 256, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) + tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], + matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], + matrix_max_scalar, 2, 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + cc1 * 4096], resMatmul_local_UB, 0, 8, 512, + 0, 1504) + with tik_instance.else_scope(): + tik_instance.data_move(input_2_local_L1, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, + 1920, 0) + tik_instance.data_move(input_1_local_L1, input_x1[core_n_idx * 129024 + 2 * 4096], 0, 8, 256, 752, 0) + with tik_instance.for_range(0, 8) as cc10: + tik_instance.load2dv1(input_2_local_L1_local_L0B[cc10 * 2048], input_2_local_L1[cc10 * 256], 0, 8, + 8, 0, True) + with tik_instance.for_range(0, 16) as cc101: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc101 * 2048], input_1_local_L1[cc101 * 256], 0, 8, + 16, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C, input_1_local_L1_local_L0A, input_2_local_L1_local_L0B, + 256, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB, resMatmul_local_UB_local_L0C, 0, 1, 128, 0, 0) + tik_instance.vmuls(64, resMatmul_local_UB, resMatmul_local_UB, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[255 * 64], resMatmul_local_UB[255 * 64], matrix_max_scalar, + 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB[510 * 64], resMatmul_local_UB[510 * 64], matrix_max_scalar, 2, + 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + 2 * 4096], resMatmul_local_UB, 0, 8, 512, 0, + 1504) + + tik_instance.data_move(input_2_local_L11, input_x2[core_n_idx * 262144 + core_n_idx * 2048], 0, 8, 128, + 1920, 0) + tik_instance.data_move(input_1_local_L11, input_x1[core_n_idx * 129024 + 12288], 0, 8, 240, 768, 0) + + with tik_instance.for_range(0, 8) as cc102: + tik_instance.load2dv1(input_2_local_L1_local_L0B1[cc102 * 2048], input_2_local_L11[cc102 * 256], 0, + 8, 8, 0, True) + with tik_instance.for_range(0, 16) as cc103: + tik_instance.load2dv1(input_1_local_L1_local_L0A[cc103 * 2048], input_1_local_L11[cc103 * 256], 0, + 8, 15, 0, False) + + tik_instance.mmad(resMatmul_local_UB_local_L0C1, input_1_local_L1_local_L0A, + input_2_local_L1_local_L0B1, 240, 128, 128, 0) + tik_instance.data_move(resMatmul_local_UB1, resMatmul_local_UB_local_L0C1, 0, 1, 120, 0, 0) + + tik_instance.vmuls(64, resMatmul_local_UB1, resMatmul_local_UB1, matrix_max_scalar, 255, 1, 1, 8, 8) + tik_instance.vmuls(64, resMatmul_local_UB1[255 * 64], resMatmul_local_UB1[255 * 64], matrix_max_scalar, + 225, 1, 1, 8, 8) + + tik_instance.data_move(resMatmul[core_n_idx * 129024 + 12288], resMatmul_local_UB1, 0, 8, 480, 0, 1536) + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py new file mode 100644 index 0000000000..11b668445e --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py @@ -0,0 +1,526 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.platform.cce_params as cce +from te import tik +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_fracz_left_cast_op_info = TBERegOp("CusMatMulCubeFraczLeftCast") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubefraczleftcast.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeFraczLeftCast") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F32_FracZ, DataType.F16_Default, DataType.F16_FracZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND +src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + print(km_shape, kn_shape) + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias): + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + if bias_length % 16 == 0: + return shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + return shape_bias + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@op_info_register(matmul_cube_fracz_left_cast_op_info) +def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="CusMatMulCubeFraczLeftCast"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + print("============") + print(input_x1.get("format"), input_x2.get("format")) + print(shape_a, shape_b) + print("============") + if input_x2.get("format") == "FRACTAL_Z": + n, c, h, w = shape_b + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_b = [n, c1 * h * w * c0] + shape_a = [n, n] + + if input_x1.get("format") == "FRACTAL_Z": + n, c, h, w = shape_a + c0 = 16 + c1 = c // c0 + if c1 == 0: + c1 = 1 + shape_a = [n, c1 * h * w * c0] + shape_b = [c1 * h * w * c0, c1 * h * w * c0] + + if input_x2.get("format") == "FRACTAL_NZ": + shape_a = [shape_b[0], shape_b[0]] + shape_b = shape_b + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = shape_a + shape_b = [shape_a[1], shape_a[1]] + + shape_a = list(shape_a) + shape_b = list(shape_b) + + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + input_x1 = tik_instance.Tensor(input_x1.get("dtype"), shape_a_temp, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor(input_x2.get("dtype"), shape_b_temp, name="right_matrix", scope=tik.scope_gm) + res_matmul = tik_instance.Tensor(output_y.get("dtype"), output_y.get("shape"), name="output", scope=tik.scope_gm) + DIAG_SIZE = 128 + mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(input_x1, input_x2, DIAG_SIZE) + cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, res_matmul, + mo_tile=mo_tile, ko_tile=ko_tile, no_tile=no_tile, + diag_opt=diag_opt, diag_size=DIAG_SIZE) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[res_matmul]) + return tik_instance + + +def get_cus_tile_info(input_x1, input_x2, diag_size): + """get_cus_tile_info""" + tile_map = { + ((32, 32, 16, 16), (128, 32, 16, 16)): (8, 8, 16), + ((8, 8, 16, 16), (72, 8, 16, 16)): (8, 8, 4), + ((32, 32, 16, 16), (288, 32, 16, 16)): (8, 8, 12), + ((128, 128, 16, 16), (32, 128, 16, 16)): (8, 8, 16), + ((16, 16, 16, 16), (144, 16, 16, 16)): (8, 8, 9), + ((64, 64, 16, 16), (16, 64, 16, 16)): (8, 8, 4), + ((16, 16, 16, 16), (64, 16, 16, 16)): (8, 8, 4), + ((32, 32, 16, 16), (8, 32, 16, 16)): (8, 8, 1), + ((128, 128, 16, 16), (64, 128, 16, 16)): (8, 8, 16), + ((16, 16, 16, 16), (4, 16, 16, 16)): (8, 8, 1), + ((16, 16, 16, 16), (32, 16, 16, 16)): (8, 8, 2), + ((64, 64, 16, 16), (32, 64, 16, 16)): (8, 8, 8), + ((32, 32, 16, 16), (64, 32, 16, 16)): (8, 8, 8), + ((32, 32, 16, 16), (16, 32, 16, 16)): (8, 8, 2), + ((8, 8, 16, 16), (32, 8, 16, 16)): (8, 8, 1), + ((8, 8, 16, 16), (16, 8, 16, 16)): (4, 8, 1), + ((4, 4, 16, 16), (16, 4, 16, 16)): (2, 4, 1), + ((4, 4, 16, 16), (4, 4, 16, 16)): (1, 4, 1), + ((4, 4, 16, 16), (36, 4, 16, 16)): (2, 4, 3), + ((4, 4, 16, 16), (49, 4, 16, 16)): (1, 4, 7) + } + shape_info = (tuple(input_x1.shape), tuple(input_x2.shape)) + diag_opt = False + if input_x1.shape[0] * input_x1.shape[3] > diag_size: + diag_opt = True + if shape_info not in tile_map: + raise ValueError("shape %s is not supported" % str(shape_info)) + mo_tile, ko_tile, no_tile = tile_map[shape_info] + return mo_tile, ko_tile, no_tile, diag_opt + + +def cus_cube_matmul_cast(tik_instance, input_x1, trans_a, input_x2, trans_b, + res, mo_tile, ko_tile, no_tile, diag_opt=False, diag_size=128): + """cus_cube_matmul_cast""" + ko, mo, _, _ = input_x1.shape + no, ko, _, _ = input_x2.shape + c0 = input_x1.shape[-1] + diag_outer = diag_size // c0 + maxblocknum = 32 + fp32_size = 4 + fp16_size = 2 + blocksize = 32 + vectorfp32_size = 64 + if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: + raise ValueError("shape of input_x1 or input_x2 is not supported!") + if not trans_a or not trans_b: + raise ValueError("only trans_a=False and trans_b=False be supported!") + + core_m_num = mo // mo_tile + loop_n_num = no // no_tile + if loop_n_num * core_m_num <= maxblocknum: + core_n_num = loop_n_num + else: + core_n_num = maxblocknum // core_m_num + if core_n_num > 0 and loop_n_num % core_n_num == 0: + loop_n_num = loop_n_num // core_n_num + else: + raise ValueError("Does not support this scenario!") + block_num = core_m_num * core_n_num + + loop_k_num = ko // ko_tile + if diag_opt: + loop_k_num = diag_outer // ko_tile + # double buffer: + thread_num_k = 2 + loop_k_num *= thread_num_k + ko_tile_inner = ko_tile // thread_num_k + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + core_m = block_idx // core_n_num + core_n = block_idx % core_n_num + with tik_instance.for_range(0, loop_n_num) as cc_n: + res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_L0C", scope=tik.scope_cc) + with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: + # input_x2 -> input_x2_ub -(fp322fp16)-> input_x2_cast_ub -> input_x2_L1 + input_x2_ub = tik_instance.Tensor("float32", [no_tile, ko_tile_inner, c0, c0], name="input_x2_ub", + scope=tik.scope_ubuf) + if diag_opt: + k_idx = core_m * mo_tile + thread_idx_k * ko_tile_inner + else: + k_idx = thread_idx_k * ko_tile_inner + tik_instance.data_move(input_x2_ub, + input_x2[(core_n * loop_n_num + cc_n) * no_tile, + k_idx, 0, 0], + 0, no_tile, ko_tile_inner * c0 * c0 * fp32_size // blocksize, + (ko - ko_tile_inner) * c0 * c0 * fp32_size // blocksize, 0) + input_x2_cast_ub = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_cast_ub", scope=tik.scope_ubuf) + repeate_num = no_tile * ko_tile_inner * c0 * c0 // vectorfp32_size + repeate_times_max = 255 + count = 0 + while repeate_num > repeate_times_max: + tik_instance.vconv(vectorfp32_size, 'none', + input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], + input_x2_ub[count * repeate_times_max * vectorfp32_size], + repeate_times_max, + 1, 1, 4, 8) + repeate_num -= repeate_times_max + count += 1 + tik_instance.vconv(vectorfp32_size, 'none', + input_x2_cast_ub[count * repeate_times_max * vectorfp32_size], + input_x2_ub[count * repeate_times_max * vectorfp32_size], repeate_num, + 1, 1, 4, 8) + input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x2_L1, input_x2_cast_ub, 0, 1, + no_tile * ko_tile_inner * c0 * c0 * fp16_size // blocksize, 0, 0) + # input_x1 -> input_x1_L1 + input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], + name="input_x1_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x1_L1, + input_x1[k_idx, + core_m * mo_tile, 0, 0], + 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) + # input_x2_L1 -> input_x2_L0B + input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], + name="input_x2_L0B", scope=tik.scope_cb) + with tik_instance.for_range(0, ko_tile_inner) as cc2: + tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, + ko_tile_inner, + 0, True) + # input_x1_L1 -> input_x1_L0A + input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], + name="input_x1_L0A", scope=tik.scope_ca) + with tik_instance.for_range(0, mo_tile) as cc1: + tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, + mo_tile, 0, False) + with tik_instance.if_scope(thread_idx_k == 0): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 0) + with tik_instance.else_scope(): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 1) + res_ub = tik_instance.Tensor(input_x1.dtype, [no_tile, mo_tile, c0, c0], + name="resMatmul_ub", scope=tik.scope_ubuf) + tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0, 1) + tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, core_m * mo_tile, 0, 0], + res_ub, 0, no_tile, + mo_tile * c0 * c0 * fp16_size // blocksize, 0, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py new file mode 100644 index 0000000000..79fab2c3cd --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +cus_matmul_cube_fracz_right_mul_op_info = TBERegOp("CusMatMulCubeFraczRightMul") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcubefraczrightmul.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCubeFraczRightMul") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "required", "all") \ + .input(3, "x4", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F32_Default, DataType.F16_Default, + DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(cus_matmul_cube_fracz_right_mul_op_info) +def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeFraczRightMul""" + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x1_shape = input_x1.get("shape") + input_x1_dtype = input_x1.get("dtype").lower() + input_x2_shape = input_x2.get("shape") + input_x2_dtype = input_x2.get("dtype").lower() + input_x3_shape = input_x3.get("shape") + input_x3_dtype = input_x3.get("dtype").lower() + output_shape = output_y.get("shape") + Supported = [((72, 8, 16, 16), "float16", (72, 72, 16, 16), "float16", (1,), "float32"), + ((32, 8, 16, 16), "float16", (32, 32, 16, 16), "float16", (1,), "float32"), + ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"), + ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"), + ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'), + ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'), + ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'), + ((64, 16, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((32, 64, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((32, 16, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((16, 32, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((16, 8, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((16, 4, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32'), + ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'), + ((144, 16, 16, 16), 'float16', (144, 144, 16, 16), 'float16', (1,), 'float32'), + ((128, 32, 16, 16), 'float16', (128, 128, 16, 16), 'float16', (1,), 'float32'), + ((64, 128, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((32, 128, 16, 16), 'float16', (32, 32, 16, 16), 'float16', (1,), 'float32'), + ((64, 32, 16, 16), 'float16', (64, 64, 16, 16), 'float16', (1,), 'float32'), + ((16, 64, 16, 16), 'float16', (16, 16, 16, 16), 'float16', (1,), 'float32')] + input_shape = ( + tuple(input_x1_shape), input_x1_dtype, tuple(input_x2_shape), input_x2_dtype, tuple(input_x3_shape), input_x3_dtype) + if input_shape not in Supported: + raise RuntimeError("input_shape %s is not supported" % str(input_shape)) + + input_x1 = tik_instance.Tensor("float16", input_x1_shape, name="left_matrix", scope=tik.scope_gm) + input_x2 = tik_instance.Tensor("float16", input_x2_shape, name="right_matrix", scope=tik.scope_gm) + input_x3 = tik_instance.Tensor("float32", input_x3_shape, name="matrix_max", scope=tik.scope_gm) + resMatmul = tik_instance.Tensor("float32", output_shape, name="output", scope=tik.scope_gm) + cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, resMatmul) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2, input_x3], outputs=[resMatmul]) + return tik_instance + + +def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3, + res): + """cus_cube_matmul_right_mul""" + diag_size = 128 + ko, mo, _, _ = input_x1.shape + no, ko, _, _ = input_x2.shape + c0 = input_x1.shape[-1] + diag_outer = diag_size // c0 + if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]: + raise ValueError("shape of input_x1 or input_x2 is not supported!") + + def get_cus_tile_info(input_x1, input_x2, input_x3): + """get_cus_tile_info""" + input_shape = (tuple(input_x1.shape), input_x1.dtype, tuple(input_x2.shape), input_x2.dtype, + tuple(input_x3.shape), input_x3.dtype) + tile_map = { + # no diag opt: + ((8, 32, 16, 16), "float16", (8, 8, 16, 16), "float16", (1,), "float32"): (4, 8, 2, 8, 4), + ((4, 4, 16, 16), "float16", (4, 4, 16, 16), "float16", (1,), "float32"): (1, 4, 1, 4, 4), + ((4, 16, 16, 16), 'float16', (4, 4, 16, 16), 'float16', (1,), 'float32'): (1, 4, 2, 16, 2), + ((49, 4, 16, 16), 'float16', (49, 49, 16, 16), 'float16', (1,), 'float32'): (1, 7, 7, 4, 7), + ((36, 4, 16, 16), 'float16', (36, 36, 16, 16), 'float16', (1,), 'float32'): (2, 6, 3, 2, 12), + # diag opt: + ((288, 32, 16, 16), 'float16', (288, 288, 16, 16), 'float16', (1,), 'float32'): (16, 8, 8, 2, 12), + } + maxblocknum = 32 + diag_opt = False + if input_x2.shape[0] * input_x2.shape[3] > diag_size and input_x2.shape[0] % diag_outer == 0: + diag_opt = True + if input_shape in tile_map: + mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_ = tile_map[input_shape] + elif diag_opt: + ko_tile_ = diag_outer + no_tile_ = ko_tile_ + core_n_num_ = no // no_tile_ + core_m_num_max = maxblocknum // core_n_num_ + mo_tile_ = -1 + core_m_num_ = -1 + for i in range(core_m_num_max, 0, -1): + if mo % i == 0: + core_m_num_ = i + mo_tile_ = mo // i + break + if mo_tile_ == -1: + raise ValueError("no valid tile be found!") + while mo_tile_ > 16: + mo_tile_ = mo_tile_ // 2 + else: + raise ValueError("please add tile config to the tile_map") + print("shape: %s, tile: %s" % (input_shape, str((mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, + diag_opt)))) + return mo_tile_, ko_tile_, no_tile_, core_m_num_, core_n_num_, diag_opt + + mo_tile, ko_tile, no_tile, core_m_num, core_n_num, diag_opt = get_cus_tile_info(input_x1, input_x2, input_x3) + fp32_size = 4 + fp16_size = 2 + blocksize = 32 + vectorfp32_size = 64 + loop_n_num_total = no // no_tile + loop_m_num_total = mo // mo_tile + if loop_n_num_total % core_n_num != 0 or loop_m_num_total % core_m_num != 0: + raise ValueError("Does not support this scenario!") + loop_n_num = loop_n_num_total // core_n_num + loop_m_num = loop_m_num_total // core_m_num + block_num = core_n_num * core_m_num + loop_k_num = ko // ko_tile + if diag_opt: + loop_k_num = diag_outer // ko_tile + # double buffer: + thread_num_k = 2 + if ko_tile % 2 == 0: + loop_k_num *= thread_num_k + ko_tile_inner = ko_tile // thread_num_k + else: + ko_tile_inner = ko_tile + ko_tile *= thread_num_k + with tik_instance.for_range(0, block_num, block_num=block_num) as block_idx: + core_m = block_idx // core_n_num + core_n = block_idx % core_n_num + with tik_instance.for_range(0, loop_m_num) as cc_m: + with tik_instance.for_range(0, loop_n_num) as cc_n: + res_L0C = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_L0C", scope=tik.scope_cc) + with tik_instance.for_range(0, loop_k_num, thread_num=thread_num_k) as thread_idx_k: + if diag_opt: + k_idx = (core_n * loop_n_num + cc_n) * no_tile + thread_idx_k * ko_tile_inner + else: + k_idx = thread_idx_k * ko_tile_inner + # input_x1 -> input_x1_L1 + input_x1_L1 = tik_instance.Tensor(input_x1.dtype, [ko_tile_inner, mo_tile, c0, c0], + name="input_x1_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x1_L1, + input_x1[k_idx, + (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], + 0, ko_tile_inner, mo_tile * c0 * c0 * fp16_size // blocksize, + (mo - mo_tile) * c0 * c0 * fp16_size // blocksize, 0) + # input_x2 -> input_x2_L1 + input_x2_L1 = tik_instance.Tensor("float16", [no_tile, ko_tile_inner, c0, c0], + name="input_x2_L1", scope=tik.scope_cbuf) + tik_instance.data_move(input_x2_L1, + input_x2[(core_n * loop_n_num + cc_n) * no_tile, + k_idx, 0, 0], + 0, no_tile, ko_tile_inner * c0 * c0 * fp16_size // blocksize, + (ko - ko_tile_inner) * c0 * c0 * fp16_size // blocksize, 0) + # input_x1_L1 -> input_x1_L0A + input_x1_L0A = tik_instance.Tensor(input_x1.dtype, [mo_tile, ko_tile_inner, c0, c0], + name="input_x1_L0A", scope=tik.scope_ca) + with tik_instance.for_range(0, mo_tile) as cc1: + tik_instance.load2dv1(input_x1_L0A[cc1, 0, 0, 0], input_x1_L1[0, cc1, 0, 0], 0, ko_tile_inner, + mo_tile, 0, False) + # input_x2_L1 -> input_x2_L0B + input_x2_L0B = tik_instance.Tensor("float16", [ko_tile_inner, no_tile, c0, c0], + name="input_x2_L0B", scope=tik.scope_cb) + with tik_instance.for_range(0, ko_tile_inner) as cc2: + tik_instance.load2dv1(input_x2_L0B[cc2, 0, 0, 0], input_x2_L1[0, cc2, 0, 0], 0, no_tile, + ko_tile_inner, + 0, True) + with tik_instance.if_scope(thread_idx_k == 0): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 0) + with tik_instance.else_scope(): + tik_instance.mmad(res_L0C, input_x1_L0A, input_x2_L0B, mo_tile * c0, + ko_tile_inner * c0, no_tile * c0, 1) + res_ub = tik_instance.Tensor("float32", [no_tile, mo_tile, c0, c0], + name="resMatmul_ub", scope=tik.scope_ubuf) + tik_instance.data_move(res_ub, res_L0C, 0, 1, no_tile * mo_tile, 0, 0) + + input_3_local_UB = tik_instance.Tensor("float32", (8,), scope=tik.scope_ubuf, name="input_3_local_UB") + tik_instance.data_move(input_3_local_UB, input_x3, 0, 1, 1, 0, 0) + matrix_max_scalar = tik_instance.Scalar("float32") + matrix_max_scalar.set_as(input_3_local_UB[0]) + repeate_num = no_tile * mo_tile * c0 * c0 // vectorfp32_size + repeate_times_max = 255 + count = 0 + while repeate_num > repeate_times_max: + tik_instance.vmuls(vectorfp32_size, + res_ub[count * repeate_times_max * vectorfp32_size], + res_ub[count * repeate_times_max * vectorfp32_size], + matrix_max_scalar, repeate_times_max, 1, 1, 8, 8) + repeate_num -= repeate_times_max + count += 1 + tik_instance.vmuls(vectorfp32_size, + res_ub[count * repeate_times_max * vectorfp32_size], + res_ub[count * repeate_times_max * vectorfp32_size], + matrix_max_scalar, repeate_num, 1, 1, 8, 8) + + tik_instance.data_move(res[(core_n * loop_n_num + cc_n) * no_tile, + (core_m * loop_m_num + cc_m) * mo_tile, 0, 0], + res_ub, 0, no_tile, + mo_tile * c0 * c0 * fp32_size // blocksize, 0, + (mo - mo_tile) * c0 * c0 * fp32_size // blocksize) diff --git a/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py new file mode 100644 index 0000000000..603ed287f6 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import +from impl.matmul_vector import matmul_vector_cce +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te.lang.cce +import te.platform.cce_params as cce +from te import tvm +from topi import generic +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + +matmul_cube_op_info = TBERegOp("CusMatMulCube") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmulcube.so") \ + .compute_cost(10) \ + .kernel_name("CusMatMulCube") \ + .partial_flag(True) \ + .attr("transpose_a", "required", "bool", "all") \ + .attr("transpose_b", "required", "bool", "all") \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .input(2, "x3", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_FracNZ) \ + .get_op_info() + + +# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals, +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + k_block_size = cce.BLOCK_REDUCE + + check_list = ("float16") + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions should use batch_matmul now!") + + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + + if m_shape == 1: + if n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + elif km_shape % k_block_size != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 should be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N should be 1 or multiple of %d" % cce.BLOCK_IN) + + if len(shape_bias): + if len(shape_bias) == 1: + if is_gevm or is_gemv: + if shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + else: + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == shape_len: + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("unsupport input shape now for batch bias case") + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + if bias_length % 16 == 0: + return shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + return shape_bias + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res + + +def check_supported(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """check_supported""" + shape_a = input_x1.get("shape") + shape_b = input_x2.get("shape") + print("shape_a: ", shape_a) + print("shape_b: ", shape_b) + src_dtype = input_x1.get("dtype") + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + try: + trans_a_f = bool(1 - trans_a) + if src_dtype == "float32" or src_dtype == "int32": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + if trans_b: + if shape_b[0] == 1: + return False + else: + if shape_b[1] == 1: + return False + if trans_a: + if trans_b: + if shape_a[0] != shape_b[1]: + return False + elif shape_a[0] != shape_b[0]: + return False + elif trans_b: + if shape_a[1] != shape_b[1]: + return False + elif shape_a[1] != shape_b[0]: + return False + + if trans_a_f and trans_b and shape_b[1] == 1: + return False + + if src_dtype == "float16": + if len(shape_a) != 2 and len(shape_b) != 2: + return False + + if trans_a: + m_shape = shape_a[1] + k_shape = shape_a[0] + else: + m_shape = shape_a[0] + k_shape = shape_a[1] + + if trans_b: + n_shape = shape_b[0] + k_b_shape = shape_b[1] + else: + n_shape = shape_b[1] + k_b_shape = shape_b[0] + + if k_shape != k_b_shape: + return False + + if m_shape == 1 or n_shape == 1: + if k_shape % 256 != 0: + return False + + except RuntimeError as e: + return False + + return True + + +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@op_info_register(matmul_cube_op_info) +def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """ + calculating matrix multiplication with bias, C = A*B + bias, support input + data with fractal format. + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + src_dtype: str + The data type of input, support "float32", "float16" + dst_dtype: str + The data type of output, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + is_fractal: bool + If True, the input data format of a and b must be fractal format + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + + Returns + ------- + None + """ + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + + if shape_a is not None: + if len(shape_a) < 2: + shape_a = input_x1.get("shape") + + if shape_b is not None: + if len(shape_b) < 2: + shape_b = input_x2.get("shape") + + shape_a = list(shape_a) + shape_b = list(shape_b) + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = _get_input_shape(shape_a) + shape_b = _get_input_shape(shape_b) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_a) + util.check_shape_rule(shape_b) + util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT) + util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT) + + if input_x1.get("format") == "FRACTAL_NZ": + shape_a = [shape_a[1], shape_a[0]] + trans_a = bool(1 - trans_a) + + if input_x2.get("format") == "FRACTAL_NZ": + shape_b = [shape_b[1], shape_b[0]] + trans_b = bool(1 - trans_b) + + shape_bias = () + if bias is not None and bool(bias): + shape_bias = bias.get("shape") + shape_bias = list(shape_bias) + shape_bias = _get_bias(shape_bias) + + src_dtype = input_x1.get("dtype").lower() + dst_dtype = output_y.get("dtype").lower() + if src_dtype == "float32" or src_dtype == "int32": + matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b, shape_bias, kernel_name) + return + _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b) + m_shape = shape_a[len(shape_a) - 2] + km_shape = shape_a[len(shape_a) - 1] + kn_shape = shape_b[len(shape_a) - 2] + n_shape = shape_b[len(shape_a) - 1] + + if src_dtype == "float16": + block_reduce = cce.BLOCK_REDUCE + + block_in = cce.BLOCK_IN + block_out = cce.BLOCK_OUT + + if trans_a and km_shape == 1: + block_in = cce.BLOCK_VECTOR + + if not trans_a and m_shape == 1: + block_in = cce.BLOCK_VECTOR + + if trans_b and kn_shape == 1: + block_out = cce.BLOCK_VECTOR + + if not trans_b and n_shape == 1: + block_out = cce.BLOCK_VECTOR + + if trans_a: + shape_a_temp = (m_shape // block_reduce, km_shape // block_in, block_reduce, block_in) + else: + shape_a_temp = (m_shape // block_in, km_shape // block_reduce, block_in, block_reduce) + + if trans_b: + shape_b_temp = (kn_shape // block_out, n_shape // block_reduce, block_reduce, block_out) + else: + shape_b_temp = (kn_shape // block_reduce, n_shape // block_out, block_out, block_reduce) + + if input_x1.get("format") == "FORMAT_FRACTAL_Z": + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "fractal" + elif input_x1.get("format") == "FRACTAL_NZ": + shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2], shape_a_temp[3]) + format_a = "FRACTAL_NZ" + else: + shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1]) + format_a = "ND" + + if input_x2.get("format") == "FORMAT_FRACTAL_Z": + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "fractal" + elif input_x2.get("format") == "FRACTAL_NZ": + shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3]) + format_b = "FRACTAL_NZ" + else: + shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1]) + format_b = "ND" + + tensor_bias = None + tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', + dtype=src_dtype) + tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', + dtype=src_dtype) + + if len(shape_bias) > 0: + tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias', + dtype=dst_dtype) + result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a, + format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias) + + with tvm.target.cce(): + schedule = generic.auto_schedule(result) + + tensor_list = [tensor_a, tensor_b, result] + if len(shape_bias) > 0: + tensor_list = [tensor_a, tensor_b, tensor_bias, result] + + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(schedule, config) diff --git a/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py b/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py new file mode 100644 index 0000000000..0a3f41386b --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusMatrixCombine""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_matrix_combine_op_info = TBERegOp("CusMatrixCombine") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matrixcombine.so") \ + .compute_cost(10) \ + .kernel_name("CusMatrixCombine") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(cus_matrix_combine_op_info) +def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): + """CusMatrixCombine""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + split_dim = 128 + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm) + + blocks = 32 + matrix_dim = input_x_shape[0] * input_x_shape[1] + if input_x_shape[0] == 1 and input_x_shape[1] == 64: + tiling_dim = 2 + bs = 1 + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[0, block_index * tiling_dim, 0], 0, 1, 16, 0, 0) + tik_instance.data_move(res[block_index * tiling_dim, 0], input_x_ub, 0, 1, 16, 0, 0) + else: + tiling_dim = 4 + bs = input_x_shape[0] + with tik_instance.for_range(0, blocks, block_num=blocks) as block_index: + input_x_ub = tik_instance.Tensor("float32", (tiling_dim, matrix_dim), name="input_x_ub", + scope=tik.scope_ubuf) + zero = tik_instance.Scalar("float32") + zero.set_as(0.0) + with tik_instance.for_range(0, bs) as i: + repeat_real = tiling_dim * matrix_dim // 64 + if repeat_real <= 255: + tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8) + else: + repeat_1 = 255 + repeat_2 = repeat_real - 255 + tik_instance.vector_dup(64, input_x_ub, zero, repeat_1, 1, 8) + tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_2, 1, 8) + with tik_instance.for_range(0, tiling_dim) as j: + tik_instance.data_move(input_x_ub[j, split_dim * i], input_x[i, block_index * tiling_dim + j, 0], 0, + 1, 16, 0, 0) + tik_instance.data_move(res[i * split_dim + block_index * tiling_dim, 0], input_x_ub, 0, 1, + tiling_dim * matrix_dim * 4 // 32, 0, 0) + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py new file mode 100644 index 0000000000..141e2c1d51 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py @@ -0,0 +1,289 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusTranspose02314""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +from te import tik +from topi.cce import util + +cus_transpose02314_op_info = TBERegOp("CusTranspose02314") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("transpose02314.so") \ + .compute_cost(10) \ + .kernel_name("CusTranspose02314") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(cus_transpose02314_op_info) +def CusTranspose02314(input_x, output, kernel_name="transpose021354"): + """CusTranspose02314""" + input_x_shape = input_x.get("shape") + output_shape = output.get("shape") + perm = (0, 2, 3, 1, 4) + input_x_shape = tuple(input_x_shape) + support_shape = [(32, 128, 7, 7, 16), + (32, 32, 7, 7, 16), + (32, 32, 14, 14, 16), + (32, 64, 14, 14, 16), + (32, 16, 14, 14, 16), + (32, 16, 28, 28, 16), + (32, 32, 28, 28, 16), + (32, 8, 28, 28, 16), + (32, 8, 56, 56, 16), + (32, 16, 56, 56, 16), + (32, 4, 56, 56, 16), + (32, 4, 112, 112, 16)] + if input_x_shape not in support_shape: + raise RuntimeError("input_shape %s is not supported" % str(input_x_shape)) + + if util.get_product_version() == util.VERSION_MINI: + tik_instance = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_instance = tik.Tik(tik.Dprofile("v100", "cloud")) + + input_x = tik_instance.Tensor("float16", input_x_shape, name="input_x", scope=tik.scope_gm) + res = tik_instance.Tensor("float16", output_shape, name="res", scope=tik.scope_gm) + + dtype = "float16" + if tuple(input_x_shape) == (32, 4, 112, 112, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 14) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + zero = tik_instance.Scalar(dtype="float16", init_value=0) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 802816 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, + 12096, 0) + with tik_instance.for_range(0, 448) as cc7: + with tik_instance.for_range(0, 4) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], + input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 4, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 3) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 200704 + cc1_db * 14336 + 7168 * db_idx], 0, 4, 448, + 2688, 0) + with tik_instance.for_range(0, 448) as cc7: + with tik_instance.for_range(0, 4) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 64 + cc8 * 16], + input_1_local_UB[7168 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + + input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) + T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 43008], 0, 4, 448, 2688, 0) + with tik_instance.for_range(0, 448) as cc72: + with tik_instance.for_range(0, 4) as cc82: + tik_instance.vadds(16, T_transpose_local_UB2[cc72 * 64 + cc82 * 16], + input_1_local_UB2[7168 * cc82 + cc72 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 16, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 14) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 802816 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, + 3024, 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 802816 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 8, 56, 56, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 7) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 401408 + cc1_db * 7168 + 3584 * db_idx], 0, 8, 224, 2912, + 0) + with tik_instance.for_range(0, 224) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], + input_1_local_UB[3584 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 8, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 2) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 100352 + cc1_db * 6272 + 3136 * db_idx], 0, 8, 196, 588, + 0) + with tik_instance.for_range(0, 196) as cc7: + with tik_instance.for_range(0, 8) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 128 + cc8 * 16], + input_1_local_UB[3136 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + cc1_db * 50176 + 25088 * db_idx], + T_transpose_local_UB, 0, 1, 1568, 0, 0) + elif tuple(input_x_shape) == (32, 32, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 7) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input_x[block_idx * 401408 + cc1_db * 1792 + 896 * db_idx], + 0, 32, 56, 728, 0) + with tik_instance.for_range(0, 56) as cc7: + with tik_instance.for_range(0, 32) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 512 + cc8 * 16], + input_1_local_UB[896 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 401408 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + elif tuple(input_x_shape) == (32, 16, 28, 28, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 3) as cc1_db: + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB", + scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, + input_x[block_idx * 200704 + cc1_db * 3584 + 1792 * db_idx], 0, 16, 112, 672, + 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + cc1_db * 57344 + 28672 * db_idx], + T_transpose_local_UB, 0, 1, 1792, 0, 0) + + input_1_local_UB2 = tik_instance.Tensor(dtype, [28672], name="input_1_local_UB2", scope=tik.scope_ubuf) + T_transpose_local_UB2 = tik_instance.Tensor(dtype, [28672], name="T_transpose_local_UB2", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB2, input_x[block_idx * 200704 + 10752], 0, 16, 112, 672, 0) + with tik_instance.for_range(0, 112) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB2[cc7 * 256 + cc8 * 16], + input_1_local_UB2[1792 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + 172032], T_transpose_local_UB2, 0, 1, 1792, 0, 0) + + elif tuple(input_x_shape) == (32, 16, 14, 14, 16): + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + zero = tik_instance.Scalar(dtype="float16", init_value=0) + with tik_instance.for_range(0, 2, thread_num=2) as db_idx: + input_1_local_UB = tik_instance.Tensor(dtype, [25088], name="input_1_local_UB", scope=tik.scope_ubuf) + T_transpose_local_UB = tik_instance.Tensor(dtype, [25088], name="T_transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_1_local_UB, input_x[block_idx * 50176 + 1568 * db_idx], 0, 16, 98, 98, 0) + with tik_instance.for_range(0, 98) as cc7: + with tik_instance.for_range(0, 16) as cc8: + tik_instance.vadds(16, T_transpose_local_UB[cc7 * 256 + cc8 * 16], + input_1_local_UB[1568 * cc8 + cc7 * 16], zero, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 50176 + 25088 * db_idx], T_transpose_local_UB, 0, 1, 1568, 0, 0) + elif tuple(input_x_shape) == (32, 128, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 7, thread_num=2) as cc1: + input_x_ub = tik_instance.Tensor(dtype, [1, 128, 1, 7, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 1, 7, 128, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, cc1, 0, 0], 0, 128, 7, 42, 0) + with tik_instance.for_range(0, 7) as cc7: + with tik_instance.for_range(0, 128) as cc8: + tik_instance.vadds(16, transpose_ub[0, 0, cc7, cc8, 0], input_x_ub[0, cc8, 0, cc7, 0], 0, + 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + 14336 * cc1], transpose_ub, 0, 1, 896, 0, 0) + + elif tuple(input_x_shape) == (32, 32, 7, 7, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + input_x_ub = tik_instance.Tensor(dtype, [1, 32, 7, 7, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 7, 7, 32, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, 0, 0, 0], 0, 1, 1568, 0, 0) + with tik_instance.for_range(0, 7) as cc1: + with tik_instance.for_range(0, 7) as cc2: + with tik_instance.for_range(0, 32) as cc3: + tik_instance.vadds(16, transpose_ub[0, cc1, cc2, cc3, 0], input_x_ub[0, cc3, cc1, cc2, 0], 0, + 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 25088], transpose_ub, 0, 1, 1568, 0, 0) + + elif tuple(input_x_shape) == (32, 32, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + def _inner_compute(split_index): + input_x_ub = tik_instance.Tensor(dtype, [1, 32, 2, 14, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 32, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 32, 28, 168, 0) + with tik_instance.for_range(0, 2) as cc2: + with tik_instance.for_range(0, 14) as cc3: + with tik_instance.for_range(0, 32) as cc4: + tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], + 0, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 100352 + split_index * 2 * 7168], transpose_ub, 0, 1, 896, 0, 0) + + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 6, thread_num=2) as cc1: + _inner_compute(cc1) + _inner_compute(6) + elif tuple(input_x_shape) == (32, 64, 14, 14, 16) and tuple(perm) == (0, 2, 3, 1, 4) and dtype == "float16": + def _inner_compute(split_index, block_idx): + input_x_ub = tik_instance.Tensor(dtype, [1, 64, 2, 14, 16], name="input_1_local_UB", + scope=tik.scope_ubuf) + transpose_ub = tik_instance.Tensor(dtype, [1, 2, 14, 64, 16], name="transpose_local_UB", + scope=tik.scope_ubuf) + tik_instance.data_move(input_x_ub, input_x[block_idx, 0, split_index * 2, 0, 0], 0, 64, 28, 168, 0) + with tik_instance.for_range(0, 2) as cc2: + with tik_instance.for_range(0, 14) as cc3: + with tik_instance.for_range(0, 64) as cc4: + tik_instance.vadds(16, transpose_ub[0, cc2, cc3, cc4, 0], input_x_ub[0, cc4, cc2, cc3, 0], + 0, 1, 1, 1, 0, 0) + tik_instance.data_move(res[block_idx * 200704 + split_index * 2 * 14336], transpose_ub, 0, 1, 1792, 0, 0) + + with tik_instance.for_range(0, 32, block_num=32) as block_idx: + with tik_instance.for_range(0, 6, thread_num=2) as cc1: + _inner_compute(cc1, block_idx) + _inner_compute(6, block_idx) + + tik_instance.BuildCCE(kernel_name, inputs=[input_x], outputs=[res]) + return tik_instance diff --git a/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py deleted file mode 100644 index e2afa96a7d..0000000000 --- a/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""batch_matmul_impl""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusBatchMatMul", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "batchmatmul.so", - "compute_cost": 10, - "kernel_name": "CusBatchMatMul", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): - """CusBatchMatMul""" - return diff --git a/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py b/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py deleted file mode 100644 index 5c38dfc25d..0000000000 --- a/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusCholeskyTrsm""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusCholeskyTrsm", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "choleskytrsm.so", - "compute_cost": 10, - "kernel_name": "CusCholeskyTrsm", - "partial_flag": true, - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusCholeskyTrsm(input_x, output, kernel_name): - """CusCholeskyTrsm""" - return diff --git a/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py b/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py deleted file mode 100644 index b9a0d45273..0000000000 --- a/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusFusedAbsMax1""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusFusedAbsMax1", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "fusedabsmax1.so", - "compute_cost": 10, - "kernel_name": "CusFusedAbsMax1", - "partial_flag": true, - "attr": [ - { - "name": "origin_shape", - "param_type": "required", - "type": "listInt", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): - """CusFusedAbsMax1""" - return diff --git a/mindspore/ops/_op_impl/custom_op/img2col_impl.py b/mindspore/ops/_op_impl/custom_op/img2col_impl.py deleted file mode 100644 index 5137d4d7e7..0000000000 --- a/mindspore/ops/_op_impl/custom_op/img2col_impl.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusImg2ColNC1HWC0""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusImg2ColNC1HWC0", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "img2colnc1hwc0.so", - "compute_cost": 10, - "kernel_name": "CusImg2ColNC1HWC0", - "partial_flag": true, - "attr": [ - { - "name": "ksizes", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "strides", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "dilates", - "param_type": "required", - "type": "listInt", - "value": "all" - }, - { - "name": "padding", - "param_type": "required", - "type": "str", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"): - """CusImg2ColNC1HWC0""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py deleted file mode 100644 index 300410eb4a..0000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeDenseLeft", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubedenseleft.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeDenseLeft", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="matmulcube"): - """CusMatMulCubeDenseLeft""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py deleted file mode 100644 index 3da1593dfd..0000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeFraczLeftCast", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubefraczleftcast.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeFraczLeftCast", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "FracZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="CusMatMulCubeFraczLeftCast"): - """CusMatMulCubeFraczLeftCast""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py deleted file mode 100644 index 7fc2ba35d1..0000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCubeFraczRightMul", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcubefraczrightmul.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCubeFraczRightMul", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 3, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x4", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "FracZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, - kernel_name="matmulcube"): - """CusMatMulCubeFraczRightMul""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py deleted file mode 100644 index 7c2d81e1d6..0000000000 --- a/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul -""" -from __future__ import absolute_import - -from mindspore.ops.op_info_register import op_info_register -from topi.cce import util - -# General limitation of the size for input shape: 2**31 -SHAPE_SIZE_LIMIT = 2147483648 -NoneType = type(None) - - -@op_info_register("""{ - "op_name": "CusMatMulCube", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matmulcube.so", - "compute_cost": 10, - "kernel_name": "CusMatMulCube", - "partial_flag": true, - "attr": [ - { - "name": "transpose_a", - "param_type": "required", - "type": "bool", - "value": "all" - }, - { - "name": "transpose_b", - "param_type": "required", - "type": "bool", - "value": "all" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "x3", - "need_compile": false, - "param_type": "optional", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements -@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) -def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): - """CusMatMulCube""" - return diff --git a/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py b/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py deleted file mode 100644 index 32045e7ccb..0000000000 --- a/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusMatrixCombine""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusMatrixCombine", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "matrixcombine.so", - "compute_cost": 10, - "kernel_name": "CusMatrixCombine", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): - """CusMatrixCombine""" - return diff --git a/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py deleted file mode 100644 index c5aebe523d..0000000000 --- a/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CusTranspose02314""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "CusTranspose02314", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "transpose02314.so", - "compute_cost": 10, - "kernel_name": "CusTranspose02314", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x1", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ] -}""") -def CusTranspose02314(input_x, output, kernel_name="transpose021354"): - """CusTranspose02314""" - return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 3ca616f2dc..3096e90250 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -23,6 +23,7 @@ from mindspore._checkparam import Validator as validator # path of built-in op info register. BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" +BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op" def op_info_register(op_info): @@ -47,7 +48,10 @@ def op_info_register(op_info): op_lib = Oplib() file_path = os.path.realpath(inspect.getfile(func)) # keep the path custom ops implementation. - imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path + if BUILT_IN_CUSTOM_OPS_REGISTER_PATH in file_path: + imply_path = file_path + else: + imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path if not op_lib.reg_op(op_info_real, imply_path): raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real)) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 40fdee5946..9d6ca945bf 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from . import _quant_ops from ._quant_ops import * +from .thor_ops import * __all__ = [ 'TensorAdd', diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py index 23593a2630..9d99ae22f0 100644 --- a/mindspore/ops/operations/thor_ops.py +++ b/mindspore/ops/operations/thor_ops.py @@ -13,19 +13,51 @@ # limitations under the License. # ============================================================================ """thor_ops""" -import mindspore as ms from mindspore.ops import prim_attr_register, PrimitiveWithInfer from mindspore.ops.composite import multitype_ops as C +import mindspore as ms + +__all__ = ["CusBatchMatMul", + "CusCholeskyTrsm", + "CusFusedAbsMax1", + "CusImg2Col", + "CusMatMulCubeDenseLeft", + "CusMatMulCubeFraczRightMul", + "CusMatMulCube", + "CusMatrixCombine", + "CusTranspose02314", + "CusMatMulCubeDenseRight", + "CusMatMulCubeFraczLeftCast", + ] + class CusBatchMatMul(PrimitiveWithInfer): - """CusMatMulCube definition""" + """ + Multiplies matrix `a` by matrix `b` in batch. + + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If + `transpose_b` is True. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, D, D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cus_batch_matmul = P.CusBatchMatMul() + >>> output = cus_batch_matmul(input_x, input_y) + """ @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusBatchMatMul""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -40,13 +72,30 @@ class CusBatchMatMul(PrimitiveWithInfer): class CusCholeskyTrsm(PrimitiveWithInfer): - """CusCholeskyTrsm definition""" + """ + L * LT = A. + LT * (LT)^-1 = I. + return (LT)^-1. + Only compute the res of the diag part of input matrix with dim 128. + The rank of input tensors must be `2`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32) + >>> cus_choleskytrsm = P.CusCholeskyTrsm() + >>> output = matmul(input_x) + """ @prim_attr_register def __init__(self): """init CusCholeskyTrsm""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm def infer_shape(self, data1_shape): ll = [] m, _ = data1_shape @@ -61,14 +110,28 @@ class CusCholeskyTrsm(PrimitiveWithInfer): class CusFusedAbsMax1(PrimitiveWithInfer): - """CusCholeskyTrsm definition""" + """ + Compute the abs max of Tensor input. + + The rank of input tensors must be `4` or `2`. + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)` + or math:`(32, 64)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) + >>> cus_fused_abs_max1 = P.CusFusedAbsMax1() + >>> output = cus_fused_abs_max1(input_x) + """ @prim_attr_register def __init__(self, origin_shape=[-1, -1]): - """init CusCholeskyTrsm""" + """init CusFusedAbsMax1""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.origin_shape = origin_shape - + from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1 def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -88,7 +151,21 @@ class CusFusedAbsMax1(PrimitiveWithInfer): class CusImg2Col(PrimitiveWithInfer): - """CusImg2Col definition""" + """ + Img2col the feature map and the result in reorganized in NC1HWC0. + + Args: + - **strides** (listInt) - the stride of the ops. + - **ksizes** (listInt) - the kernel size of the ops. + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16) + >>> cusimg2col = P.CusImg2Col() + >>> output = cusimg2col(input_x) + """ @prim_attr_register def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): @@ -98,7 +175,7 @@ class CusImg2Col(PrimitiveWithInfer): self.strides = strides self.dilates = dilates self.mode = mode - + from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -122,13 +199,30 @@ class CusImg2Col(PrimitiveWithInfer): class CusMatMulCubeDenseLeft(PrimitiveWithInfer): - """CusMatMulCube definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 must be `4`, the fractal format of the normal matrix. + The rank of input_x2 must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(N0, M0, N1, M1)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, C)`. + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft() + >>> output = matmulcubedenseleft(input_x, input_y) + """ @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusMatMulCubeDenseLeft""" self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -143,13 +237,32 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer): class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): - """CusMatMulCubeFraczRightMul definition""" + """ + Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`. + + The rank of input_x1 tensors must be `2`. + The rank of input_x2 tensors must be `4`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, C0, M0)`. + - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`. + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + Examples: + >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16) + >>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul() + >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3) + """ @prim_attr_register def __init__(self): """init CusMatMulCubeFraczRightMul""" self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul def get_bprop(self): def bprop(x1, x2, x3, out, dout): return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) @@ -164,7 +277,30 @@ class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): class CusMatMulCube(PrimitiveWithInfer): - """CusMatMulCube definition""" + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input tensors must be `2`. + + Args: + transpose_a (bool): If True, `a` is transposed before multiplication. Default: False. + transpose_b (bool): If True, `b` is transposed before multiplication. Default: False. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If + `transpose_a` is True, its shape should be :math:`(N, C)` after transposing. + - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If + `transpose_b` is True, its shape should be :math:`(C, M)` after transpose. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcube = P.CusMatMulCube() + >>> output = matmul(input_x, input_y) + """ @prim_attr_register def __init__(self, transpose_a=False, transpose_b=False): @@ -172,7 +308,7 @@ class CusMatMulCube(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) self.transpose_a = transpose_a self.transpose_b = transpose_b - + from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube def get_bprop(self): def bprop(x1, x2, out, dout): return (C.zeros_like(x1), C.zeros_like(x2)) @@ -199,13 +335,27 @@ class CusMatMulCube(PrimitiveWithInfer): class CusMatrixCombine(PrimitiveWithInfer): - """CusMatMulCube definition""" + """ + move the batch matrix to result matrix diag part. + The rank of input tensors must be `3`. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N * D, N * D)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32) + >>> cusmatrixcombine = P.CusMatrixCombine() + >>> output = cusmatrixcombine(input_x) + """ @prim_attr_register def __init__(self): - """init CusMatMulCube""" + """init CusMatrixCombine""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -223,13 +373,28 @@ class CusMatrixCombine(PrimitiveWithInfer): class CusTranspose02314(PrimitiveWithInfer): - """CusTranspose02314 definition""" + """ + Permute input tensor with perm (0, 2, 3, 1, 4) + + The rank of input tensors must be `5` with format NC1HWC0. + + Inputs: + - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16) + >>> custranspose02314 = P.CusTranspose02314() + >>> output = custranspose02314(input_x) + """ @prim_attr_register def __init__(self): """init CusTranspose02314""" self.init_prim_io_names(inputs=['x1'], outputs=['y']) - + from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 def get_bprop(self): def bprop(x, out, dout): return (C.zeros_like(x),) @@ -246,3 +411,83 @@ class CusTranspose02314(PrimitiveWithInfer): def infer_dtype(self, data1_dtype): return data1_dtype + + +class CusMatMulCubeDenseRight(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `2`. + The rank of input_x2 tensor must be `4`. + + Inputs: + - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. + - **input_y** (Tensor) - The second tensor to be multiplied. + The shape of the tensor is :math:`(C1, M1, M0, C0)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight() + >>> output = cusmatmulcubedenseright(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeDenseRight""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight + def get_bprop(self): + def bprop(x1, x2, x3, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape, data3_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float32")) + + +class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): + """ + Multiplies matrix `a` by matrix `b`. + + The rank of input_x1 tensor must be `4`. + The rank of input_x2 tensors must be `2`. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. + The shape of the tensor is :math:`(C1, N1, N0, C0)`. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N, M)`. + + Examples: + >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16) + >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16) + >>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast() + >>> output = cusmatmulcubefraczleftcast(input_x, input_y) + """ + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeFraczLeftCast""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast + def get_bprop(self): + def bprop(x1, x2, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape): + return data2_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float16"))