Browse Source

add overload for Layer call function, be able to input array and return array.

v0.20-tensorflow2.3
Oceania2018 Haiping 5 years ago
parent
commit
f226ad704f
28 changed files with 199 additions and 98 deletions
  1. +17
    -0
      src/TensorFlowNET.Console/MemoryTestingCases.cs
  2. +3
    -0
      src/TensorFlowNET.Console/Program.cs
  3. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  6. +4
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  8. +37
    -3
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  17. +43
    -6
      src/TensorFlowNET.Core/Layers/Layer.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  19. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  20. +3
    -3
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  22. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  23. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  24. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  25. +0
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  26. +67
    -64
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  27. +1
    -0
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  28. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

+ 17
- 0
src/TensorFlowNET.Console/MemoryTestingCases.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using NumSharp;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -18,6 +19,22 @@ namespace Tensorflow
var tensor = tf.constant(3112.0f); var tensor = tf.constant(3112.0f);
} }
}; };

public Action<int> Constant2x3
=> (iterate) =>
{
var nd = np.array(new byte[,]
{
{1, 2, 3},
{4, 5, 6}
});
for (int i = 0; i < iterate; i++)
{
var tensor = tf.constant(nd);
var data = tensor.numpy();
}
};

public Action<int> Variable public Action<int> Variable
=> (iterate) => => (iterate) =>
{ {


+ 3
- 0
src/TensorFlowNET.Console/Program.cs View File

@@ -15,6 +15,9 @@ namespace Tensorflow


int batchSize = 1000; int batchSize = 1000;


// explaination of constant
mm.Execute(10, 100 * batchSize, cases.Constant2x3);

// 1 million float tensor 68M. // 1 million float tensor 68M.
mm.Execute(10, 100 * batchSize, cases.Constant); mm.Execute(10, 100 * batchSize, cases.Constant);




+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public partial class c_api public partial class c_api
{ {
public const string TensorFlowLibName = "tensorflow";
public const string TensorFlowLibName = @"C:\Users\haipi\Documents\Projects\tensorflow\bazel-bin\tensorflow\tensorflow";


public static string StringPiece(IntPtr handle) public static string StringPiece(IntPtr handle)
{ {


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -70,8 +70,8 @@ namespace Tensorflow.Eager


protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
base.DisposeUnmanagedResources(handle);
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
c_api.TF_DeleteTensor(_handle);
} }
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -311,6 +311,7 @@ namespace Tensorflow
while (queue.Count > 0) while (queue.Count > 0)
{ {
var op = queue.Dequeue(); var op = queue.Dequeue();

if (reached_ops.Contains(op)) if (reached_ops.Contains(op))
{ {
between_ops.Add(op); between_ops.Add(op);


+ 4
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -278,7 +278,11 @@ namespace Tensorflow
// after removing the trailing '/'. // after removing the trailing '/'.
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, attrs: attrs); var node_def = ops._NodeDef(op_type, name, attrs: attrs);
if (name == "rnn/while/basic_rnn_cell/MatMul"
|| name == "rnn/while/basic_rnn_cell/MatMul/Enter")
{


}
var input_ops = inputs.Select(x => x.op).ToArray(); var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops); var control_inputs = _control_dependencies_for_inputs(input_ops);




+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first"; _channels_first = args.DataFormat == "channels_first";
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
if (_channels_first) if (_channels_first)
{ {


+ 37
- 3
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="input"></param> /// <param name="input"></param>
/// <param name="is_training"></param> /// <param name="is_training"></param>
/// <returns></returns> /// <returns></returns>
public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null)
public Tensor Apply(Tensor inputs, bool is_training = false)
{ {
Tensor outputs = null; Tensor outputs = null;


@@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine
if (!built) if (!built)
MaybeBuild(inputs); MaybeBuild(inputs);


outputs = call(inputs, is_training: is_training, state: state);
outputs = call(inputs, is_training: is_training);


outputs = _set_connectivity_metadata_(inputs, outputs); outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs); _handle_activity_regularization(inputs, outputs);
@@ -161,6 +161,35 @@ namespace Tensorflow.Keras.Engine
return outputs; return outputs;
} }


public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false)
{
Tensor[] outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = name;
else
nameScope = _name_scope();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs[0]);

outputs = call(inputs, is_training: is_training, state: state);
});

return outputs;
}

private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
{ {
/*var returnOutputs = new List<Tensor>(); /*var returnOutputs = new List<Tensor>();
@@ -200,7 +229,12 @@ namespace Tensorflow.Keras.Engine
return null; return null;
} }


protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected virtual Tensor call(Tensor inputs, bool is_training = false)
{
throw new NotImplementedException("");
}

protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
Tensor outputs = null; Tensor outputs = null;




+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool training = false)
{ {
var outputs = _convolution_op.__call__(inputs, kernel); var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias) if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool training = false)
{ {
Tensor outputs = null; Tensor outputs = null;
var rank = inputs.rank; var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args; this.args = args;
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
var output = tf_utils.smart_cond(is_training, var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs, () => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
var dtype = inputs.dtype; var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64) if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray(); .ToArray();
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
return base.call(inputs, is_training, state);
return base.call(inputs, is_training);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4); input_spec = new InputSpec(ndim: 4);
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
int[] pool_shape; int[] pool_shape;
int[] strides; int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args; this.args = args;
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
scale = math_ops.cast(args.Scale, args.DType); scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType); offset = math_ops.cast(args.Offset, args.DType);


+ 43
- 6
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -61,9 +61,8 @@ namespace Tensorflow.Layers
return (results[0], results[1]); return (results[0], results[1]);
} }


public Tensor[] __call__(Tensor inputs,
public Tensor __call__(Tensor inputs,
Tensor training = null, Tensor training = null,
Tensor state = null,
VariableScope scope = null) VariableScope scope = null)
{ {
_set_scope(scope); _set_scope(scope);
@@ -88,16 +87,54 @@ namespace Tensorflow.Layers
{ {
_current_scope = scope2; _current_scope = scope2;
// Actually call layer // Actually call layer
outputs = base.Apply(inputs,
is_training: training == null ? false : false,
state: state);
outputs = base.Apply(inputs[0],
is_training: training == null ? false : false);
});


// Update global default collections.
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return outputs;
}

public Tensor[] __call__(Tensor[] inputs,
Tensor state = null,
Tensor training = null,
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(inputs, graph: _graph);

variable_scope scope_context_manager = null;
if (built)
{
scope_context_manager = tf.variable_scope(_scope,
reuse: true,
auxiliary_name_scope: false);
}
else
{
scope_context_manager = tf.variable_scope(_scope,
reuse: _reuse,
auxiliary_name_scope: false);
}

Tensor[] outputs = null;
tf_with(scope_context_manager, scope2 =>
{
_current_scope = scope2;
// Actually call layer
outputs = base.Apply(inputs,
state,
is_training: training == null ? false : false);
}); });




// Update global default collections. // Update global default collections.
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });


return new Tensor[] { outputs };
return outputs;
} }


protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)


+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -326,7 +326,7 @@ namespace Tensorflow.Operations


protected override void _AddOpInternal(Operation op) protected override void _AddOpInternal(Operation op)
{ {
if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad")
if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape")
{ {


} }


+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
built = true; built = true;
} }


public Tensor[] __call__(Tensor inputs, LSTMStateTuple state)
public Tensor __call__(Tensor inputs, LSTMStateTuple state)
{ {
_state = state; _state = state;
return base.__call__(inputs); return base.__call__(inputs);
@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param> /// <param name="training"></param>
/// <param name="state"></param> /// <param name="state"></param>
/// <returns></returns> /// <returns></returns>
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{ {
var one = constant_op.constant(1, dtype: dtypes.int32); var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency. // Parameters of gates are concatenated into one multiply for efficiency.


+ 3
- 3
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,14 +67,14 @@ namespace Tensorflow
built = true; built = true;
} }


protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
{ {
// Most basic RNN: output = new_state = act(W * input + U * state + B). // Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
var concat = array_ops.concat(new[] { inputs[0], state }, 1);
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
var output = _activation(gate_inputs, null); var output = _activation(gate_inputs, null);
return output;
return new[] { output, output };
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -364,7 +364,7 @@ namespace Tensorflow.Operations
if (sequence_length != null) if (sequence_length != null)
throw new NotImplementedException("sequence_length != null"); throw new NotImplementedException("sequence_length != null");
else else
outputs = cell.__call__(input_t_t, state: state1);
outputs = cell.__call__(new[] { input_t_t }, state: state1);


var (output, new_state) = (outputs[0], outputs[1]); var (output, new_state) = (outputs[0], outputs[1]);
// Keras cells always wrap state as list, even if it's a single tensor. // Keras cells always wrap state as list, even if it's a single tensor.


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -326,7 +326,7 @@ namespace Tensorflow
// the updated inputs are reloaded from the c_api // the updated inputs are reloaded from the c_api
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
// c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
//var updated_inputs = inputs; //var updated_inputs = inputs;
tf.Status.Check(); tf.Status.Check();
} }


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow> <TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.20.0</Version>
<Version>0.20.1</Version>
<LangVersion>8.0</LangVersion> <LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
@@ -19,13 +19,13 @@
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.20.0.0</AssemblyVersion>
<AssemblyVersion>0.20.1.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.


* Eager Mode is added finally. * Eager Mode is added finally.
* tf.keras is partially working. * tf.keras is partially working.
* tf.data is added.</PackageReleaseNotes> * tf.data is added.</PackageReleaseNotes>
<FileVersion>0.20.0.0</FileVersion>
<FileVersion>0.20.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -50,6 +50,8 @@ namespace Tensorflow
/// </summary> /// </summary>
public AllocationType AllocationType { get; protected set; } public AllocationType AllocationType { get; protected set; }


public IntPtr TensorDataPointer => TF_TensorData(_handle);

/// <summary> /// <summary>
/// Create a Tensor object from an existing TF handle /// Create a Tensor object from an existing TF handle
/// </summary> /// </summary>


+ 0
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -261,7 +261,6 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
c_api.TF_DeleteTensor(handle); c_api.TF_DeleteTensor(handle);

if (AllocationHandle == null) if (AllocationHandle == null)
return; return;




+ 67
- 64
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -88,80 +88,83 @@ namespace Tensorflow


if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
_in_graph_mode = !tf.Context.executing_eagerly(); _in_graph_mode = !tf.Context.executing_eagerly();
tf_with(ops.name_scope(name, "Variable"), scope =>
tf_with(ops.init_scope2(), delegate
{ {
name = scope;
var handle_name = ops.name_from_scope_name(name);
string unique_id = "";
string shared_name = "";

if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
{ {
unique_id = $"{handle_name}_{ops.uid()}";
shared_name = tf.Context.shared_name();
}

var attr = new AttrValue();
attr.List = new AttrValue.Types.ListValue();
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
tf_with(ops.name_scope("Initializer"), delegate
{
if (initial_value.GetType().GetInterface("IInitializer") != null)
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
name = scope;
var handle_name = ops.name_from_scope_name(name);
string unique_id = "";
string shared_name = "";

if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else else
{ {
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value;
initial_value = ops.convert_to_tensor(value,
name: "initial_value",
dtype: dtype);
unique_id = $"{handle_name}_{ops.uid()}";
shared_name = tf.Context.shared_name();
} }
});
_shape = shape ?? (initial_value as Tensor).TensorShape;
_initial_value = initial_value as Tensor;


var attr = new AttrValue();
attr.List = new AttrValue.Types.ListValue();
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
tf_with(ops.name_scope("Initializer"), delegate
{
if (initial_value.GetType().GetInterface("IInitializer") != null)
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
else
{
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value;
initial_value = ops.convert_to_tensor(value,
name: "initial_value",
dtype: dtype);
}
});
_shape = shape ?? (initial_value as Tensor).TensorShape;
_initial_value = initial_value as Tensor;



if (_in_graph_mode)
{
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;


if (_in_graph_mode)
{
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;
ops.colocate_with(initializer_op);


ops.colocate_with(initializer_op);
_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
}
else
{
handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value,
shape: _shape,
shared_name: shared_name,
name: name,
graph_mode: _in_graph_mode);

gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
is_initialized_op = null;
initializer_op = null;
_graph_element = null;
_dtype = _initial_value.dtype.as_base_dtype();
initial_value = _in_graph_mode ? initial_value : null;
}


_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
}
else
{
handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value,
shape: _shape,
shared_name: shared_name,
name: name,
graph_mode: _in_graph_mode);

gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
is_initialized_op = null;
initializer_op = null;
_graph_element = null;
_dtype = _initial_value.dtype.as_base_dtype();
initial_value = _in_graph_mode ? initial_value : null;
}

base.__init__(trainable: trainable,
handle: handle,
name: name,
unique_id: unique_id,
handle_name: handle_name);
base.__init__(trainable: trainable,
handle: handle,
name: name,
unique_id: unique_id,
handle_name: handle_name);
});
}); });
} }




+ 1
- 0
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -30,6 +30,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" /> <PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
<PackageReference Include="TensorFlow.NET" Version="0.20.0" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -43,7 +43,7 @@


<ItemGroup> <ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" /> <PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> <PackageReference Include="MSTest.TestFramework" Version="2.1.2" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />


Loading…
Cancel
Save