Browse Source

add overload for Layer call function, be able to input array and return array.

v0.20-tensorflow2.3
Oceania2018 Haiping 5 years ago
parent
commit
f226ad704f
28 changed files with 199 additions and 98 deletions
  1. +17
    -0
      src/TensorFlowNET.Console/MemoryTestingCases.cs
  2. +3
    -0
      src/TensorFlowNET.Console/Program.cs
  3. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  6. +4
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  8. +37
    -3
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  17. +43
    -6
      src/TensorFlowNET.Core/Layers/Layer.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  19. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  20. +3
    -3
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  22. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  23. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  24. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  25. +0
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  26. +67
    -64
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  27. +1
    -0
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  28. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

+ 17
- 0
src/TensorFlowNET.Console/MemoryTestingCases.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -18,6 +19,22 @@ namespace Tensorflow
var tensor = tf.constant(3112.0f);
}
};

public Action<int> Constant2x3
=> (iterate) =>
{
var nd = np.array(new byte[,]
{
{1, 2, 3},
{4, 5, 6}
});
for (int i = 0; i < iterate; i++)
{
var tensor = tf.constant(nd);
var data = tensor.numpy();
}
};

public Action<int> Variable
=> (iterate) =>
{


+ 3
- 0
src/TensorFlowNET.Console/Program.cs View File

@@ -15,6 +15,9 @@ namespace Tensorflow

int batchSize = 1000;

// explaination of constant
mm.Execute(10, 100 * batchSize, cases.Constant2x3);

// 1 million float tensor 68M.
mm.Execute(10, 100 * batchSize, cases.Constant);



+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
/// </summary>
public partial class c_api
{
public const string TensorFlowLibName = "tensorflow";
public const string TensorFlowLibName = @"C:\Users\haipi\Documents\Projects\tensorflow\bazel-bin\tensorflow\tensorflow";

public static string StringPiece(IntPtr handle)
{


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -70,8 +70,8 @@ namespace Tensorflow.Eager

protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
c_api.TF_DeleteTensor(_handle);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -311,6 +311,7 @@ namespace Tensorflow
while (queue.Count > 0)
{
var op = queue.Dequeue();

if (reached_ops.Contains(op))
{
between_ops.Add(op);


+ 4
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -278,7 +278,11 @@ namespace Tensorflow
// after removing the trailing '/'.
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, attrs: attrs);
if (name == "rnn/while/basic_rnn_cell/MatMul"
|| name == "rnn/while/basic_rnn_cell/MatMul/Enter")
{

}
var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);



+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
if (_channels_first)
{


+ 37
- 3
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="input"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null)
public Tensor Apply(Tensor inputs, bool is_training = false)
{
Tensor outputs = null;

@@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, is_training: is_training, state: state);
outputs = call(inputs, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
@@ -161,6 +161,35 @@ namespace Tensorflow.Keras.Engine
return outputs;
}

public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false)
{
Tensor[] outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = name;
else
nameScope = _name_scope();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs[0]);

outputs = call(inputs, is_training: is_training, state: state);
});

return outputs;
}

private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
{
/*var returnOutputs = new List<Tensor>();
@@ -200,7 +229,12 @@ namespace Tensorflow.Keras.Engine
return null;
}

protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected virtual Tensor call(Tensor inputs, bool is_training = false)
{
throw new NotImplementedException("");
}

protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
{
throw new NotImplementedException("");
}


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
Tensor outputs = null;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool training = false)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool training = false)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray();
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
return base.call(inputs, is_training, state);
return base.call(inputs, is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
int[] pool_shape;
int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);


+ 43
- 6
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -61,9 +61,8 @@ namespace Tensorflow.Layers
return (results[0], results[1]);
}

public Tensor[] __call__(Tensor inputs,
public Tensor __call__(Tensor inputs,
Tensor training = null,
Tensor state = null,
VariableScope scope = null)
{
_set_scope(scope);
@@ -88,16 +87,54 @@ namespace Tensorflow.Layers
{
_current_scope = scope2;
// Actually call layer
outputs = base.Apply(inputs,
is_training: training == null ? false : false,
state: state);
outputs = base.Apply(inputs[0],
is_training: training == null ? false : false);
});


// Update global default collections.
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return outputs;
}

public Tensor[] __call__(Tensor[] inputs,
Tensor state = null,
Tensor training = null,
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(inputs, graph: _graph);

variable_scope scope_context_manager = null;
if (built)
{
scope_context_manager = tf.variable_scope(_scope,
reuse: true,
auxiliary_name_scope: false);
}
else
{
scope_context_manager = tf.variable_scope(_scope,
reuse: _reuse,
auxiliary_name_scope: false);
}

Tensor[] outputs = null;
tf_with(scope_context_manager, scope2 =>
{
_current_scope = scope2;
// Actually call layer
outputs = base.Apply(inputs,
state,
is_training: training == null ? false : false);
});


// Update global default collections.
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return new Tensor[] { outputs };
return outputs;
}

protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)


+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -326,7 +326,7 @@ namespace Tensorflow.Operations

protected override void _AddOpInternal(Operation op)
{
if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad")
if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape")
{

}


+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
built = true;
}

public Tensor[] __call__(Tensor inputs, LSTMStateTuple state)
public Tensor __call__(Tensor inputs, LSTMStateTuple state)
{
_state = state;
return base.__call__(inputs);
@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param>
/// <param name="state"></param>
/// <returns></returns>
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor call(Tensor inputs, bool is_training = false)
{
var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency.


+ 3
- 3
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,14 +67,14 @@ namespace Tensorflow
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null)
protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
var concat = array_ops.concat(new[] { inputs[0], state }, 1);
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
var output = _activation(gate_inputs, null);
return output;
return new[] { output, output };
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -364,7 +364,7 @@ namespace Tensorflow.Operations
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
outputs = cell.__call__(input_t_t, state: state1);
outputs = cell.__call__(new[] { input_t_t }, state: state1);

var (output, new_state) = (outputs[0], outputs[1]);
// Keras cells always wrap state as list, even if it's a single tensor.


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -326,7 +326,7 @@ namespace Tensorflow
// the updated inputs are reloaded from the c_api
lock (Locks.ProcessWide)
{
// c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
//var updated_inputs = inputs;
tf.Status.Check();
}


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.20.0</Version>
<Version>0.20.1</Version>
<LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
@@ -19,13 +19,13 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.20.0.0</AssemblyVersion>
<AssemblyVersion>0.20.1.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.

* Eager Mode is added finally.
* tf.keras is partially working.
* tf.data is added.</PackageReleaseNotes>
<FileVersion>0.20.0.0</FileVersion>
<FileVersion>0.20.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -50,6 +50,8 @@ namespace Tensorflow
/// </summary>
public AllocationType AllocationType { get; protected set; }

public IntPtr TensorDataPointer => TF_TensorData(_handle);

/// <summary>
/// Create a Tensor object from an existing TF handle
/// </summary>


+ 0
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -261,7 +261,6 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteTensor(handle);

if (AllocationHandle == null)
return;



+ 67
- 64
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -88,80 +88,83 @@ namespace Tensorflow

if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
_in_graph_mode = !tf.Context.executing_eagerly();
tf_with(ops.name_scope(name, "Variable"), scope =>
tf_with(ops.init_scope2(), delegate
{
name = scope;
var handle_name = ops.name_from_scope_name(name);
string unique_id = "";
string shared_name = "";

if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
{
unique_id = $"{handle_name}_{ops.uid()}";
shared_name = tf.Context.shared_name();
}

var attr = new AttrValue();
attr.List = new AttrValue.Types.ListValue();
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
tf_with(ops.name_scope("Initializer"), delegate
{
if (initial_value.GetType().GetInterface("IInitializer") != null)
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
name = scope;
var handle_name = ops.name_from_scope_name(name);
string unique_id = "";
string shared_name = "";

if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else
{
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value;
initial_value = ops.convert_to_tensor(value,
name: "initial_value",
dtype: dtype);
unique_id = $"{handle_name}_{ops.uid()}";
shared_name = tf.Context.shared_name();
}
});
_shape = shape ?? (initial_value as Tensor).TensorShape;
_initial_value = initial_value as Tensor;

var attr = new AttrValue();
attr.List = new AttrValue.Types.ListValue();
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
tf_with(ops.name_scope("Initializer"), delegate
{
if (initial_value.GetType().GetInterface("IInitializer") != null)
initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
else
{
var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value;
initial_value = ops.convert_to_tensor(value,
name: "initial_value",
dtype: dtype);
}
});
_shape = shape ?? (initial_value as Tensor).TensorShape;
_initial_value = initial_value as Tensor;



if (_in_graph_mode)
{
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;

if (_in_graph_mode)
{
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;
ops.colocate_with(initializer_op);

ops.colocate_with(initializer_op);
_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
}
else
{
handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value,
shape: _shape,
shared_name: shared_name,
name: name,
graph_mode: _in_graph_mode);

gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
is_initialized_op = null;
initializer_op = null;
_graph_element = null;
_dtype = _initial_value.dtype.as_base_dtype();
initial_value = _in_graph_mode ? initial_value : null;
}

_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
}
else
{
handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value,
shape: _shape,
shared_name: shared_name,
name: name,
graph_mode: _in_graph_mode);

gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
is_initialized_op = null;
initializer_op = null;
_graph_element = null;
_dtype = _initial_value.dtype.as_base_dtype();
initial_value = _in_graph_mode ? initial_value : null;
}

base.__init__(trainable: trainable,
handle: handle,
name: name,
unique_id: unique_id,
handle_name: handle_name);
base.__init__(trainable: trainable,
handle: handle,
name: name,
unique_id: unique_id,
handle_name: handle_name);
});
});
}



+ 1
- 0
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -30,6 +30,7 @@
<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
<PackageReference Include="TensorFlow.NET" Version="0.20.0" />
</ItemGroup>

<ItemGroup>


+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -43,7 +43,7 @@

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />


Loading…
Cancel
Save