Browse Source

fix output dtype error for control_flow.ref_switch

tags/v0.12
Oceania2018 6 years ago
parent
commit
389f7dd2e3
8 changed files with 21 additions and 17 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  2. +3
    -4
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +3
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  5. +3
    -3
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +1
    -1
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  7. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  8. +2
    -2
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
public partial class Operation
{
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index));

public int OutputListLength(string name)
{


+ 3
- 4
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -44,7 +44,6 @@ namespace Tensorflow
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private readonly Graph _graph;
private NodeDef _node_def;

@@ -91,7 +90,7 @@ namespace Tensorflow
{
_graph = g;

_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
var _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
lock (Locks.ProcessWide)
using (var status = new Status())
@@ -161,7 +160,7 @@ namespace Tensorflow
op_def = g.GetOpDef(node_def.Op);

var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
@@ -170,7 +169,7 @@ namespace Tensorflow

_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
_outputs[i] = new Tensor(this, i, output_types[i]);

graph._add_op(this);



+ 3
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -275,7 +275,7 @@ namespace Tensorflow
/// </returns>
public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
{
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
data = ops.convert_to_tensor_or_composite(data, name: "data");
// NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
// addresses the following scenario.
//
@@ -296,9 +296,8 @@ namespace Tensorflow
{
if (data is Tensor)
{
// TODO: ref_switch
//if (data.dtype._is_ref_dtype)
// return control_flow_ops.ref_switch(data, pred, name = name);
if (data.dtype.is_ref_dtype())
return gen_control_flow_ops.ref_switch(data, pred, name: name);
}
return @switch(data, pred, name: name);
}


+ 6
- 0
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -114,6 +114,12 @@ namespace Tensorflow
return _op;
}

public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred });
return _op.outputs;
}

/// <summary>
/// Forwards `data` to the output port determined by `pred`.
///


+ 3
- 3
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.11.2</Version>
<Version>0.11.3</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,7 +17,7 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.2.0</AssemblyVersion>
<AssemblyVersion>0.11.3.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.
2. Add DisposableObject class to manage object lifetime.
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
9. MultiThread is safe.
10. Support n-dim indexing for tensor.</PackageReleaseNotes>
<LangVersion>7.3</LangVersion>
<FileVersion>0.11.2.0</FileVersion>
<FileVersion>0.11.3.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 1
- 1
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -117,7 +117,7 @@ namespace Tensorflow
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_;
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;

public string TRAIN_OP => TRAIN_OP_;



+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -206,7 +206,7 @@ namespace Tensorflow
/// </param>
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
/// <returns>A wrapped TF_Operation*.</returns>
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
{
lock (Locks.ProcessWide)
{
@@ -249,7 +249,7 @@ namespace Tensorflow

status.Check(true);

return (c_op, op_desc);
return c_op;
}
}



+ 2
- 2
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test
using (var g = tf.Graph().as_default())
{
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}});
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op);
Assert.AreEqual("myop", op.name);
@@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test
var true_fn = new Func<Tensor>(() =>
{
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var new_ops = g._add_new_tf_operations();
self.assertEqual(len(new_ops), 1);
return x;


Loading…
Cancel
Save