# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless REQUIRED by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """funcs for gen_explicit_map""" from functools import partial def gen_explicit_map_f_max_pool2d(params_pt, args_pt): """ Generate explicit_map for F.MaxPool2d. Args: params_pt (dict): Params for APIPt. args_pt (dict): Args for APIPt. Returns: dict, map between frames. """ if 'padding' in args_pt: padding = args_pt['padding'] else: padding = params_pt['padding'] if padding.strip() in ("0", "(0,0)", "(0, 0)"): padding = "'valid'" else: padding = "'same'" return {"padding": padding} def gen_explicit_map_nn_sequential(_, args_pt): """ Generate explicit_map for nn.Sequential. Args: args_pt (dict): Args for APIPt. Returns: dict, map between frames. """ args = args_pt['*args'] return {"*args": "[{}]".format(args)} def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): """ Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`. Args: params_pt (dict): Params for APIPt. args_pt (dict): Args for APIPt. Returns: dict, map between frames. """ value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt] value = value.strip() def is_number(string): try: float(string) return True except ValueError: return False if is_number(value): return {k_ms: str(1 - float(value))} return {k_ms: "1.0 - " + value} def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): """ Generate explicit_map for nn.MaxPool2d. Args: params_pt (dict): Params for APIPt. args_pt (dict): Args for APIPt. Returns: dict, map between frames. """ if 'padding' in args_pt: padding = args_pt['padding'] else: padding = params_pt['padding'] if padding.strip() in ("0", "(0,0)", "(0, 0)"): pad_mode = "'valid'" else: pad_mode = "'same'" return {"pad_mode": pad_mode} tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} tensor_dot_reshape_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} nn_conv2d_gen_explicit_map = lambda params_pt, args_pt: {"pad_mode": "'pad'"} nn_batchnorm2d_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="momentum", k_pt="momentum") nn_dropout_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="keep_prob", k_pt="p")