|
- """
- Copyright 2020 Tianshu AI Platform. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =============================================================
- """
-
- import abc
-
- from . import depara
- from captum.attr import InputXGradient
- import torch
- from kamal.core import hub
- from kamal.vision import sync_transforms as sT
-
- class TransMetric(abc.ABC):
- def __init__(self):
- pass
-
- def __call__(self, a, b) -> float:
- return 0
-
- class DeparaMetric(TransMetric):
- def __init__(self, data, device):
- self.data = data
- self.device = device
- self._cache = {}
-
- def _get_transform(self, metadata):
- input_metadata = metadata['input']
- size = input_metadata['size']
- space = input_metadata['space']
- drange = input_metadata['range']
- normalize = input_metadata['normalize']
- if size==None:
- size=224
- if isinstance(size, (list, tuple)):
- size = size[-1]
- transform = [
- sT.Resize(size),
- sT.CenterCrop(size),
- ]
- if space=='bgr':
- transform.append(sT.FlipChannels())
- if list(drange)==[0, 1]:
- transform.append( sT.ToTensor() )
- elif list(drange)==[0, 255]:
- transform.append( sT.ToTensor(normalize=False, dtype=torch.float) )
- else:
- raise NotImplementedError
- if normalize is not None:
- transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] ))
- return sT.Compose(transform)
-
- def _get_attr_graph(self, n):
- transform = self._get_transform(n.metadata)
- data = torch.stack( [ transform( d ) for d in self.data ], dim=0 )
- return depara.get_attribution_graph(
- n.model,
- attribution_type=InputXGradient,
- with_noise=False,
- probe_data=data,
- device=self.device
- )
-
- def __call__(self, n1, n2):
- attrgraph_1 = self._cache.get(n1, None)
- attrgraph_2 = self._cache.get(n2, None)
- if attrgraph_1 is None:
- self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu()
- if attrgraph_2 is None:
- self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu()
- result = depara.graph_similarity(attrgraph_1, attrgraph_2)
- self._cache[n1] = self._cache[n1]
- self._cache[n2] = self._cache[n2]
-
- class AttrMapMetric(DeparaMetric):
- def _get_attr_map(self, n):
- transform = self._get_transform(n.metadata)
- data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 )
- return depara.attribution_map(
- n.model.to(self.device),
- attribution_type=InputXGradient,
- with_noise=False,
- probe_data=data,
- )
-
- def __call__(self, n1, n2):
- attrgraph_1 = self._cache.get(n1, None)
- attrgraph_2 = self._cache.get(n2, None)
- if attrgraph_1 is None:
- self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu()
- if attrgraph_2 is None:
- self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu()
- result = depara.attr_map_distance(attrgraph_1, attrgraph_2)
- self._cache[n1] = self._cache[n1]
- self._cache[n2] = self._cache[n2]
- return result
|