|
|
|
@@ -267,16 +267,17 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): |
|
|
|
longest_name = param_not_load[0] |
|
|
|
while prefix_name != longest_name and param_not_load: |
|
|
|
logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) |
|
|
|
longest_name = sorted(param_not_load, key=len, reverse=True)[0] |
|
|
|
prefix_name = longest_name |
|
|
|
for net_param_name in param_not_load: |
|
|
|
for dict_name in parameter_dict: |
|
|
|
if dict_name.endswith(net_param_name): |
|
|
|
tmp_name = dict_name[:-len(net_param_name)] |
|
|
|
prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name |
|
|
|
prefix_name = dict_name[:-len(net_param_name)] |
|
|
|
break |
|
|
|
if prefix_name != longest_name: |
|
|
|
break |
|
|
|
|
|
|
|
if prefix_name != longest_name: |
|
|
|
logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) |
|
|
|
logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
new_param_name = prefix_name + param.name |
|
|
|
if param.name in param_not_load and new_param_name in parameter_dict: |
|
|
|
|