You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 1.1 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import oneflow
  2. import contextlib
  3. import oneflow.nn as nn
  4. def split_batch(batch):
  5. if isinstance(batch, (list, tuple)):
  6. inputs, *targets = batch
  7. if len(targets)==1:
  8. targets = targets[0]
  9. return inputs, targets
  10. else:
  11. return [batch, None]
  12. @contextlib.contextmanager
  13. def set_mode(model, training=True):
  14. ori_mode = model.training
  15. model.train(training)
  16. yield
  17. model.train(ori_mode)
  18. def move_to_device(obj, device):
  19. if isinstance(obj, oneflow.Tensor):
  20. return obj.to(device=device)
  21. elif isinstance( obj, (list, tuple) ):
  22. return [ o.to(device=device) for o in obj ]
  23. elif isinstance(obj, nn.Module):
  24. return obj.to(device=device)
  25. def flatten_dict(dic):
  26. flattned = dict()
  27. def _flatten(prefix, d):
  28. for k, v in d.items():
  29. if isinstance(v, dict):
  30. if prefix is None:
  31. _flatten( k, v )
  32. else:
  33. _flatten( prefix+'%s/'%k, v )
  34. else:
  35. flattned[ (prefix+'%s/'%k).strip('/') ] = v
  36. _flatten('', dic)
  37. return flattned

模型炼知是由浙江大学VIPA团队于2019-2020年期间提出,其目的是建立轻量化的知识融合算法和解决深度模型迁移性度量问题。 本仓库包含TTL、THL、TFL三个模型炼知示例算法,用于计算机视觉领域,通过将多个同构或异构教师重组,实现知识融合,获得定制化的、全能型的学生模型,解决所有教师任务,学生模型性能相比于传统训练结果显著提高。因此,模型炼知具有深入研究和实际应用价值。

Contributors (1)