@@ -3,7 +3,6 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
using TF_DataType = Tensorflow.DataType; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -72,12 +72,8 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var base_type = value.dtype; | |||||
// base type | |||||
if ((int)value.dtype > 100) | |||||
{ | |||||
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString()); | |||||
} | |||||
var base_type = value.dtype.as_base_dtype(); | |||||
input_types.Add(base_type); | input_types.Add(base_type); | ||||
} | } | ||||
} | } | ||||
@@ -151,7 +147,7 @@ namespace Tensorflow | |||||
public DataType _MakeType(TF_DataType v, AttrDef attr_def) | public DataType _MakeType(TF_DataType v, AttrDef attr_def) | ||||
{ | { | ||||
return v.as_datatype_enum(); | |||||
return v.as_base_dtype().as_datatype_enum(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -24,6 +24,7 @@ namespace Tensorflow | |||||
throw new NotImplementedException("as_numpy_datatype failed"); | throw new NotImplementedException("as_numpy_datatype failed"); | ||||
} | } | ||||
} | } | ||||
public static TF_DataType as_dtype(Type type) | public static TF_DataType as_dtype(Type type) | ||||
{ | { | ||||
TF_DataType dtype = TF_DataType.DtInvalid; | TF_DataType dtype = TF_DataType.DtInvalid; | ||||
@@ -62,5 +63,12 @@ namespace Tensorflow | |||||
return dtype; | return dtype; | ||||
} | } | ||||
public static TF_DataType as_base_dtype(this TF_DataType type) | |||||
{ | |||||
return (int)type > 100 ? | |||||
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) : | |||||
type; | |||||
} | |||||
} | } | ||||
} | } |
@@ -16,7 +16,7 @@ namespace Tensorflow | |||||
private Operation _initializer_op; | private Operation _initializer_op; | ||||
public Operation initializer => _initializer_op; | public Operation initializer => _initializer_op; | ||||
public Operation op => _initializer_op; | |||||
public Operation op => _variable.op; | |||||
public string name => _variable.name; | public string name => _variable.name; | ||||
@@ -77,7 +77,7 @@ namespace Tensorflow | |||||
var shape = _initial_value.shape; | var shape = _initial_value.shape; | ||||
dtype = _initial_value.dtype; | dtype = _initial_value.dtype; | ||||
_variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name); | |||||
} | } | ||||
// Manually overrides the variable's shape with the initial value's. | // Manually overrides the variable's shape with the initial value's. | ||||
@@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void Add() | public void Add() | ||||
{ | { | ||||
var x = tf.Variable(0, name: "x"); | |||||
var x = tf.Variable(10, name: "x"); | |||||
var model = tf.global_variables_initializer(); | var model = tf.global_variables_initializer(); | ||||