Browse Source

Add missing trackable class but not implemented.

tags/v0.100.5-BERT-load
Haiping Chen 2 years ago
parent
commit
8550dccc56
8 changed files with 120 additions and 4 deletions
  1. +40
    -0
      src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs
  2. +12
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Trackables/Asset.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Trackables/CapturableResource.cs
  5. +12
    -0
      src/TensorFlowNET.Core/Trackables/RestoredResource.cs
  6. +11
    -0
      src/TensorFlowNET.Core/Trackables/TrackableConstant.cs
  7. +5
    -0
      src/TensorFlowNET.Core/Trackables/TrackableResource.cs
  8. +22
    -4
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

+ 40
- 0
src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using Tensorflow.Util;

namespace Tensorflow;

public sealed class SafeOperationHandle : SafeTensorflowHandle
{
private SafeOperationHandle()
{
}

public SafeOperationHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
var status = new Status();
// c_api.TF_CloseSession(handle, status);
c_api.TF_DeleteSession(handle, status);
SetHandle(IntPtr.Zero);
return true;
}
}

+ 12
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -65,6 +65,18 @@ namespace Tensorflow
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator(); => GetEnumerator();


public string[] StringData()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData();
}

public string StringData(int index)
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData(index);
}

public NDArray numpy() public NDArray numpy()
{ {
EnsureSingleTensor(this, "nnumpy"); EnsureSingleTensor(this, "nnumpy");


+ 11
- 0
src/TensorFlowNET.Core/Trackables/Asset.cs View File

@@ -0,0 +1,11 @@
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class Asset : Trackable
{
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
{
return (null, null);
}
}

+ 7
- 0
src/TensorFlowNET.Core/Trackables/CapturableResource.cs View File

@@ -0,0 +1,7 @@
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class CapturableResource : Trackable
{
}

+ 12
- 0
src/TensorFlowNET.Core/Trackables/RestoredResource.cs View File

@@ -0,0 +1,12 @@
using System.Runtime.CompilerServices;
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class RestoredResource : TrackableResource
{
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
{
return (null, null);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Trackables/TrackableConstant.cs View File

@@ -0,0 +1,11 @@
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class TrackableConstant : Trackable
{
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
{
return (null, null);
}
}

+ 5
- 0
src/TensorFlowNET.Core/Trackables/TrackableResource.cs View File

@@ -0,0 +1,5 @@
namespace Tensorflow.Trackables;

public class TrackableResource : CapturableResource
{
}

+ 22
- 4
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -13,6 +13,7 @@ using System.Runtime.CompilerServices;
using Tensorflow.Variables; using Tensorflow.Variables;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Trackables;


namespace Tensorflow namespace Tensorflow
{ {
@@ -51,9 +52,13 @@ namespace Tensorflow
_node_filters = filters; _node_filters = filters;
_node_path_to_id = _convert_node_paths_to_ints(); _node_path_to_id = _convert_node_paths_to_ints();
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); _loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
foreach(var filter in filters)

if (filters != null)
{ {
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
foreach (var filter in filters)
{
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
}
} }


_filtered_nodes = _retrieve_all_filtered_nodes(); _filtered_nodes = _retrieve_all_filtered_nodes();
@@ -535,7 +540,13 @@ namespace Tensorflow
dependencies[item.Key] = nodes[item.Value]; dependencies[item.Key] = nodes[item.Value];
} }


return _recreate_default(proto, node_id, dependencies);
return proto.KindCase switch
{
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(),
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(),
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(),
_ => _recreate_default(proto, node_id, dependencies)
};
} }


/// <summary> /// <summary>
@@ -549,7 +560,7 @@ namespace Tensorflow
return proto.KindCase switch return proto.KindCase switch
{ {
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
SavedObject.KindOneofCase.Function => throw new NotImplementedException(),
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
@@ -609,6 +620,13 @@ namespace Tensorflow
} }
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
{
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
}

private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies) Dictionary<Maybe<string, int>, Trackable> dependencies)
{ {


Loading…
Cancel
Save