Browse Source

loading of hdf5 files from keras #554

tags/yolov3
Oceania2018 4 years ago
parent
commit
02ce65b47f
5 changed files with 119 additions and 103 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/ops.cs
  2. +13
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  3. +18
    -38
      src/TensorFlowNET.Keras/Engine/Model.Training.cs
  4. +80
    -64
      src/TensorFlowNET.Keras/Saving/fdf5_format.cs
  5. +0
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 8
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -532,5 +532,13 @@ namespace Tensorflow
var g = get_default_graph();
return g.get_name_scope();
}

public static bool executing_eagerly_outside_functions()
{
if (tf.Context.executing_eagerly())
return true;
else
throw new NotImplementedException("");
}
}
}

+ 13
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -142,6 +142,19 @@ namespace Tensorflow.Keras
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}

public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
{
if (ops.executing_eagerly_outside_functions())
{
foreach (var (x, value) in tuples)
x.assign(value);
}
else
{
throw new NotImplementedException("");
}
}

/// <summary>
/// Pads the 2nd and 3rd dimensions of a 4D tensor.
/// </summary>


+ 18
- 38
src/TensorFlowNET.Keras/Engine/Model.Training.cs View File

@@ -9,48 +9,28 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
private long fileId = -1;
private long f = -1;
public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null)
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
long root = Hdf5.OpenFile(filepath, true);
long fileId = Hdf5.OpenFile(filepath, true);

long fileId = root;
//try
//{
bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");

bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");

if (!lsuccess && msuccess)
{
f = H5G.open(fileId, "model_weights");
}
if (by_name)
{
//fdf5_format.load_weights_from_hdf5_group_by_name();
}
else
{
fdf5_format.load_weights_from_hdf5_group(f, this);
}
H5G.close(f);
//}
//catch (Exception ex)
//{
// if (fileId != -1)
// {
// Hdf5.CloseFile(fileId);
// }
// if (f != -1)
// {
// H5G.close(f);
// }
// throw new Exception(ex.ToString());
//}
if (!lsuccess && msuccess)
{
fileId = H5G.open(fileId, "model_weights");
}
if (by_name)
{
//fdf5_format.load_weights_from_hdf5_group_by_name();
throw new NotImplementedException("");
}
else
{
fdf5_format.load_weights_from_hdf5_group(fileId, Layers);
}
H5G.close(fileId);
}
}
}


+ 80
- 64
src/TensorFlowNET.Keras/Saving/fdf5_format.cs View File

@@ -7,6 +7,8 @@ using Tensorflow.Keras.Engine;
using HDF5CSharp;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class fdf5_format
@@ -45,13 +47,29 @@ namespace Tensorflow.Keras.Saving
{

}
public static void preprocess_weights_for_loading(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}
public static void _convert_rnn_weights(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
/// <summary>
/// Preprocess layer weights between different Keras formats.
/// </summary>
/// <param name="layer"></param>
/// <param name="weights"></param>
/// <param name="original_keras_version"></param>
/// <param name="original_backend"></param>
public static List<NDArray> preprocess_weights_for_loading(ILayer layer, List<NDArray> weights, string original_keras_version = null, string original_backend = null)
{
// convert CuDNN layers
return _convert_rnn_weights(layer, weights);
}

/// <summary>
/// Converts weights for RNN layers between native and CuDNN format.
/// </summary>
/// <param name="layer"></param>
/// <param name="weights"></param>
static List<NDArray> _convert_rnn_weights(ILayer layer, List<NDArray> weights)
{
var target_class = layer.GetType().Name;
return weights;
}
public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{
@@ -65,56 +83,79 @@ namespace Tensorflow.Keras.Saving
{

}
public static void load_weights_from_hdf5_group(long f=-1,Model model=null)
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
{
string original_keras_version = "1";
string original_keras_version = "2.4.0";
string original_backend = null;
if (Hdf5.AttributeExists(f, "keras_version"))
{
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
if (success)
{
original_keras_version = attr[0];
}
original_keras_version = attr.First();
// keras version should be 2.5.0+
var ver_major = int.Parse(original_keras_version.Split('.')[0]);
var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
throw new ValueError("keras version should be 2.5.0 or later.");
}
if (Hdf5.AttributeExists(f, "backend"))
{
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", "");
var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "");
if (success)
{
original_backend = attr[0];
}
original_backend = attr.First();
}
List<ILayer> filtered_layers = new List<ILayer>();
List<Tensor> weights;
foreach (var layer in model.Layers)
List<IVariableV1> weights;
foreach (var layer in layers)
{
weights = _legacy_weights(layer);
if (weights.Count>0)
if (weights.Count > 0)
{
filtered_layers.append(layer);
}
}
string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names");
List<NDArray> weight_values=new List<NDArray>();
foreach (var i in filtered_layers) {
long g = H5G.open(f, i.Name);
string[] weight_names = null;
if (g != -1)
{
weight_names = load_attributes_from_hdf5_group(g, "weight_names");
}
if (weight_names != null)
string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names");
var filtered_layer_names = new List<string>();
foreach(var name in layer_names)
{
long g = H5G.open(f, name);
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
if (weight_names.Count() > 0)
filtered_layer_names.Add(name);
H5G.close(g);
}
layer_names = filtered_layer_names.ToArray();
if (layer_names.Length != filtered_layers.Count())
throw new ValueError("You are trying to load a weight file " +
$"containing {layer_names}" +
$" layers into a model with {filtered_layers.Count} layers.");

var weight_value_tuples = new List<(IVariableV1, NDArray)>();
foreach (var (k, name) in enumerate(layer_names))
{
var weight_values = new List<NDArray>();
long g = H5G.open(f, name);
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
foreach (var i_ in weight_names)
{
foreach (var i_ in weight_names) {
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
//
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
if (success)
weight_values.Add(np.array(result));
}
}
H5G.close(g);
var layer = filtered_layers[k];
var symbolic_weights = _legacy_weights(layer);
preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend);
if (weight_values.Count() != symbolic_weights.Count())
throw new ValueError($"Layer #{k} (named {layer.Name}" +
"in the current model) was found to " +
$"correspond to layer {name} in the save file." +
$"However the new layer {layer.Name} expects " +
$"{symbolic_weights.Count()} weights, but the saved weights have " +
$"{weight_values.Count()} elements.");
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
}

keras.backend.batch_set_value(weight_value_tuples);
}
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{
@@ -128,15 +169,13 @@ namespace Tensorflow.Keras.Saving
{

}
public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "")
public static string[] load_attributes_from_hdf5_group(long group, string name)
{
if (Hdf5.AttributeExists(f, name))
if (Hdf5.AttributeExists(group, name))
{
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, "");
var (success, attr) = Hdf5.ReadStringAttributes(group, name, "");
if (success)
{
return attr;
}
return attr.ToArray();
}
return null;
}
@@ -145,33 +184,10 @@ namespace Tensorflow.Keras.Saving

}

public static List<Tensor> _legacy_weights(ILayer layer)
public static List<IVariableV1> _legacy_weights(ILayer layer)
{
List<Tensor> weights= new List<Tensor>();
if (layer.trainable_weights.Count != 0)
{
Tensor[] trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.trainable_weights.ToArray(), s => s.AsTensor());
Tensor[] non_trainable_weights =null;
if (layer.non_trainable_weights.Count != 0)
{
non_trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.non_trainable_weights.ToArray(), s => s.AsTensor());
}
foreach (var i in trainable_weights) {
if (non_trainable_weights != null)
{
foreach (var i_ in non_trainable_weights)
{
weights.Add(i + i_);
}
}
else {
weights.Add(i);
};

}
}
var weights = layer.trainable_weights.Select(x => x).ToList();
weights.AddRange(layer.non_trainable_weights);
return weights;
}
}


+ 0
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -46,7 +46,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
</PropertyGroup>

<ItemGroup>
<PackageReference Include="HDF.PInvoke.1.10" Version="1.10.500" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />


Loading…
Cancel
Save