|
- """
- Copyright 2020 Tianshu AI Platform. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =============================================================
- """
- import torch
- import networkx as nx
- from . import depara
- import os, abc
- from typing import Callable
- from kamal import hub
- import json, numbers
-
- from tqdm import tqdm
-
- class Node(object):
- def __init__(self, hub_root, entry_name, spec_name):
- self.hub_root = hub_root
- self.entry_name = entry_name
- self.spec_name = spec_name
-
- @property
- def model(self):
- return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval()
-
- @property
- def tag(self):
- return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name)
-
- @property
- def metadata(self):
- return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name)
-
- class TransferabilityGraph(object):
- def __init__(self, model_zoo_set):
- self.model_zoo_set = model_zoo_set
- # self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) )
- self._graphs = dict()
- self._models = dict()
- self._register_models()
-
- def _register_models(self):
- cnt = 0
- for model_zoo in self.model_zoo_set:
- model_zoo = os.path.abspath(os.path.expanduser(model_zoo))
- for hub_root in self._list_modelzoo(model_zoo):
- for entry_name, spec_name in hub.list_spec(hub_root):
- node = Node( hub_root, entry_name, spec_name )
- name = node.metadata['name']
- self._models[name] = node
- cnt += 1
- print("%d models has been registered!"%cnt)
-
- def _list_modelzoo(self, zoo_dir):
- zoo_list = []
- def _traverse(path):
- for item in os.listdir(path):
- item_path = os.path.join(path, item)
- if os.path.isdir(item_path):
- if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )):
- zoo_list.append(item_path)
- else:
- _traverse( item_path )
- _traverse(zoo_dir)
- return zoo_list
-
- def add_metric(self, metric_name, metric):
- self._graphs[metric_name] = g = nx.DiGraph()
- g.add_nodes_from( self._models.values() )
- for n1 in self._models.values():
- for n2 in tqdm(self._models.values()):
- if n1!=n2 and not g.has_edge(n1, n2):
- try:
- g.add_edge(n1, n2, dist=metric( n1, n2 ))
- except:
- ori_device = metric.device
- metric.device = torch.device('cpu')
- g.add_edge(n1, n2, dist=metric( n1, n2 ))
- metric.device = ori_device
-
- def export_to_json(self, metric_name, output_filename, topk=None, normalize=False):
- graph = self._graphs.get( metric_name, None )
- assert graph is not None
- graph_data={
- 'nodes': [],
- 'edges': [],
- }
- node_to_idx = {}
- for i, node in enumerate(self._models.values()):
- tags = node.tag
- metadata = node.metadata
- node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) }
- node_data['name'] = metadata['name']
- node_data['task'] = metadata['task']
- node_data['dataset'] = metadata['dataset']
- node_data['url'] = metadata['url']
- node_data['id'] = i
- graph_data['nodes'].append({'tags': node_data})
- node_to_idx[node] = i
-
- # record Edges
- edge_list = graph_data['edges']
- topk_dist = { idx: [] for idx in range(len( self._models )) }
- for i, edge in enumerate(graph.edges.data('dist')):
- s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2])
- topk_dist[s].append(d)
- edge_list.append([
- s, t, d # source, target, distance
- ])
-
- if isinstance(topk, int):
- for i, dist in topk_dist.items():
- dist.sort()
- topk_dist[i] = dist[topk]
- graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ]
-
- if normalize:
- edge_dist = [e[2] for e in graph_data['edges']]
- min_dist, max_dist = min(edge_dist), max(edge_dist)
- for e in graph_data['edges']:
- e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8)
-
- with open(output_filename, 'w') as fp:
- json.dump(graph_data, fp)
|