| @@ -71,9 +71,9 @@ class APIPt: | |||
| or the given args_str not valid. | |||
| """ | |||
| # expr is REQUIRED to meet (**) format | |||
| if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): | |||
| raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) | |||
| if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"): | |||
| raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without ' | |||
| 'considering spaces'.format(args_str)) | |||
| try: | |||
| ast_node = ast.parse("whatever_call_name" + args_str) | |||
| call_node = ast_node.body[0].value | |||
| @@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt): | |||
| padding = "'valid'" | |||
| else: | |||
| padding = "'same'" | |||
| return {"padding": padding} | |||
| if 'stride' in args_pt: | |||
| strides = args_pt['stride'] | |||
| else: | |||
| strides = args_pt['kernel_size'] | |||
| return {"padding": padding, | |||
| "strides": strides} | |||
| def gen_explicit_map_nn_sequential(_, args_pt): | |||
| @@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): | |||
| pad_mode = "'valid'" | |||
| else: | |||
| pad_mode = "'same'" | |||
| return {"pad_mode": pad_mode} | |||
| if 'stride' in args_pt: | |||
| stride = args_pt['stride'] | |||
| else: | |||
| stride = args_pt['kernel_size'] | |||
| return {"pad_mode": pad_mode, | |||
| "stride": stride} | |||
| def torch_dot_eye_gen_explicit_map(_, args_pt): | |||
| @@ -21,14 +21,13 @@ | |||
| "kernel_size": "REQUIRED", | |||
| "stride": null, | |||
| "padding": 0, | |||
| "dilation": 1, | |||
| "ceil_mode": false, | |||
| "return_indices": false | |||
| "count_include_pad": true, | |||
| "divisor_override": null | |||
| } | |||
| ], | |||
| "ms2pt_mapping": { | |||
| "ksize": "kernel_size", | |||
| "strides": "stride", | |||
| "input": "input" | |||
| }, | |||
| "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | |||
| @@ -62,7 +61,6 @@ | |||
| ], | |||
| "ms2pt_mapping": { | |||
| "ksize": "kernel_size", | |||
| "strides": "stride", | |||
| "input": "input" | |||
| }, | |||
| "gen_explicit_map": "gen_explicit_map_f_max_pool2d" | |||
| @@ -16,9 +16,7 @@ | |||
| "inplace": false | |||
| } | |||
| ], | |||
| "ms2pt_mapping": { | |||
| "keep_prob": "p" | |||
| }, | |||
| "ms2pt_mapping": {}, | |||
| "gen_explicit_map": "nn_dropout_gen_explicit_map" | |||
| }, | |||
| "nn.AvgPool2d": { | |||
| @@ -36,14 +34,13 @@ | |||
| "kernel_size": "REQUIRED", | |||
| "stride": null, | |||
| "padding": 0, | |||
| "dilation": 1, | |||
| "return_indices": false, | |||
| "ceil_mode": "False" | |||
| "ceil_mode": false, | |||
| "count_include_pad": true, | |||
| "divisor_override": null | |||
| } | |||
| ], | |||
| "ms2pt_mapping": { | |||
| "kernel_size": "kernel_size", | |||
| "stride": "stride" | |||
| "kernel_size": "kernel_size" | |||
| }, | |||
| "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | |||
| }, | |||
| @@ -68,8 +65,7 @@ | |||
| } | |||
| ], | |||
| "ms2pt_mapping": { | |||
| "kernel_size": "kernel_size", | |||
| "stride": "stride" | |||
| "kernel_size": "kernel_size" | |||
| }, | |||
| "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" | |||
| }, | |||
| @@ -64,6 +64,15 @@ class TestConverter: | |||
| assert replaced_code == code.replace('nn.Softmax(dim=1)', | |||
| '{}(axis=1)'.format(expected_ms_api_name)) | |||
| def test_convert_api_nn_dropout(self): | |||
| """Test convert_api function work ok when convert api nn.Dropout""" | |||
| code = """nn.Dropout(0.3)""" | |||
| expected_ms_api_name = 'nn.Dropout' | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('nn.Dropout(0.3)', | |||
| "{}(keep_prob=0.7)".format(expected_ms_api_name)) | |||
| # test convert_api with torch dot ops | |||
| def test_convert_api_torch_dot_abs(self): | |||
| """Test convert_api function work ok when convert api torch.abs""" | |||
| @@ -202,6 +211,33 @@ class TestConverter: | |||
| assert replaced_code == code.replace('F.sigmoid(input)', | |||
| '{}()(input)'.format(expected_ms_api_name)) | |||
| def test_convert_api_f_max_pool2d(self): | |||
| """Test convert_api function work ok when convert api F.max_pool2d""" | |||
| code = """F.max_pool2d(out, 2)""" | |||
| expected_ms_api_name = 'P.MaxPool' | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('F.max_pool2d(out, 2)', | |||
| "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) | |||
| def test_convert_api_f_avg_pool2d_without_strides(self): | |||
| """Test convert_api function work ok when convert api F.avg_pool2d""" | |||
| code = """F.avg_pool2d(out, 2)""" | |||
| expected_ms_api_name = 'P.AvgPool' | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('F.avg_pool2d(out, 2)', | |||
| "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) | |||
| def test_convert_api_f_avg_pool2d_with_strides(self): | |||
| """Test convert_api function work ok when convert api F.avg_pool2d""" | |||
| code = """F.avg_pool2d(out, 2, 3)""" | |||
| expected_ms_api_name = 'P.AvgPool' | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)', | |||
| "{}(2, 3, 'valid')(out)".format(expected_ms_api_name)) | |||
| # test convert_api with tensor dot ops | |||
| def test_convert_api_tensor_dot_repeat(self): | |||
| """Test convert_api function work ok when convert api .repeat""" | |||
| @@ -216,7 +252,6 @@ class TestConverter: | |||
| """Test convert_api function work ok when convert api .permute""" | |||
| code = "x.permute(2, 0, 1)" | |||
| expected_ms_api_name = 'P.Transpose' | |||
| replaced_code = self.converter_ins.convert_api(code) | |||
| assert replaced_code == code.replace('x.permute(2, 0, 1)', | |||
| '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) | |||