Browse Source

Renmae to AssetResource.

tags/v0.100.5-BERT-load
Haiping Chen 2 years ago
parent
commit
bdf229acbb
7 changed files with 73 additions and 30 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  2. +0
    -11
      src/TensorFlowNET.Core/Trackables/Asset.cs
  3. +18
    -0
      src/TensorFlowNET.Core/Trackables/AssetResource.cs
  4. +4
    -3
      src/TensorFlowNET.Core/Trackables/RestoredResource.cs
  5. +14
    -3
      src/TensorFlowNET.Core/Trackables/TrackableConstant.cs
  6. +13
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  7. +16
    -13
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

+ 8
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -80,6 +80,14 @@ namespace Tensorflow
{ {
return np.array(tensor.IntVal.ToArray()).reshape(shape); return np.array(tensor.IntVal.ToArray()).reshape(shape);
} }
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype))
{
return np.array(tensor.Int64Val.ToArray()).reshape(shape);
}
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype))
{
return np.array(tensor.Uint64Val.ToArray()).reshape(shape);
}
else if (tensor.Dtype == DataType.DtBool) else if (tensor.Dtype == DataType.DtBool)
{ {
return np.array(tensor.BoolVal.ToArray()).reshape(shape); return np.array(tensor.BoolVal.ToArray()).reshape(shape);


+ 0
- 11
src/TensorFlowNET.Core/Trackables/Asset.cs View File

@@ -1,11 +0,0 @@
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class Asset : Trackable
{
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
{
return (null, null);
}
}

+ 18
- 0
src/TensorFlowNET.Core/Trackables/AssetResource.cs View File

@@ -0,0 +1,18 @@
using Google.Protobuf.Collections;
using System.IO;
using Tensorflow.Train;

namespace Tensorflow.Trackables;

public class AssetResource : Trackable
{
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
string export_dir,
RepeatedField<AssetFileDef> asset_file_def,
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
{
var proto = object_proto.Asset;
var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename);
return (new AssetResource(), null);
}
}

+ 4
- 3
src/TensorFlowNET.Core/Trackables/RestoredResource.cs View File

@@ -1,12 +1,13 @@
using System.Runtime.CompilerServices;
using Google.Protobuf.Collections;
using Tensorflow.Train; using Tensorflow.Train;


namespace Tensorflow.Trackables; namespace Tensorflow.Trackables;


public class RestoredResource : TrackableResource public class RestoredResource : TrackableResource
{ {
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
{ {
return (null, null);
return (new RestoredResource(), null);
} }
} }

+ 14
- 3
src/TensorFlowNET.Core/Trackables/TrackableConstant.cs View File

@@ -1,11 +1,22 @@
using Tensorflow.Train;
using Google.Protobuf.Collections;
using Tensorflow.Train;


namespace Tensorflow.Trackables; namespace Tensorflow.Trackables;


public class TrackableConstant : Trackable public class TrackableConstant : Trackable
{ {
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
Tensor _constant;
public TrackableConstant(Tensor constant)
{ {
return (null, null);
_constant = constant;
}

public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
{
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor;
var ndarray = tensor_util.MakeNdarray(tensor_proto);
var imported_constant = constant_op.constant(ndarray);
return (new TrackableConstant(imported_constant), null);
} }
} }

+ 13
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel
{ {
public static class function_deserialization public static class function_deserialization
{ {
/// <summary>
/// Creates a `Function` from a `SavedFunction`.
/// </summary>
/// <param name="saved_concrete_function"></param>
/// <param name="concrete_functions"></param>
/// <returns></returns>
public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions)
{
var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec);
return null;
}

public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions) IDictionary<string, ConcreteFunction> concrete_functions)
{ {


+ 16
- 13
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -387,13 +387,6 @@ namespace Tensorflow
} }
else else
{ {
// skip the function and concrete function.
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function)
{
nodes[node_id] = null;
node_setters[node_id] = null;
continue;
}
var (node, setter) = _recreate(proto, node_id, nodes); var (node, setter) = _recreate(proto, node_id, nodes);
nodes[node_id] = node; nodes[node_id] = node;
node_setters[node_id] = setter; node_setters[node_id] = setter;
@@ -471,6 +464,11 @@ namespace Tensorflow
} }
} }


private void _setup_function_captures()
{
// TODO: implement it with concrete functions.
}

private void _setup_remaining_functions() private void _setup_remaining_functions()
{ {
// TODO: implement it with concrete functions. // TODO: implement it with concrete functions.
@@ -542,9 +540,9 @@ namespace Tensorflow


return proto.KindCase switch return proto.KindCase switch
{ {
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(),
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(),
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(),
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes),
SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes),
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes),
_ => _recreate_default(proto, node_id, dependencies) _ => _recreate_default(proto, node_id, dependencies)
}; };
} }
@@ -563,7 +561,8 @@ namespace Tensorflow
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(),
_ => throw new NotImplementedException()
}; };
} }


@@ -623,8 +622,12 @@ namespace Tensorflow
private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies) Dictionary<Maybe<string, int>, Trackable> dependencies)
{ {
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
var fn = function_deserialization.recreate_function(proto, null);
foreach (var name in proto.ConcreteFunctions)
{
_setup_function_captures();
}
return (fn, setattr);
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,


Loading…
Cancel
Save