Browse Source

add Gradient for GatherV2, MaxPool, op.

tags/v0.9
Oceania2018 6 years ago
parent
commit
8e2e31cc67
15 changed files with 224 additions and 61 deletions
  1. +6
    -0
      TensorFlow.NET.sln
  2. +1
    -0
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  3. +13
    -0
      src/TensorFlowNET.Core/Framework/CompositeTensor.cs
  4. +25
    -0
      src/TensorFlowNET.Core/Framework/IndexedSlices.cs
  5. +50
    -4
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  6. +29
    -10
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  7. +17
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  8. +17
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  10. +5
    -2
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  11. +7
    -2
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  12. +30
    -16
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  13. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  14. +15
    -15
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  15. +6
    -3
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 6
- 0
TensorFlow.NET.sln View File

@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\Kera
EndProject EndProject
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}"
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{92762DCB-64C8-41B4-BEF7-780A969CE68F}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU Debug|Any CPU = Debug|Any CPU
@@ -51,6 +53,10 @@ Global
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE


+ 1
- 0
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -142,6 +142,7 @@ namespace Tensorflow


var layer = new Dense(units, activation, var layer = new Dense(units, activation,
use_bias: use_bias, use_bias: use_bias,
bias_initializer: bias_initializer,
kernel_initializer: kernel_initializer); kernel_initializer: kernel_initializer);


return layer.apply(inputs); return layer.apply(inputs);


+ 13
- 0
src/TensorFlowNET.Core/Framework/CompositeTensor.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
/// <summary>
/// Abstract base class for Tensor-like objects that are composed from Tensors.
/// </summary>
public abstract class CompositeTensor
{
}
}

+ 25
- 0
src/TensorFlowNET.Core/Framework/IndexedSlices.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
/// <summary>
/// A sparse representation of a set of tensor slices at given indices.
/// </summary>
public class IndexedSlices : CompositeTensor
{
Tensor _values;
public Tensor values => _values;

public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
{

}

public static implicit operator Tensor(IndexedSlices indexedSlices)
{
return indexedSlices.values;
}
}
}

+ 50
- 4
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Framework;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Python; using static Tensorflow.Python;


@@ -42,9 +43,9 @@ namespace Tensorflow.Gradients
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };


var concat_dim = op.inputs[dim_index]; var concat_dim = op.inputs[dim_index];
if (end_value_index == -1)
end_value_index = op.inputs.Length - 1;
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();
var input_values = op.inputs._inputs.Skip(start_value_index)
.Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index)
.ToArray();


var out_grads = new List<Tensor>(); var out_grads = new List<Tensor>();
if (constant_op.is_constant(concat_dim)) if (constant_op.is_constant(concat_dim))
@@ -92,10 +93,16 @@ namespace Tensorflow.Gradients
} }


return (end_value_index <= dim_index ? return (end_value_index <= dim_index ?
out_grads.ToArray().Concat(null) :
out_grads.ToArray().Concat(new Tensor[] { null }) :
new Tensor[] { null }.Concat(out_grads)).ToArray(); new Tensor[] { null }.Concat(out_grads)).ToArray();
} }


[RegisterGradient("ExpandDims")]
public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]), null };
}

/// <summary> /// <summary>
/// Extract the shapes of a set of input tensors. /// Extract the shapes of a set of input tensors.
/// </summary> /// </summary>
@@ -125,6 +132,45 @@ namespace Tensorflow.Gradients
return gen_ops.shape_n(inputs); return gen_ops.shape_n(inputs);
} }


/// <summary>
/// Gradient for GatherV2 op.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("GatherV2")]
public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var @params = op.inputs[0];
ops.colocate_with(@params);

var params_shape = array_ops.shape(@params, out_type: tf.int64);
params_shape = math_ops.cast(params_shape, tf.int32);

var indices = op.inputs[1];
var indices_size = array_ops.expand_dims(array_ops.size(indices), 0);
var axis = op.inputs[2];
var axis_static = tensor_util.constant_value(axis);

// For axis 0 gathers, build an appropriately shaped IndexedSlices.
if((int)axis_static == 0)
{
var params_tail_shape = params_shape[1];
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
var values = array_ops.reshape(grad, values_shape);
indices = array_ops.reshape(indices, indices_size);
return new Tensor[]
{
new IndexedSlices(values, indices, params_shape),
null,
null
};
}

return new Tensor[] { null, null };
}

[RegisterGradient("Reshape")] [RegisterGradient("Reshape")]
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{ {


+ 29
- 10
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -106,10 +106,10 @@ namespace Tensorflow.Gradients
[RegisterGradient("Conv2D")] [RegisterGradient("Conv2D")]
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
{ {
var dilations = op.get_attr("dilations");
var strides = op.get_attr("strides");
var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
var padding = op.get_attr("padding"); var padding = op.get_attr("padding");
var explicit_paddings = op.get_attr("explicit_paddings");
var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu");
var data_format = op.get_attr("data_format"); var data_format = op.get_attr("data_format");
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
@@ -120,21 +120,23 @@ namespace Tensorflow.Gradients
{ {
InputSizes = shape[0], InputSizes = shape[0],
Filter = op.inputs[1], Filter = op.inputs[1],
Dilations = dilations == null ? null : dilations as int[],
Strides = strides == null ? null : strides as int[],
OutBackProp = grads[0],
Dilations = dilations,
Strides = strides,
Padding = padding.ToString(), Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
ExplicitPaddings = explicit_paddings,
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString()
DataFormat = data_format.ToString(),
}), }),
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
{ {
Input = op.inputs[0], Input = op.inputs[0],
FilterSizes = shape[1], FilterSizes = shape[1],
Dilations = dilations == null ? null : dilations as int[],
Strides = strides == null ? null : strides as int[],
OutBackProp = grads[0],
Dilations = dilations,
Strides = strides,
Padding = padding.ToString(), Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
ExplicitPaddings = explicit_paddings,
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString() DataFormat = data_format.ToString()
}) })
@@ -155,6 +157,23 @@ namespace Tensorflow.Gradients
return vec * mat; return vec * mat;
} }


[RegisterGradient("MaxPool")]
public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
return new Tensor[]
{
gen_nn_ops.max_pool_grad(
op.inputs[0],
op.outputs[0],
grad,
(op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(),
(op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(),
padding: op.get_attr("padding").ToString(),
data_format: op.get_attr("data_format").ToString())
};
}

/// <summary> /// <summary>
/// Return the gradients for TopK. /// Return the gradients for TopK.
/// </summary> /// </summary>


+ 17
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -179,6 +179,23 @@ namespace Tensorflow.Operations
return _op.outputs[0]; return _op.outputs[0];
} }


public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding,
string data_format= "NHWC", string name= null)
{
var _op = _op_def_lib._apply_op_helper("MaxPoolGrad", name: name, args: new
{
orig_input,
orig_output,
grad,
ksize,
strides,
padding,
data_format
});

return _op.outputs[0];
}

public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new


+ 17
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,5 +1,7 @@
using Google.Protobuf.Collections; using Google.Protobuf.Collections;
//using Newtonsoft.Json;
#if GRAPH_SERIALIZE
using Newtonsoft.Json;
#endif
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@@ -33,16 +35,23 @@ namespace Tensorflow
private readonly IntPtr _operDesc; private readonly IntPtr _operDesc;


private Graph _graph; private Graph _graph;
//[JsonIgnore]
public string type => OpType;

#if GRAPH_SERIALIZE
[JsonIgnore]
public Graph graph => _graph;
[JsonIgnore]
public int _id => _id_value;
[JsonIgnore]
public int _id_value;
[JsonIgnore]
public Operation op => this;
#else
public Graph graph => _graph; public Graph graph => _graph;
//[JsonIgnore]
public int _id => _id_value; public int _id => _id_value;
//[JsonIgnore]
public int _id_value; public int _id_value;

public string type => OpType;
//[JsonIgnore]
public Operation op => this; public Operation op => this;
#endif
public TF_DataType dtype => TF_DataType.DtInvalid; public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status(); private Status status = new Status();


@@ -51,7 +60,7 @@ namespace Tensorflow
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));


private NodeDef _node_def; private NodeDef _node_def;
//[JsonIgnore]
[JsonIgnore]
public NodeDef node_def public NodeDef node_def
{ {
get get


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -277,7 +277,7 @@ namespace Tensorflow
var input_shape = tensor_util.to_shape(input_tensor.shape); var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
{ {
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
return constant_op.constant(nd, name: name); return constant_op.constant(nd, name: name);
} }
} }


+ 5
- 2
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -123,7 +123,7 @@ namespace Tensorflow
return with(ops.name_scope(name, "tuple", tensors), scope => return with(ops.name_scope(name, "tuple", tensors), scope =>
{ {
name = scope; name = scope;
var gating_ops = tensors.Select(x => x.op).ToList();
var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList();


if(control_inputs != null) if(control_inputs != null)
{ {
@@ -139,7 +139,10 @@ namespace Tensorflow
var tpl = new List<Tensor>(); var tpl = new List<Tensor>();
foreach(var t in tensors) foreach(var t in tensors)
{ {
tpl.Add(with_dependencies(new Operation[] { gate }, t));
if (t != null)
tpl.Add(with_dependencies(new Operation[] { gate }, t));
else
tpl.Add(null);
} }


return tpl.ToArray(); return tpl.ToArray();


+ 7
- 2
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -29,7 +29,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -47,7 +47,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.7.0" />
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="NumSharp" Version="0.10.2" /> <PackageReference Include="NumSharp" Version="0.10.2" />
</ItemGroup> </ItemGroup>


@@ -62,4 +63,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<Folder Include="Keras\Initializers\" /> <Folder Include="Keras\Initializers\" />
</ItemGroup> </ItemGroup>


<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project> </Project>

+ 30
- 16
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -1,4 +1,6 @@
//using Newtonsoft.Json;
#if GRAPH_SERIALIZE
using Newtonsoft.Json;
#endif
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@@ -19,15 +21,22 @@ namespace Tensorflow
private readonly IntPtr _handle; private readonly IntPtr _handle;


private int _id; private int _id;
//[JsonIgnore]
private Operation _op;
#if GRAPH_SERIALIZE
[JsonIgnore]
public int Id => _id;
[JsonIgnore]
public Graph graph => op?.graph;
[JsonIgnore]
public Operation op => _op;
[JsonIgnore]
public Tensor[] outputs => op.outputs;
#else
public int Id => _id; public int Id => _id;
//[JsonIgnore]
public Graph graph => op?.graph; public Graph graph => op?.graph;
private Operation _op;
//[JsonIgnore]
public Operation op => _op; public Operation op => _op;
//[JsonIgnore]
public Tensor[] outputs => op.outputs; public Tensor[] outputs => op.outputs;
#endif


/// <summary> /// <summary>
/// The string name of this tensor. /// The string name of this tensor.
@@ -210,11 +219,11 @@ namespace Tensorflow
} }
} }


public Tensor this[int slice_spec]
public Tensor this[int start, int? stop, int? step]
{ {
get get
{ {
var slice_spec_s = new int[] { slice_spec };
var slice_spec = new int[] { start };
var begin = new List<int>(); var begin = new List<int>();
var end = new List<int>(); var end = new List<int>();
var strides = new List<int>(); var strides = new List<int>();
@@ -224,22 +233,25 @@ namespace Tensorflow
var (begin_mask, end_mask) = (0, 0); var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0; var ellipsis_mask = 0;


foreach(var s in slice_spec_s)
foreach (var s in slice_spec)
{ {
begin.Add(s);
if (stop == null)
{ {
begin.Add(s);
end.Add(s + 1);
strides.Add(1);
shrink_axis_mask |= (1 << index);
end.Add(0);
end_mask |= (1 << index);
} }
else
end.Add(s + 1);
strides.Add(1);

index += 1; index += 1;
} }


return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
{ {
string name = scope; string name = scope;
if(begin != null)
if (begin != null)
{ {
var (packed_begin, packed_end, packed_strides) = var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()), (array_ops.stack(begin.ToArray()),
@@ -256,15 +268,17 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask, shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask, new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask, ellipsis_mask: ellipsis_mask,

name: name); name: name);
} }


throw new NotImplementedException(""); throw new NotImplementedException("");
}); });
} }
} }


public Tensor this[int slice_spec] => this[slice_spec, null, null];

public override string ToString() public override string ToString()
{ {
// this can throw IndexOutOfRangeException // this can throw IndexOutOfRangeException


+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -16,6 +16,8 @@ namespace Tensorflow
{ {
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
return typeof(bool); return typeof(bool);
case TF_DataType.TF_INT64:
return typeof(long);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
return typeof(int); return typeof(int);
case TF_DataType.TF_INT16: case TF_DataType.TF_INT16:


+ 15
- 15
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -57,24 +57,24 @@ namespace Tensorflow
if (initializer is IInitializer init) if (initializer is IInitializer init)
{ {
return _get_single_variable(name: name, return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: init,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
shape: shape,
dtype: dtype,
initializer: init,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
} }
else if (initializer is Tensor tensor) else if (initializer is Tensor tensor)
{ {
return _get_single_variable(name: name, return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: tensor,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
shape: shape,
dtype: dtype,
initializer: tensor,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
} }
else else
{ {
@@ -141,7 +141,7 @@ namespace Tensorflow
v = variable_scope.default_variable_creator(init_val, v = variable_scope.default_variable_creator(init_val,
name: name, name: name,
trainable: trainable, trainable: trainable,
dtype: TF_DataType.DtInvalid,
dtype: variable_dtype,
validate_shape: validate_shape, validate_shape: validate_shape,
synchronization: synchronization, synchronization: synchronization,
aggregation: aggregation); aggregation: aggregation);


+ 6
- 3
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -5,6 +5,7 @@ using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Newtonsoft.Json;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
@@ -197,17 +198,17 @@ namespace TensorFlowNET.Examples


var h_pool = tf.concat(pooled_outputs, 3); var h_pool = tf.concat(pooled_outputs, 3);
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));
Tensor h_drop = null;
with(tf.name_scope("dropout"), delegate with(tf.name_scope("dropout"), delegate
{ {
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
}); });


Tensor logits = null; Tensor logits = null;
Tensor predictions = null; Tensor predictions = null;
with(tf.name_scope("output"), delegate with(tf.name_scope("output"), delegate
{ {
logits = tf.layers.dense(h_pool_flat, NUM_CLASS);
logits = tf.layers.dense(h_drop, NUM_CLASS);
predictions = tf.argmax(logits, -1, output_type: tf.int32); predictions = tf.argmax(logits, -1, output_type: tf.int32);
}); });


@@ -307,6 +308,8 @@ namespace TensorFlowNET.Examples
{ {
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); var graph = IsImportingGraph ? ImportGraph() : BuildGraph();


var imported_graph = JsonConvert.SerializeObject(graph, new JsonSerializerSettings { Formatting = Formatting.Indented });

return with(tf.Session(graph), sess => Train(sess, graph)); return with(tf.Session(graph), sess => Train(sess, graph));
} }




Loading…
Cancel
Save