@@ -80,6 +80,14 @@ namespace Tensorflow | |||||
{ | { | ||||
return np.array(tensor.IntVal.ToArray()).reshape(shape); | return np.array(tensor.IntVal.ToArray()).reshape(shape); | ||||
} | } | ||||
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | |||||
{ | |||||
return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||||
} | |||||
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | |||||
{ | |||||
return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||||
} | |||||
else if (tensor.Dtype == DataType.DtBool) | else if (tensor.Dtype == DataType.DtBool) | ||||
{ | { | ||||
return np.array(tensor.BoolVal.ToArray()).reshape(shape); | return np.array(tensor.BoolVal.ToArray()).reshape(shape); | ||||
@@ -1,11 +0,0 @@ | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Trackables; | |||||
public class Asset : Trackable | |||||
{ | |||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
{ | |||||
return (null, null); | |||||
} | |||||
} |
@@ -0,0 +1,18 @@ | |||||
using Google.Protobuf.Collections; | |||||
using System.IO; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Trackables; | |||||
public class AssetResource : Trackable | |||||
{ | |||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
string export_dir, | |||||
RepeatedField<AssetFileDef> asset_file_def, | |||||
Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
{ | |||||
var proto = object_proto.Asset; | |||||
var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename); | |||||
return (new AssetResource(), null); | |||||
} | |||||
} |
@@ -1,12 +1,13 @@ | |||||
using System.Runtime.CompilerServices; | |||||
using Google.Protobuf.Collections; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
namespace Tensorflow.Trackables; | namespace Tensorflow.Trackables; | ||||
public class RestoredResource : TrackableResource | public class RestoredResource : TrackableResource | ||||
{ | { | ||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
{ | { | ||||
return (null, null); | |||||
return (new RestoredResource(), null); | |||||
} | } | ||||
} | } |
@@ -1,11 +1,22 @@ | |||||
using Tensorflow.Train; | |||||
using Google.Protobuf.Collections; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Trackables; | namespace Tensorflow.Trackables; | ||||
public class TrackableConstant : Trackable | public class TrackableConstant : Trackable | ||||
{ | { | ||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
Tensor _constant; | |||||
public TrackableConstant(Tensor constant) | |||||
{ | { | ||||
return (null, null); | |||||
_constant = constant; | |||||
} | |||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
{ | |||||
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | |||||
var ndarray = tensor_util.MakeNdarray(tensor_proto); | |||||
var imported_constant = constant_op.constant(ndarray); | |||||
return (new TrackableConstant(imported_constant), null); | |||||
} | } | ||||
} | } |
@@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
{ | { | ||||
public static class function_deserialization | public static class function_deserialization | ||||
{ | { | ||||
/// <summary> | |||||
/// Creates a `Function` from a `SavedFunction`. | |||||
/// </summary> | |||||
/// <param name="saved_concrete_function"></param> | |||||
/// <param name="concrete_functions"></param> | |||||
/// <returns></returns> | |||||
public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, | |||||
IDictionary<string, ConcreteFunction> concrete_functions) | |||||
{ | |||||
var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); | |||||
return null; | |||||
} | |||||
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | ||||
IDictionary<string, ConcreteFunction> concrete_functions) | IDictionary<string, ConcreteFunction> concrete_functions) | ||||
{ | { | ||||
@@ -387,13 +387,6 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
// skip the function and concrete function. | |||||
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
{ | |||||
nodes[node_id] = null; | |||||
node_setters[node_id] = null; | |||||
continue; | |||||
} | |||||
var (node, setter) = _recreate(proto, node_id, nodes); | var (node, setter) = _recreate(proto, node_id, nodes); | ||||
nodes[node_id] = node; | nodes[node_id] = node; | ||||
node_setters[node_id] = setter; | node_setters[node_id] = setter; | ||||
@@ -471,6 +464,11 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private void _setup_function_captures() | |||||
{ | |||||
// TODO: implement it with concrete functions. | |||||
} | |||||
private void _setup_remaining_functions() | private void _setup_remaining_functions() | ||||
{ | { | ||||
// TODO: implement it with concrete functions. | // TODO: implement it with concrete functions. | ||||
@@ -542,9 +540,9 @@ namespace Tensorflow | |||||
return proto.KindCase switch | return proto.KindCase switch | ||||
{ | { | ||||
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), | |||||
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), | |||||
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), | |||||
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes), | |||||
SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes), | |||||
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes), | |||||
_ => _recreate_default(proto, node_id, dependencies) | _ => _recreate_default(proto, node_id, dependencies) | ||||
}; | }; | ||||
} | } | ||||
@@ -563,7 +561,8 @@ namespace Tensorflow | |||||
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | ||||
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | ||||
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | ||||
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||||
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | |||||
_ => throw new NotImplementedException() | |||||
}; | }; | ||||
} | } | ||||
@@ -623,8 +622,12 @@ namespace Tensorflow | |||||
private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | ||||
Dictionary<Maybe<string, int>, Trackable> dependencies) | Dictionary<Maybe<string, int>, Trackable> dependencies) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
//var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||||
var fn = function_deserialization.recreate_function(proto, null); | |||||
foreach (var name in proto.ConcreteFunctions) | |||||
{ | |||||
_setup_function_captures(); | |||||
} | |||||
return (fn, setattr); | |||||
} | } | ||||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||