Browse Source

Adjust location of KerasTensor.

tags/v0.110.4-Transformer-Model
Haiping Chen 2 years ago
parent
commit
7b26d6699a
7 changed files with 173 additions and 164 deletions
  1. +64
    -0
      src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
  2. +0
    -53
      src/TensorFlowNET.Core/Tensors/KerasTensor.cs
  3. +4
    -13
      src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
  4. +27
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs
  5. +0
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +9
    -16
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  7. +69
    -77
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs

+ 64
- 0
src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs View File

@@ -0,0 +1,64 @@
namespace Tensorflow.Keras.Engine;

/// <summary>
/// A representation of a Keras in/output during Functional API construction.
/// </summary>
public class KerasTensor
{
private Tensors _original_tensors;
public Tensors original_tensors
{
get => _original_tensors;
set => _original_tensors = value;
}

private Shape _inferred_value;
public Shape inferred_value => _inferred_value;

private string _name;
private TensorSpec _type_spec;
public Shape shape => _type_spec.shape;
public TF_DataType dtype => _type_spec.dtype;

public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null)
{
_type_spec = type_spec;
_inferred_value = inferred_value;
_name = name;
}

public static KerasTensor from_tensor(Tensor tensor)
{
var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
kt.original_tensors = tensor;
return kt;
}

public override string ToString()
=> _original_tensors.Length switch
{
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
_ => _original_tensors.ToString(),
};

private string GetInferredValueString()
=> _inferred_value == null ? "" : "";

public static implicit operator Tensors(KerasTensor kt)
=> kt._original_tensors;

public static implicit operator Tensor(KerasTensor kt)
{
Tensor tensor = kt._original_tensors;
tensor.IsFromKerasTensor = true;
return tensor;
}

public static implicit operator KerasTensor(Tensor tensor)
=> from_tensor(tensor);

public static implicit operator KerasTensor(Tensors tensors)
=> from_tensor(tensors.First());
}

+ 0
- 53
src/TensorFlowNET.Core/Tensors/KerasTensor.cs View File

@@ -1,53 +0,0 @@
namespace Tensorflow.Keras.Engine;

/// <summary>
/// A representation of a Keras in/output during Functional API construction.
/// </summary>
public class KerasTensor
{
private Tensors _inferred_value;
public Tensors inferred_value
{
get => _inferred_value;
set => _inferred_value = value;
}

private string _name;
private TensorSpec _type_spec;
public Shape shape => _type_spec.shape;
public TF_DataType dtype => _type_spec.dtype;

public KerasTensor(TensorSpec type_spec, string name = null)
{
_type_spec = type_spec;
_name = name;
}

public static KerasTensor from_tensor(Tensor tensor)
{
var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
kt.inferred_value = tensor;
return kt;
}

public override string ToString()
=> _inferred_value.Length switch
{
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]",
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>",
_ => _inferred_value.ToString(),
};

public static implicit operator Tensors(KerasTensor kt)
=> kt._inferred_value;

public static implicit operator Tensor(KerasTensor kt)
=> kt._inferred_value;

public static implicit operator KerasTensor(Tensor tensor)
=> from_tensor(tensor);

public static implicit operator KerasTensor(Tensors tensors)
=> from_tensor(tensors.First());
}

+ 4
- 13
src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs View File

@@ -14,19 +14,10 @@
limitations under the License.
******************************************************************************/

using Tensorflow.NumPy;
using System;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;
namespace Tensorflow;

namespace Tensorflow
public partial class Tensor
{
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
public partial class Tensor
{
public TensorSpec ToTensorSpec()
=> new TensorSpec(shape, dtype, name);
}
public TensorSpec ToTensorSpec()
=> new TensorSpec(shape, dtype, name);
}

+ 27
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs View File

@@ -0,0 +1,27 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

namespace Tensorflow;

public partial class Tensor
{
public bool IsFromKerasTensor { get; set; }

/// <summary>
/// Keras History: (Layer, (node_index, tensor_index))
/// </summary>
public KerasHistory KerasHistory { get; set; }
}

+ 0
- 5
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -146,11 +146,6 @@ namespace Tensorflow
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
}

/// <summary>
/// Keras History: (Layer, (node_index, tensor_index))
/// </summary>
public KerasHistory KerasHistory { get; set; }

/// <summary>
/// Updates the shape of this tensor.
/// </summary>


+ 9
- 16
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -1,22 +1,15 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;

namespace Tensorflow.Keras.Models
namespace Tensorflow.Keras.Models;

public class ModelsApi: IModelsApi
{
public class ModelsApi: IModelsApi
{
public Functional from_config(FunctionalConfig config)
=> Functional.from_config(config);
public Functional from_config(FunctionalConfig config)
=> Functional.from_config(config);

public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
{
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
}
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
{
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
}
}

+ 69
- 77
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -1,97 +1,89 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using System.IO;
using Tensorflow.Train;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Saving.SavedModel
namespace Tensorflow.Keras.Saving.SavedModel;

public class KerasLoadModelUtils
{
public class KerasLoadModelUtils
/// <summary>
/// Corresponding to keras/saving/save.py/load_model
/// </summary>
/// <param name="filepath"></param>
/// <param name="custom_objects"></param>
/// <param name="compile"></param>
/// <param name="options"></param>
/// <returns></returns>
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
bool compile = true, LoadOptions? options = null)
{
/// <summary>
/// Corresponding to keras/saving/save.py/load_model
/// </summary>
/// <param name="filepath"></param>
/// <param name="custom_objects"></param>
/// <param name="compile"></param>
/// <param name="options"></param>
/// <returns></returns>
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
bool compile = true, LoadOptions? options = null)
using var savingScope = SharedObjectSavingScope.Enter();

using var ctx = LoadContext.load_context(options);

if (!File.Exists(filepath) && !Directory.Exists(filepath))
{
using (SharedObjectSavingScope.Enter())
{
using (LoadContext.load_context(options))
{
if (!File.Exists(filepath) && !Directory.Exists(filepath))
{
throw new IOException($"No file or directory found at {filepath}.");
}
if (Directory.Exists(filepath))
{
return load(filepath, compile, options);
}
else
{
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
}
}
}
throw new IOException($"No file or directory found at {filepath}.");
}

private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
if (Directory.Exists(filepath))
{
return load(filepath, compile, options);
}
else
{
SavedMetadata metadata = new SavedMetadata();
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
if (File.Exists(path_to_metadata_pb))
{
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
}
else
{
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
}
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
}
}

if (metadata.Nodes is null || metadata.Nodes.Count == 0)
{
return Loader.load(path, options: options) as Model;
}
private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
{
SavedMetadata metadata;
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
if (File.Exists(path_to_metadata_pb))
{
using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read);
metadata = SavedMetadata.Parser.ParseFrom(stream);
}
else
{
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
}

var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);
if (metadata.Nodes is null || metadata.Nodes.Count == 0)
{
return Loader.load(path, options: options) as Model;
}

Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
nodes_to_load["root"] = (null, null);
foreach(var item in keras_loader.LoadedNodes)
{
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
}
var loaded = Loader.load_partial(path, nodes_to_load, options);
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);

keras_loader.finalize_objects();
keras_loader.del_tracking();
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
nodes_to_load["root"] = (null, null);
foreach(var item in keras_loader.LoadedNodes)
{
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
}
var loaded = Loader.load_partial(path, nodes_to_load, options);

var model = loaded["root"];
keras_loader.finalize_objects();
keras_loader.del_tracking();

if(model is Model && compile)
{
// TODO(Rinne): implement it.
}
var model = loaded["root"];

if (!tf.Context.executing_eagerly())
{
// TODO(Rinne): implement it.
}
if (model is Model && compile)
{
// TODO(Rinne): implement it.
}

return model;
if (!tf.Context.executing_eagerly())
{
// TODO(Rinne): implement it.
}

return model;
}
}

Loading…
Cancel
Save