Browse Source

ops.colocate_with

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
0511614861
5 changed files with 36 additions and 1 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs
  2. +7
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +11
    -0
      src/TensorFlowNET.Core/ops.py.cs

+ 14
- 0
src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class Graph
{
public void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
{

}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -72,7 +72,13 @@ namespace Tensorflow
}
else
{
input_types.Add(value.dtype);
var base_type = value.dtype;
// base type
if ((int)value.dtype > 100)
{
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString());
}
input_types.Add(base_type);
}
}
}


+ 2
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -124,6 +124,8 @@ namespace Tensorflow
}
}

private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();

private NodeDef _node_def;
public NodeDef node_def
{


+ 2
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -99,6 +99,8 @@ namespace Tensorflow
}
else
{
ops.colocate_with(_initializer_op);

_snapshot = gen_array_ops.identity(_variable, name = "read");
}



+ 11
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -185,5 +185,16 @@ namespace Tensorflow
{
return uid_number++;
}

public static void colocate_with(Operation op, bool ignore_existing = false)
{
_colocate_with_for_gradient(op, null, ignore_existing);
}

private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
{
var default_graph = get_default_graph();
default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
}
}
}

Loading…
Cancel
Save