Browse Source

fix output dtype error for control_flow.ref_switch

tags/v0.12
Oceania2018 6 years ago
parent
commit
389f7dd2e3
8 changed files with 21 additions and 17 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  2. +3
    -4
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +3
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  5. +3
    -3
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +1
    -1
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  7. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  8. +2
    -2
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
public partial class Operation public partial class Operation
{ {
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index));


public int OutputListLength(string name) public int OutputListLength(string name)
{ {


+ 3
- 4
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -44,7 +44,6 @@ namespace Tensorflow
public partial class Operation : ITensorOrOperation public partial class Operation : ITensorOrOperation
{ {
private readonly IntPtr _handle; // _c_op in python private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private readonly Graph _graph; private readonly Graph _graph;
private NodeDef _node_def; private NodeDef _node_def;


@@ -91,7 +90,7 @@ namespace Tensorflow
{ {
_graph = g; _graph = g;


_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
var _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
using (var status = new Status()) using (var status = new Status())
@@ -161,7 +160,7 @@ namespace Tensorflow
op_def = g.GetOpDef(node_def.Op); op_def = g.GetOpDef(node_def.Op);


var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());


// Initialize self._outputs. // Initialize self._outputs.
output_types = new TF_DataType[NumOutputs]; output_types = new TF_DataType[NumOutputs];
@@ -170,7 +169,7 @@ namespace Tensorflow


_outputs = new Tensor[NumOutputs]; _outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
_outputs[i] = new Tensor(this, i, output_types[i]);


graph._add_op(this); graph._add_op(this);




+ 3
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -275,7 +275,7 @@ namespace Tensorflow
/// </returns> /// </returns>
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{ {
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
data = ops.convert_to_tensor_or_composite(data, name: "data");
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
// addresses the following scenario. // addresses the following scenario.
// //
@@ -296,9 +296,8 @@ namespace Tensorflow
{ {
if (data is Tensor) if (data is Tensor)
{ {
// TODO: ref_switch
//if (data.dtype._is_ref_dtype)
// return control_flow_ops.ref_switch(data, pred, name = name);
if (data.dtype.is_ref_dtype())
return gen_control_flow_ops.ref_switch(data, pred, name: name);
} }
return @switch(data, pred, name: name); return @switch(data, pred, name: name);
} }


+ 6
- 0
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -114,6 +114,12 @@ namespace Tensorflow
return _op; return _op;
} }


public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred });
return _op.outputs;
}

/// <summary> /// <summary>
/// Forwards `data` to the output port determined by `pred`. /// Forwards `data` to the output port determined by `pred`.
/// ///


+ 3
- 3
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow> <TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.11.2</Version>
<Version>0.11.3</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,7 +17,7 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description> Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.2.0</AssemblyVersion>
<AssemblyVersion>0.11.3.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0: <PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20. 1. Upgrade NumSharp to v0.20.
2. Add DisposableObject class to manage object lifetime. 2. Add DisposableObject class to manage object lifetime.
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
9. MultiThread is safe. 9. MultiThread is safe.
10. Support n-dim indexing for tensor.</PackageReleaseNotes> 10. Support n-dim indexing for tensor.</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.11.2.0</FileVersion>
<FileVersion>0.11.3.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 1
- 1
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -117,7 +117,7 @@ namespace Tensorflow
/// Key to collect Variable objects that are global (shared across machines). /// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones. /// Default collection for all variables, except local ones.
/// </summary> /// </summary>
public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_;
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;


public string TRAIN_OP => TRAIN_OP_; public string TRAIN_OP => TRAIN_OP_;




+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -206,7 +206,7 @@ namespace Tensorflow
/// </param> /// </param>
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
/// <returns>A wrapped TF_Operation*.</returns> /// <returns>A wrapped TF_Operation*.</returns>
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
{ {
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
@@ -249,7 +249,7 @@ namespace Tensorflow


status.Check(true); status.Check(true);


return (c_op, op_desc);
return c_op;
} }
} }




+ 2
- 2
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test
using (var g = tf.Graph().as_default()) using (var g = tf.Graph().as_default())
{ {
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}});
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op); var op = g._create_op_from_tf_operation(c_op);
Assert.AreEqual("myop", op.name); Assert.AreEqual("myop", op.name);
@@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test
var true_fn = new Func<Tensor>(() => var true_fn = new Func<Tensor>(() =>
{ {
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var new_ops = g._add_new_tf_operations(); var new_ops = g._add_new_tf_operations();
self.assertEqual(len(new_ops), 1); self.assertEqual(len(new_ops), 1);
return x; return x;


Loading…
Cancel
Save