Browse Source

as_base_type #136

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
0ca9485dea
5 changed files with 14 additions and 11 deletions
  1. +0
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -7
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 0
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -3,7 +3,6 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using TF_DataType = Tensorflow.DataType;


namespace Tensorflow namespace Tensorflow
{ {


+ 3
- 7
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -72,12 +72,8 @@ namespace Tensorflow
} }
else else
{ {
var base_type = value.dtype;
// base type
if ((int)value.dtype > 100)
{
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString());
}
var base_type = value.dtype.as_base_dtype();
input_types.Add(base_type); input_types.Add(base_type);
} }
} }
@@ -151,7 +147,7 @@ namespace Tensorflow


public DataType _MakeType(TF_DataType v, AttrDef attr_def) public DataType _MakeType(TF_DataType v, AttrDef attr_def)
{ {
return v.as_datatype_enum();
return v.as_base_dtype().as_datatype_enum();
} }
} }
} }

+ 8
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -24,6 +24,7 @@ namespace Tensorflow
throw new NotImplementedException("as_numpy_datatype failed"); throw new NotImplementedException("as_numpy_datatype failed");
} }
} }

public static TF_DataType as_dtype(Type type) public static TF_DataType as_dtype(Type type)
{ {
TF_DataType dtype = TF_DataType.DtInvalid; TF_DataType dtype = TF_DataType.DtInvalid;
@@ -62,5 +63,12 @@ namespace Tensorflow


return dtype; return dtype;
} }

public static TF_DataType as_base_dtype(this TF_DataType type)
{
return (int)type > 100 ?
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) :
type;
}
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow


private Operation _initializer_op; private Operation _initializer_op;
public Operation initializer => _initializer_op; public Operation initializer => _initializer_op;
public Operation op => _initializer_op;
public Operation op => _variable.op;


public string name => _variable.name; public string name => _variable.name;


@@ -77,7 +77,7 @@ namespace Tensorflow


var shape = _initial_value.shape; var shape = _initial_value.shape;
dtype = _initial_value.dtype; dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
} }


// Manually overrides the variable's shape with the initial value's. // Manually overrides the variable's shape with the initial value's.


+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void Add() public void Add()
{ {
var x = tf.Variable(0, name: "x");
var x = tf.Variable(10, name: "x");


var model = tf.global_variables_initializer(); var model = tf.global_variables_initializer();




Loading…
Cancel
Save