|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import os.path as osp
- import time
- import warnings
- import sys
- import shutil
- import json
- sys.path.append("../anomaly_detection")
- import mmcv
- import torch
- from pycocotools.coco import COCO
- #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
- from mmcv import Config, DictAction
- from mmcv.cnn import fuse_conv_bn
- from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
- from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
- wrap_fp16_model)
-
- from mmdet.apis import multi_gpu_test, single_gpu_test
- from mmdet.datasets import (build_dataloader, build_dataset,
- replace_ImageToTensor)
- from mmdet.models import build_detector
-
- # Copyright (c) OpenMMLab. All rights reserved.
- from functools import partial
- import numpy as np
- from sklearn.covariance import LedoitWolf
- from mmdet.core.export import build_model_from_cfg, preprocess_example_input
- from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
- from mmdet.apis import (async_inference_detector, inference_detector,
- init_detector, show_result_pyplot)
- import onnxruntime as ort
- import onnx
- print(f"onnxruntime device: {ort.get_device()}") # output: GPU
- print(f'ort avail providers: {ort.get_available_providers()}') # output: ['CUDAExecutionProvider', 'CPUExecutionProvider']
-
- def cal_features(config_file, checkpoint_file, data_path_images, data_path_labels):
- data_coco = json.load(open(data_path_labels))
- data_name = data_coco["images"]
-
- model = init_detector(config_file, checkpoint_file, device='cuda:0')
- imgs_name = []
- for i in range(len(data_name)):
- imgs_name.append(osp.join(data_path_images, data_name[i]["file_name"]))
- print("before infer")
- index = 0
- num = len(imgs_name)
- results = []
- step = 1
- while index<num:
- index += step
- if index < num:
- _, results_tmp = inference_detector(model, imgs_name[index-step:index], feat=True)
- else:
- _, results_tmp = inference_detector(model, imgs_name[index-step:num], feat=True)
- results += results_tmp
- #print(len(results_tmp))
- #print(len(results))
- print("after infer")
- return results
-
- def cal_recall(config_file, checkpoint_file, data_path_images, data_path_labels):
- data_coco = json.load(open(data_path_labels))
- data_name = data_coco["images"]
- data_ann = data_coco['annotations']
- boxes = {}
- for res in data_ann:
- #print(res)
- img_id = res["image_id"]
- for i in range(len(data_name)):
- if img_id == data_name[i]["id"]:
- img_name = data_name[i]["file_name"]
- break
- bbox = res["bbox"]
- label = res["category_id"]
- bbox.append(int(label))
- if img_name in boxes.keys():
- boxes[img_name].append(bbox)
- else:
- boxes[img_name]=[]
- boxes[img_name].append(bbox)
-
- model = init_detector(config_file, checkpoint_file, device='cuda:0')
- imgs_labels = []
- imgs_name = []
- num_ng = 0
- for i in range(len(data_name)):
- res_label = 0
- if data_name[i]["file_name"] in boxes.keys():
- res_label = 1
- num_ng += 1
- imgs_labels.append(res_label)
- imgs_name.append(osp.join(data_path_images, data_name[i]["file_name"]))
- num_ok = len(data_name)-num_ng
- print(len(imgs_labels), num_ok, num_ng)
- print("before infer")
- index = 0
- num = len(imgs_name)
- results = []
- step = 1
- while index<num:
- index += step
- if index < num:
- results_tmp = inference_detector(model, imgs_name[index-step:index])
- else:
- results_tmp = inference_detector(model, imgs_name[index-step:num])
- results += results_tmp
- #print(len(results))
- print("after infer")
-
- #score_thrs = [0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.018, 0.019, 0.02]
- recall_thrs = []
- for score_thr in np.arange(0.01, 0.5, 0.01):
- imgs_results = []
- for result in results:
- res_predict = 0
- #print(len(result))
- for i in result:
- #print(i.shape)
- for j in range(i.shape[0]):
- if i[j, 4]>score_thr:
- res_predict = 1
- imgs_results.append(res_predict)
-
- count_ng = 0
- count_ok = 0
- for i in range(len(imgs_labels)):
- if imgs_labels[i]==0 and imgs_results[i]==0:
- count_ok += 1
- if imgs_labels[i]==1 and imgs_results[i]==1:
- count_ng += 1
- '''if imgs_labels[i]==1 and imgs_results_1[i]==0:
- print(imgs_name[i])'''
- recall_thr = {"score_thr":score_thr, "recall(ok)":count_ok/(num_ok+0.00000001), "recall(ng)":count_ng/(num_ng+0.00000001)}
- recall_thrs.append(recall_thr)
- return recall_thrs
-
- def pytorch2onnx(model,
- input_img,
- input_shape,
- normalize_cfg,
- opset_version=11,
- show=False,
- output_file='model.onnx',
- verify=True,
- test_img=None,
- do_simplify=False,
- dynamic_export=True,
- skip_postprocess=False):
-
- input_config = {
- 'input_shape': input_shape,
- 'input_path': input_img,
- 'normalize_cfg': normalize_cfg
- }
- # prepare input
- one_img, one_meta = preprocess_example_input(input_config)
- img_list, img_meta_list = [one_img], [[one_meta]]
-
- if skip_postprocess:
- warnings.warn('Not all models support export onnx without post '
- 'process, especially two stage detectors!')
- model.forward = model.forward_dummy
- torch.onnx.export(
- model,
- one_img,
- output_file,
- input_names=['input'],
- export_params=True,
- keep_initializers_as_inputs=True,
- do_constant_folding=True,
- verbose=show,
- opset_version=opset_version)
-
- print(f'Successfully exported ONNX model without '
- f'post process: {output_file}')
- return
-
- # replace original forward function
- origin_forward = model.forward
- model.forward = partial(
- model.forward,
- img_metas=img_meta_list,
- return_loss=False,
- rescale=False)
-
- output_names = ['dets', 'labels', 'feature', 'entropy', 'learning_loss']
- if model.with_mask:
- output_names.append('masks')
- input_name = 'input'
- dynamic_axes = None
- if dynamic_export:
- dynamic_axes = {
- input_name: {
- 0: 'batch',
- 2: 'height',
- 3: 'width'
- },
- 'dets': {
- 0: 'batch',
- 1: 'num_dets',
- },
- 'labels': {
- 0: 'batch',
- 1: 'num_dets',
- },
- 'feature': {
- 0: 'batch',
- 1: 'feat_dim',
- },
- 'entropy': {
- 0: 'batch',
- 1: '1',
- },
- 'learning_loss': {
- 0: 'batch',
- 1: '1',
- },
- }
- if model.with_mask:
- dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
-
- torch.onnx.export(
- model,
- img_list,
- output_file,
- input_names=[input_name],
- output_names=output_names,
- export_params=True,
- keep_initializers_as_inputs=True,
- do_constant_folding=True,
- verbose=show,
- opset_version=opset_version,
- dynamic_axes=dynamic_axes)
-
- model.forward = origin_forward
-
- # get the custom op path
- ort_custom_op_path = ''
- try:
- from mmcv.ops import get_onnxruntime_op_path
- ort_custom_op_path = get_onnxruntime_op_path()
- except (ImportError, ModuleNotFoundError):
- warnings.warn('If input model has custom op from mmcv, \
- you may have to build mmcv with ONNXRuntime from source.')
-
- if do_simplify:
- import onnxsim
-
- from mmdet import digit_version
-
- min_required_version = '0.3.0'
- assert digit_version(onnxsim.__version__) >= digit_version(
- min_required_version
- ), f'Requires to install onnx-simplify>={min_required_version}'
-
- input_dic = {'input': img_list[0].detach().cpu().numpy()}
- model_opt, check_ok = onnxsim.simplify(
- output_file,
- input_data=input_dic,
- custom_lib=ort_custom_op_path,
- dynamic_input_shape=dynamic_export)
- if check_ok:
- onnx.save(model_opt, output_file)
- print(f'Successfully simplified ONNX model: {output_file}')
- else:
- warnings.warn('Failed to simplify ONNX model.')
- print(f'Successfully exported ONNX model: {output_file}')
-
- if verify:
- # check by onnx
- onnx_model = onnx.load(output_file)
- onnx.checker.check_model(onnx_model)
-
- # wrap onnx model
- onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
- if dynamic_export:
- # scale up to test dynamic shape
- h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
- h, w = min(1344, h), min(1344, w)
- input_config['input_shape'] = (1, 3, h, w)
-
- if test_img is None:
- input_config['input_path'] = input_img
-
- # prepare input once again
- one_img, one_meta = preprocess_example_input(input_config)
- img_list, img_meta_list = [one_img], [[one_meta]]
-
- # get pytorch output
- with torch.no_grad():
- pytorch_results = model(
- img_list,
- img_metas=img_meta_list,
- return_loss=False,
- rescale=True)[0]
-
- img_list = [_.cuda().contiguous() for _ in img_list]
- if dynamic_export:
- img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
- img_meta_list = img_meta_list * 2
- # get onnx output
- onnx_results = onnx_model(
- img_list, img_metas=img_meta_list, return_loss=False)[0]
- # visualize predictions
- score_thr = 0.3
- if show:
- out_file_ort, out_file_pt = None, None
- else:
- out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
-
- show_img = one_meta['show_img']
- model.show_result(
- show_img,
- pytorch_results,
- score_thr=score_thr,
- show=True,
- win_name='PyTorch',
- out_file=out_file_pt)
- onnx_model.show_result(
- show_img,
- onnx_results,
- score_thr=score_thr,
- show=True,
- win_name='ONNXRuntime',
- out_file=out_file_ort)
-
- # compare a part of result
- '''print(input_config['input_shape'])
- print(one_img)
- print(len(onnx_results))
- print(len(pytorch_results))
- print(onnx_results)
- print(pytorch_results)'''
- for i in range(len(onnx_results)):
- print(onnx_results[i].shape)
- print("***************")
- for i in range(len(pytorch_results)):
- print(pytorch_results[i].shape)
- if model.with_mask:
- compare_pairs = list(zip(onnx_results, pytorch_results))
- else:
- compare_pairs = [(onnx_results, pytorch_results)]
- err_msg = 'The numerical values are different between Pytorch' + \
- ' and ONNX, but it does not necessarily mean the' + \
- ' exported ONNX model is problematic.'
- # check the numerical value
- for onnx_res, pytorch_res in compare_pairs:
- for o_res, p_res in zip(onnx_res, pytorch_res):
- np.testing.assert_allclose(
- o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
- print('The numerical values are the same between Pytorch and ONNX')
-
-
- def parse_normalize_cfg(test_pipeline):
- transforms = None
- for pipeline in test_pipeline:
- if 'transforms' in pipeline:
- transforms = pipeline['transforms']
- break
- assert transforms is not None, 'Failed to find `transforms`'
- norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
- assert len(norm_config_li) == 1, '`norm_config` should only have one'
- norm_config = norm_config_li[0]
- return norm_config
-
- def parse_args():
- parser = argparse.ArgumentParser(
- description='MMDet test (and eval) a model')
- parser.add_argument('--train-work-dir', default='/model', help='checkpoint file')
- parser.add_argument(
- '--work-dir',
- default='/result',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--shape',
- help='infer image shape')
- parser.add_argument(
- '--data-path', default='/dataset', help='dataset path')
- parser.add_argument('--out', default='/result', help='output result file in pickle format')
- parser.add_argument(
- '--fuse-conv-bn',
- action='store_true',
- help='Whether to fuse conv and bn, this will slightly increase'
- 'the inference speed')
- parser.add_argument(
- '--format-only',
- action='store_true',
- help='Format the output results without perform evaluation. It is'
- 'useful when you want to format the result to a specific format and '
- 'submit it to the test server')
- parser.add_argument(
- '--eval',
- type=str,
- default='bbox',
- nargs='+',
- help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
- ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
- parser.add_argument('--show', action='store_true', help='show results')
- parser.add_argument(
- '--show-dir', help='directory where painted images will be saved')
- parser.add_argument(
- '--show-score-thr',
- type=float,
- default=0.3,
- help='score threshold (default: 0.3)')
- parser.add_argument(
- '--gpu-collect',
- action='store_true',
- help='whether to use gpu to collect results.')
- parser.add_argument(
- '--tmpdir',
- help='tmp directory used for collecting results from multiple '
- 'workers, available when gpu-collect is not specified')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--options',
- nargs='+',
- action=DictAction,
- help='custom options for evaluation, the key-value pair in xxx=yyy '
- 'format will be kwargs for dataset.evaluate() function (deprecate), '
- 'change to --eval-options instead.')
- parser.add_argument(
- '--eval-options',
- nargs='+',
- action=DictAction,
- help='custom options for evaluation, the key-value pair in xxx=yyy '
- 'format will be kwargs for dataset.evaluate() function')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
-
- if args.options and args.eval_options:
- raise ValueError(
- '--options and --eval-options cannot be both '
- 'specified, --options is deprecated in favor of --eval-options')
- if args.options:
- warnings.warn('--options is deprecated in favor of --eval-options')
- args.eval_options = args.options
- return args
-
-
- def main():
- args = parse_args()
-
- assert args.out or args.eval or args.format_only or args.show \
- or args.show_dir, \
- ('Please specify at least one operation (save/eval/format/show the '
- 'results / save the results) with the argument "--out", "--eval"'
- ', "--format-only", "--show" or "--show-dir"')
-
- if args.eval and args.format_only:
- raise ValueError('--eval and --format_only cannot be both specified')
-
- if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
- raise ValueError('The output file must be a pkl file.')
-
- cfg = Config.fromfile(osp.join(args.train_work_dir, 'config.py'))
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
- # import modules from string list.
- if cfg.get('custom_imports', None):
- from mmcv.utils import import_modules_from_strings
- import_modules_from_strings(**cfg['custom_imports'])
- # set cudnn_benchmark
- if cfg.get('cudnn_benchmark', False):
- torch.backends.cudnn.benchmark = True
-
- cfg.model.pretrained = None
- if cfg.model.get('neck'):
- if isinstance(cfg.model.neck, list):
- for neck_cfg in cfg.model.neck:
- if neck_cfg.get('rfp_backbone'):
- if neck_cfg.rfp_backbone.get('pretrained'):
- neck_cfg.rfp_backbone.pretrained = None
- elif cfg.model.neck.get('rfp_backbone'):
- if cfg.model.neck.rfp_backbone.get('pretrained'):
- cfg.model.neck.rfp_backbone.pretrained = None
-
- # in case the test dataset is concatenated
- if isinstance(cfg.data.test, dict):
- cfg.data.test.test_mode = True
- samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
- if samples_per_gpu > 1:
- # Replace 'ImageToTensor' to 'DefaultFormatBundle'
- cfg.data.test.pipeline = replace_ImageToTensor(
- cfg.data.test.pipeline)
- elif isinstance(cfg.data.test, list):
- for ds_cfg in cfg.data.test:
- ds_cfg.test_mode = True
- samples_per_gpu = max(
- [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
- if samples_per_gpu > 1:
- for ds_cfg in cfg.data.test:
- ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
-
- # init distributed env first, since logger depends on the dist info.
- if args.launcher == 'none':
- distributed = False
- else:
- distributed = True
- init_dist(args.launcher, **cfg.dist_params)
-
- rank, _ = get_dist_info()
- # allows not to create
- if args.work_dir is not None and rank == 0:
- mmcv.mkdir_or_exist(args.work_dir)
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- #json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
-
- if args.data_path is not None:
- coco_config=COCO(os.path.join(args.data_path,"annotations/instances_annotations.json"))
- cfg.data.test.img_prefix = os.path.join(args.data_path,"images")
- cfg.data.test.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
- cfg.classes = ()
- for cat in coco_config.cats.values():
- cfg.classes = cfg.classes + tuple([cat['name']])
- cfg.data.test.classes = cfg.classes
- # build the dataloader
- samples_per_gpu = 1
- #print(samples_per_gpu)
- dataset = build_dataset(cfg.data.test)
- data_loader = build_dataloader(
- dataset,
- samples_per_gpu=samples_per_gpu,
- workers_per_gpu=cfg.data.workers_per_gpu,
- dist=distributed,
- shuffle=False)
- eval_results = []
- best_eval_result = {'checkpoint':'epoch_1.pth','AUC':0, 'bbox_mAP_50':0}
- checkpoint_files = os.listdir(args.train_work_dir)
- for checkpoint_file in checkpoint_files:
- if not checkpoint_file.endswith('pth'):
- continue
- # build the model and load checkpoint
- cfg.model.train_cfg = None
- model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
- fp16_cfg = cfg.get('fp16', None)
- if fp16_cfg is not None:
- wrap_fp16_model(model)
- checkpoint = load_checkpoint(model, osp.join(args.train_work_dir, checkpoint_file), map_location='cpu')
- if args.fuse_conv_bn:
- model = fuse_conv_bn(model)
- # old versions did not save class info in checkpoints, this walkaround is
- # for backward compatibility
- if 'CLASSES' in checkpoint.get('meta', {}):
- model.CLASSES = checkpoint['meta']['CLASSES']
- else:
- model.CLASSES = dataset.CLASSES
-
- if not distributed:
- model = MMDataParallel(model, device_ids=[0])
- outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
- args.show_score_thr)
- else:
- model = MMDistributedDataParallel(
- model.cuda(),
- device_ids=[torch.cuda.current_device()],
- broadcast_buffers=False)
- outputs = multi_gpu_test(model, data_loader, args.tmpdir,
- args.gpu_collect)
-
- rank, _ = get_dist_info()
- if rank == 0:
- if args.out:
- print(f'\nwriting results to {args.out}')
- mmcv.dump(outputs, args.out)
- kwargs = {} if args.eval_options is None else args.eval_options
- if args.format_only:
- dataset.format_results(outputs, **kwargs)
- if args.eval:
- eval_kwargs = cfg.get('evaluation', {}).copy()
- # hard-code way to remove EvalHook args
- for key in [
- 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
- 'rule'
- ]:
- eval_kwargs.pop(key, None)
- eval_kwargs.update(dict(metric=args.eval, **kwargs))
- metric = dataset.evaluate(outputs, **eval_kwargs)
- #metric = dataset.evaluate(outputs, iou_thrs=[0.5], classwise=True)
- print(metric)
- print(metric['AUC'])
- print(metric['bbox_mAP_50'])
- eval_result = {'checkpoint':checkpoint_file, 'AUC':metric['AUC'], 'bbox_mAP_50':metric['bbox_mAP_50']}
- eval_results.append(eval_result)
- if eval_result['AUC'] + eval_result['bbox_mAP_50'] > best_eval_result['AUC'] + best_eval_result['bbox_mAP_50']:
- best_eval_result = eval_result
- '''metric_dict = dict(config=args.config, metric=metric)
- if args.work_dir is not None and rank == 0:
- mmcv.dump(metric_dict, json_file)'''
- print(eval_results)
- print(best_eval_result)
-
- if args.shape is None:
- img_scale = cfg.test_pipeline[1]['img_scale'][0]
- print(img_scale)
- input_shape = (1, 3, img_scale[1], img_scale[0])
- elif len(args.shape) == 1:
- input_shape = (1, 3, args.shape[0], args.shape[0])
- elif len(args.shape) == 2:
- input_shape = (1, 3) + tuple(args.shape)
- else:
- raise ValueError('invalid input shape')
-
- '''if os.path.exists(osp.abspath(osp.join(args.work_dir, "infer/"))):
- shutil.rmtree(osp.abspath(osp.join(args.work_dir, "infer/")))'''
- # create onnx dir
- onnx_path = osp.join(args.work_dir, 'infer')
- mmcv.mkdir_or_exist(onnx_path)
- #shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../../../infer/')), onnx_path)
-
- # build the model and load checkpoint
- model = build_model_from_cfg(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']))
-
- input_img = osp.join(osp.dirname(__file__), 'demo.jpg')
-
- normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)
-
- # convert model to onnx file
- pytorch2onnx(
- model,
- input_img,
- input_shape,
- normalize_cfg,
- output_file=osp.join(onnx_path,'model.onnx'),
- test_img=input_img)
-
- recall_thrs = cal_recall(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']), os.path.join(args.data_path,"images"), os.path.join(args.data_path,"annotations/instances_annotations.json"))
- best_eval_result['recall'] = recall_thrs
- print(best_eval_result)
- json_file = osp.join(args.work_dir, f'eval_result.json')
- mmcv.dump(best_eval_result, json_file)
-
- train_feats = cal_features(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']), os.path.join(args.data_path,"images"), os.path.join(args.data_path,"annotations/instances_annotations.json"))
- train_feats = np.array(train_feats)
- print(train_feats.shape)
- train_mean = np.mean(train_feats, axis=0)
- train_cov = LedoitWolf().fit(train_feats).covariance_
- train_cov_inv = np.linalg.pinv(train_cov)
- print(train_mean.shape, train_cov.shape, train_cov_inv.shape)
-
- shutil.copy(osp.join(args.train_work_dir, "config.py"), osp.join(args.work_dir, "infer/config.py"))
- shutil.copy(osp.join(args.train_work_dir, best_eval_result['checkpoint']), osp.join(args.work_dir, "infer/"+best_eval_result['checkpoint']))
- shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../../transformer/')), osp.join(args.work_dir, "infer/transformer"))
-
- class_name_file = open(osp.join(args.work_dir, "infer/class_names.txt"), 'w')
- for name in cfg.classes:
- class_name_file.write(name+'\n')
- print(osp.join(args.work_dir, "infer/class_names.txt"))
- np.savez(osp.join(args.work_dir, "infer/train_feature.npy"),train_mean=train_mean, train_cov=train_cov, train_cov_inv=train_cov_inv)
- print(osp.join(args.work_dir, "infer/train_feature.npy"))
-
- shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'serve_desc.yaml')), osp.join(args.work_dir, "infer/serve_desc.yaml"))
- shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'ext.proto')), osp.join(args.work_dir, "infer/transformer/ext.proto"))
-
-
- if __name__ == '__main__':
- main()
|