Browse Source

as_base_type #136

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
0ca9485dea
5 changed files with 14 additions and 11 deletions
  1. +0
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -7
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 0
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -3,7 +3,6 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using TF_DataType = Tensorflow.DataType;

namespace Tensorflow
{


+ 3
- 7
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -72,12 +72,8 @@ namespace Tensorflow
}
else
{
var base_type = value.dtype;
// base type
if ((int)value.dtype > 100)
{
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString());
}
var base_type = value.dtype.as_base_dtype();
input_types.Add(base_type);
}
}
@@ -151,7 +147,7 @@ namespace Tensorflow

public DataType _MakeType(TF_DataType v, AttrDef attr_def)
{
return v.as_datatype_enum();
return v.as_base_dtype().as_datatype_enum();
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -24,6 +24,7 @@ namespace Tensorflow
throw new NotImplementedException("as_numpy_datatype failed");
}
}

public static TF_DataType as_dtype(Type type)
{
TF_DataType dtype = TF_DataType.DtInvalid;
@@ -62,5 +63,12 @@ namespace Tensorflow

return dtype;
}

public static TF_DataType as_base_dtype(this TF_DataType type)
{
return (int)type > 100 ?
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) :
type;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow

private Operation _initializer_op;
public Operation initializer => _initializer_op;
public Operation op => _initializer_op;
public Operation op => _variable.op;

public string name => _variable.name;

@@ -77,7 +77,7 @@ namespace Tensorflow

var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
}

// Manually overrides the variable's shape with the initial value's.


+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Add()
{
var x = tf.Variable(0, name: "x");
var x = tf.Variable(10, name: "x");

var model = tf.global_variables_initializer();



Loading…
Cancel
Save