|
- import oneflow
- import contextlib
- import oneflow.nn as nn
-
- def split_batch(batch):
- if isinstance(batch, (list, tuple)):
- inputs, *targets = batch
- if len(targets)==1:
- targets = targets[0]
- return inputs, targets
- else:
- return [batch, None]
-
- @contextlib.contextmanager
- def set_mode(model, training=True):
- ori_mode = model.training
- model.train(training)
- yield
- model.train(ori_mode)
-
-
- def move_to_device(obj, device):
- if isinstance(obj, oneflow.Tensor):
- return obj.to(device=device)
- elif isinstance( obj, (list, tuple) ):
- return [ o.to(device=device) for o in obj ]
- elif isinstance(obj, nn.Module):
- return obj.to(device=device)
-
-
- def flatten_dict(dic):
- flattned = dict()
-
- def _flatten(prefix, d):
- for k, v in d.items():
- if isinstance(v, dict):
- if prefix is None:
- _flatten( k, v )
- else:
- _flatten( prefix+'%s/'%k, v )
- else:
- flattned[ (prefix+'%s/'%k).strip('/') ] = v
-
- _flatten('', dic)
- return flattned
|