|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import os
- import tensorlayer as tl
- from tensorlayer.files import utils
- from tensorlayer import logging
-
- _act_dict = {
- "relu": tl.ops.ReLU,
- "relu6": tl.ops.ReLU6,
- "leaky_relu": tl.ops.LeakyReLU,
- "lrelu": tl.ops.LeakyReLU,
- "softplus": tl.ops.Softplus,
- "tanh": tl.ops.Tanh,
- "sigmoid": tl.ops.Sigmoid,
- "softmax": tl.ops.Softmax
- }
-
-
- def str2act(act):
- if len(act) > 5 and act[0:5] == "lrelu":
- try:
- alpha = float(act[5:])
- return tl.ops.LeakyReLU(alpha=alpha)
- except Exception as e:
- raise Exception("{} can not be parsed as a float".format(act[5:]))
-
- if len(act) > 10 and act[0:10] == "leaky_relu":
- try:
- alpha = float(act[10:])
- return tl.ops.LeakyReLU(alpha=alpha)
- except Exception as e:
- raise Exception("{} can not be parsed as a float".format(act[10:]))
-
- if act not in _act_dict.keys():
- raise Exception("Unsupported act: {}".format(act))
- return _act_dict[act]
-
-
- def _save_weights(net, file_path, format=None):
- """Input file_path, save model weights into a file of given format.
- Use net.load_weights() to restore.
-
- Parameters
- ----------
- file_path : str
- Filename to which the model weights will be saved.
- format : str or None
- Saved file format.
- Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
- 1) If this is set to None, then the postfix of file_path will be used to decide saved format.
- If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
- 2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
- the hdf5 file.
- 3) 'npz' will save model weights sequentially into a npz file.
- 4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
- 5) 'ckpt' will save model weights into a tensorflow ckpt file.
-
- Default None.
-
- Examples
- --------
- 1) Save model weights in hdf5 format by default.
- >>> net = vgg16()
- >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
- >>> metric = tl.metric.Accuracy()
- >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
- >>> model.save_weights('./model.h5')
- ...
- >>> model.load_weights('./model.h5')
-
- 2) Save model weights in npz/npz_dict format
- >>> model.save_weights('./model.npz')
- >>> model.save_weights('./model.npz', format='npz_dict')
-
- """
-
- if net.all_weights is None or len(net.all_weights) == 0:
- logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
- return
-
- if format is None:
- postfix = file_path.split('.')[-1]
- if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
- format = postfix
- else:
- format = 'hdf5'
-
- if format == 'hdf5' or format == 'h5':
- raise NotImplementedError("hdf5 load/save is not supported now.")
- # utils.save_weights_to_hdf5(file_path, net)
- elif format == 'npz':
- utils.save_npz(net.all_weights, file_path)
- elif format == 'npz_dict':
- utils.save_npz_dict(net.all_weights, file_path)
- elif format == 'ckpt':
- # TODO: enable this when tf save ckpt is enabled
- raise NotImplementedError("ckpt load/save is not supported now.")
- else:
- raise ValueError(
- "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
- "Other format is not supported now."
- )
-
-
- def _load_weights(net, file_path, format=None, in_order=True, skip=False):
- """Load model weights from a given file, which should be previously saved by net.save_weights().
-
- Parameters
- ----------
- file_path : str
- Filename from which the model weights will be loaded.
- format : str or None
- If not specified (None), the postfix of the file_path will be used to decide its format. If specified,
- value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
- In addition, it should be the same format when you saved the file using net.save_weights().
- Default is None.
- in_order : bool
- Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'.
- If 'in_order' is True, weights from the file will be loaded into model in a sequential way.
- If 'in_order' is False, weights from the file will be loaded into model by matching the name
- with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from
- a weights file which is saved in graph(eager) mode.
- Default is True.
- skip : bool
- Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is
- 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights
- whose name is not found in model weights (net.all_weights) will be skipped. If 'skip' is False, error will
- occur when mismatch is found.
- Default is False.
-
- Examples
- --------
- 1) load model from a hdf5 file.
- >>> net = vgg16()
- >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
- >>> metric = tl.metric.Accuracy()
- >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
- >>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
- >>> model.load_weights('./model_eager.h5') # load sequentially
-
- 2) load model from a npz file
- >>> model.load_weights('./model.npz')
-
- 3) load model from a npz file, which is saved as npz_dict previously
- >>> model.load_weights('./model.npz', format='npz_dict')
-
- Notes
- -------
- 1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is
- saved in a different mode, it is recommended to set 'in_order' be True.
- 2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True,
- 'in_order' argument will be ignored.
-
- """
- if not os.path.exists(file_path):
- raise FileNotFoundError("file {} doesn't exist.".format(file_path))
-
- if format is None:
- format = file_path.split('.')[-1]
-
- if format == 'hdf5' or format == 'h5':
- raise NotImplementedError("hdf5 load/save is not supported now.")
- # if skip ==True or in_order == False:
- # # load by weights name
- # utils.load_hdf5_to_weights(file_path, net, skip)
- # else:
- # # load in order
- # utils.load_hdf5_to_weights_in_order(file_path, net)
- elif format == 'npz':
- utils.load_and_assign_npz(file_path, net)
- elif format == 'npz_dict':
- utils.load_and_assign_npz_dict(file_path, net, skip)
- elif format == 'ckpt':
- # TODO: enable this when tf save ckpt is enabled
- raise NotImplementedError("ckpt load/save is not supported now.")
- else:
- raise ValueError(
- "File format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. "
- "Other format is not supported now."
- )
|