@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public static partial class tf | |||||
{ | |||||
/// <summary> | |||||
/// Outputs random values from a normal distribution. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="mean"></param> | |||||
/// <param name="stddev"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor random_normal(int[] shape, | |||||
float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class random_seed | |||||
{ | |||||
public static (int?, int?) get_seed(int? op_seed = null) | |||||
{ | |||||
return (null, null); | |||||
} | |||||
} | |||||
} |
@@ -141,10 +141,9 @@ namespace Tensorflow | |||||
} | } | ||||
if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
{ | |||||
name = op_type; | name = op_type; | ||||
} | |||||
// If a names ends with a '/' it is a "name scope" and we use it as-is, | |||||
// after removing the trailing '/'. | |||||
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | ||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null) | public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null) | ||||
{ | { | ||||
var keywords = ConvertToDict(args); | |||||
Dictionary<string, object> keywords = ConvertToDict(args); | |||||
var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
@@ -42,7 +42,8 @@ namespace Tensorflow | |||||
var attrs = new Dictionary<string, object>(); | var attrs = new Dictionary<string, object>(); | ||||
var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
var input_types = new List<TF_DataType>(); | var input_types = new List<TF_DataType>(); | ||||
dynamic values = null; | |||||
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => | return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => | ||||
{ | { | ||||
var inferred_from = new Dictionary<string, object>(); | var inferred_from = new Dictionary<string, object>(); | ||||
@@ -53,7 +54,17 @@ namespace Tensorflow | |||||
foreach (var input_arg in op_def.InputArg) | foreach (var input_arg in op_def.InputArg) | ||||
{ | { | ||||
var input_name = input_arg.Name; | var input_name = input_arg.Name; | ||||
var values = keywords[input_name]; | |||||
if (keywords.ContainsKey(input_name)) | |||||
values = keywords[input_name]; | |||||
else if (keywords.ContainsKey(input_name + "_")) | |||||
{ | |||||
input_name += "_"; | |||||
values = keywords[input_name]; | |||||
} | |||||
else | |||||
throw new TypeError("No argument for input " + input_name); | |||||
// Goals: | // Goals: | ||||
// * Convert values to Tensors if it contains constants. | // * Convert values to Tensors if it contains constants. | ||||
// * Verify that values is a list if that matches the input_arg's | // * Verify that values is a list if that matches the input_arg's | ||||
@@ -92,8 +103,8 @@ namespace Tensorflow | |||||
values = ops.internal_convert_n_to_tensor(values, | values = ops.internal_convert_n_to_tensor(values, | ||||
name: input_arg.Name, | name: input_arg.Name, | ||||
dtype: dtype, | |||||
preferred_dtype: default_dtype, | |||||
dtype: dtype.as_tf_dtype(), | |||||
preferred_dtype: default_dtype.as_tf_dtype(), | |||||
as_ref: input_arg.IsRef); | as_ref: input_arg.IsRef); | ||||
} | } | ||||
else | else | ||||
@@ -107,9 +118,9 @@ namespace Tensorflow | |||||
values = ops.internal_convert_to_tensor(values, | values = ops.internal_convert_to_tensor(values, | ||||
name: input_name, | name: input_name, | ||||
dtype: dtype, | |||||
dtype: dtype.as_tf_dtype(), | |||||
as_ref: input_arg.IsRef, | as_ref: input_arg.IsRef, | ||||
preferred_dtype: default_dtype); | |||||
preferred_dtype: default_dtype.as_tf_dtype()); | |||||
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
//attrs[input_arg.TypeAttr] = values.dtype; | //attrs[input_arg.TypeAttr] = values.dtype; | ||||
@@ -7,6 +7,8 @@ namespace Tensorflow | |||||
{ | { | ||||
public class array_ops | public class array_ops | ||||
{ | { | ||||
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") => gen_array_ops.placeholder_with_default(input, shape, name); | |||||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | ||||
{ | { | ||||
dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
@@ -35,13 +37,13 @@ namespace Tensorflow | |||||
var nd = np.zeros<T>(shape); | var nd = np.zeros<T>(shape); | ||||
if (shape.Size < 1000) | if (shape.Size < 1000) | ||||
{ | { | ||||
return constant_op.constant(nd, name); | |||||
return constant_op.constant(nd, name: name); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | ||||
var c = constant_op.constant(0); | var c = constant_op.constant(0); | ||||
return gen_array_ops.fill(tShape, c, name); | |||||
return gen_array_ops.fill(tShape, c, name: name); | |||||
} | } | ||||
} | } | ||||
@@ -99,7 +101,7 @@ namespace Tensorflow | |||||
if (optimize && input_shape.is_fully_defined()) | if (optimize && input_shape.is_fully_defined()) | ||||
{ | { | ||||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | ||||
return constant_op.constant(nd, name); | |||||
return constant_op.constant(nd, name: name); | |||||
} | } | ||||
} | } | ||||
@@ -122,7 +124,7 @@ namespace Tensorflow | |||||
if (input_shape.is_fully_defined()) | if (input_shape.is_fully_defined()) | ||||
{ | { | ||||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | ||||
return constant_op.constant(nd, name); | |||||
return constant_op.constant(nd, name: name); | |||||
} | } | ||||
} | } | ||||
@@ -113,7 +113,7 @@ namespace Tensorflow | |||||
/// <param name="shape"></param> | /// <param name="shape"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||||
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -0,0 +1,33 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class gen_random_ops | |||||
{ | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
/// <summary> | |||||
/// Outputs random values from a normal distribution. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <param name="seed2"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = "") | |||||
{ | |||||
if (!seed.HasValue) | |||||
seed = 0; | |||||
if (!seed2.HasValue) | |||||
seed2 = 0; | |||||
var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, | |||||
args: new { shape, dtype, seed, seed2 }); | |||||
return _op.outputs[0]; | |||||
} | |||||
} | |||||
} |
@@ -6,6 +6,8 @@ namespace Tensorflow | |||||
{ | { | ||||
public class math_ops | public class math_ops | ||||
{ | { | ||||
public static Tensor add(Tensor x, Tensor y, string name = "") => gen_math_ops.add(x, y, name); | |||||
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | ||||
{ | { | ||||
var base_type = dtype.as_base_dtype(); | var base_type = dtype.as_base_dtype(); | ||||
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class random_ops | |||||
{ | |||||
public static Tensor random_normal(int[] shape, | |||||
float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = "") | |||||
{ | |||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope => | |||||
{ | |||||
var shape_tensor = _ShapeTensor(shape); | |||||
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | |||||
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); | |||||
var (seed1, seed2) = random_seed.get_seed(seed); | |||||
var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | |||||
var mul = rnd * stddev_tensor; | |||||
var value = math_ops.add(mul, mean_tensor, name: name); | |||||
return value; | |||||
}); | |||||
} | |||||
private static Tensor _ShapeTensor(int[] shape) | |||||
{ | |||||
return ops.convert_to_tensor(shape, name: "shape"); | |||||
} | |||||
} | |||||
} | |||||
@@ -55,11 +55,19 @@ namespace Tensorflow | |||||
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | ||||
switch(subfeed.Value) | switch(subfeed.Value) | ||||
{ | { | ||||
case float floatVal: | |||||
feed_dict_tensor[subfeed_t] = (NDArray)floatVal; | |||||
break; | |||||
case int intVal: | |||||
feed_dict_tensor[subfeed_t] = (NDArray)intVal; | |||||
break; | |||||
case string str: | case string str: | ||||
feed_dict_tensor[subfeed_t] = np.array(str); | |||||
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||||
feed_dict_tensor[subfeed_t] = (NDArray)str; | |||||
break; | break; | ||||
default: | |||||
throw new NotImplementedException("_run subfeed"); | |||||
} | } | ||||
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||||
} | } | ||||
} | } | ||||
@@ -19,7 +19,12 @@ namespace Tensorflow | |||||
/// <param name="name">Optional name for the tensor.</param> | /// <param name="name">Optional name for the tensor.</param> | ||||
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> | /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||||
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") | |||||
{ | |||||
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); | |||||
} | |||||
private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
@@ -27,13 +32,13 @@ namespace Tensorflow | |||||
} | } | ||||
Graph g = ops.get_default_graph(); | Graph g = ops.get_default_graph(); | ||||
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); | |||||
var tensor_value = new AttrValue | |||||
{ | |||||
Type = tensor_pb.Dtype, | |||||
Tensor = tensor_pb | |||||
}; | |||||
var tensor_value = new AttrValue(); | |||||
tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||||
dtype: dtype, | |||||
shape: shape, | |||||
verify_shape: verify_shape, | |||||
allow_broadcast: allow_broadcast); | |||||
var dtype_value = new AttrValue | var dtype_value = new AttrValue | ||||
{ | { | ||||
Type = tensor_value.Tensor.Dtype, | Type = tensor_value.Tensor.Dtype, | ||||
@@ -44,8 +49,8 @@ namespace Tensorflow | |||||
attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
var op = g.create_op("Const", | var op = g.create_op("Const", | ||||
null, | |||||
new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
new Tensor[0], | |||||
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, | |||||
attrs: attrs, | attrs: attrs, | ||||
name: name); | name: name); | ||||
@@ -81,7 +86,7 @@ namespace Tensorflow | |||||
if (string.IsNullOrEmpty(name)) | if (string.IsNullOrEmpty(name)) | ||||
name = "shape_as_tensor"; | name = "shape_as_tensor"; | ||||
return constant_op.constant(s_list, name); | |||||
return constant_op.constant(s_list, name: name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -83,6 +83,20 @@ namespace Tensorflow | |||||
type; | type; | ||||
} | } | ||||
public static TF_DataType as_tf_dtype(this DataType type) | |||||
{ | |||||
TF_DataType dtype = TF_DataType.DtInvalid; | |||||
switch (type) | |||||
{ | |||||
default: | |||||
Enum.TryParse(((int)type).ToString(), out dtype); | |||||
break; | |||||
} | |||||
return dtype; | |||||
} | |||||
public static TF_DataType as_ref(this TF_DataType type) | public static TF_DataType as_ref(this TF_DataType type) | ||||
{ | { | ||||
return (int)type < 100 ? | return (int)type < 100 ? | ||||
@@ -10,33 +10,166 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class tensor_util | public static class tensor_util | ||||
{ | { | ||||
public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false) | |||||
public static TF_DataType[] _TENSOR_CONTENT_TYPES = | |||||
{ | { | ||||
var shape = nd.Storage.Shape; | |||||
TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, TF_DataType.TF_INT32, TF_DataType.TF_UINT8, TF_DataType.TF_INT16, | |||||
TF_DataType.TF_INT8, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, | |||||
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | |||||
}; | |||||
/// <summary> | |||||
/// Create a TensorProto. | |||||
/// </summary> | |||||
/// <param name="values"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="verify_shape"></param> | |||||
/// <param name="allow_broadcast"></param> | |||||
/// <returns></returns> | |||||
public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false) | |||||
{ | |||||
if (allow_broadcast && verify_shape) | |||||
throw new ValueError("allow_broadcast and verify_shape are not both allowed."); | |||||
if (values is TensorProto tp) | |||||
return tp; | |||||
if (dtype != TF_DataType.DtInvalid) | |||||
; | |||||
bool is_quantized = new TF_DataType[] | |||||
{ | |||||
TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||||
TF_DataType.TF_QINT32 | |||||
}.Contains(dtype); | |||||
// We first convert value to a numpy array or scalar. | |||||
NDArray nparray = null; | |||||
if (values is NDArray nd) | |||||
{ | |||||
nparray = nd; | |||||
} | |||||
else | |||||
{ | |||||
if (values == null) | |||||
throw new ValueError("None values not supported."); | |||||
switch (values) | |||||
{ | |||||
/*case bool boolVal: | |||||
nparray = boolVal; | |||||
break;*/ | |||||
case int intVal: | |||||
nparray = intVal; | |||||
break; | |||||
case int[] intVals: | |||||
nparray = np.array(intVals); | |||||
break; | |||||
case float floatVal: | |||||
nparray = floatVal; | |||||
break; | |||||
case double doubleVal: | |||||
nparray = doubleVal; | |||||
break; | |||||
case string strVal: | |||||
nparray = strVal; | |||||
break; | |||||
default: | |||||
throw new Exception("make_tensor_proto Not Implemented"); | |||||
} | |||||
} | |||||
var numpy_dtype = dtypes.as_dtype(nparray.dtype); | |||||
if (numpy_dtype == TF_DataType.DtInvalid) | |||||
throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | |||||
// If dtype was specified and is a quantized type, we convert | |||||
// numpy_dtype back into the quantized version. | |||||
if (is_quantized) | |||||
numpy_dtype = dtype; | |||||
bool is_same_size = false; | |||||
int shape_size = 0; | |||||
// If shape is not given, get the shape from the numpy array. | |||||
if (shape == null) | |||||
{ | |||||
shape = nparray.shape; | |||||
is_same_size = true; | |||||
shape_size = nparray.size; | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||||
} | |||||
var numpy_dtype = dtypes.as_dtype(nd.dtype); | |||||
var tensor_proto = new tensor_pb2.TensorProto | var tensor_proto = new tensor_pb2.TensorProto | ||||
{ | { | ||||
Dtype = numpy_dtype.as_datatype_enum(), | Dtype = numpy_dtype.as_datatype_enum(), | ||||
TensorShape = shape.reshape(nd.shape).as_proto() | |||||
TensorShape = tensor_util.as_shape(shape) | |||||
}; | }; | ||||
switch (nd.dtype.Name) | |||||
if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||||
{ | |||||
var bytes = new List<byte>(); | |||||
var nd2 = nparray.ravel(); | |||||
switch (nparray.dtype.Name) | |||||
{ | |||||
case "Int32": | |||||
nd2.Data<int>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
case "Single": | |||||
nd2.Data<float>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
case "Double": | |||||
nd2.Data<double>().Select(x => | |||||
{ | |||||
bytes.AddRange(BitConverter.GetBytes(x)); | |||||
return x; | |||||
}).ToArray(); | |||||
break; | |||||
default: | |||||
throw new Exception("make_tensor_proto Not Implemented"); | |||||
} | |||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||||
return tensor_proto; | |||||
} | |||||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||||
{ | { | ||||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||||
return tensor_proto; | |||||
} | |||||
var proto_values = nparray.ravel(); | |||||
switch (nparray.dtype.Name) | |||||
{ | |||||
case "Bool": | |||||
tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | |||||
break; | |||||
case "Int32": | case "Int32": | ||||
tensor_proto.IntVal.AddRange(nd.Data<int>()); | |||||
tensor_proto.IntVal.AddRange(proto_values.Data<int>()); | |||||
break; | break; | ||||
case "Single": | case "Single": | ||||
tensor_proto.FloatVal.AddRange(nd.Data<float>()); | |||||
tensor_proto.FloatVal.AddRange(proto_values.Data<float>()); | |||||
break; | break; | ||||
case "Double": | case "Double": | ||||
tensor_proto.DoubleVal.AddRange(nd.Data<double>()); | |||||
tensor_proto.DoubleVal.AddRange(proto_values.Data<double>()); | |||||
break; | break; | ||||
case "String": | case "String": | ||||
tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||||
tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); | |||||
break; | break; | ||||
default: | default: | ||||
throw new Exception("Not Implemented"); | |||||
throw new Exception("make_tensor_proto Not Implemented"); | |||||
} | } | ||||
return tensor_proto; | return tensor_proto; | ||||
@@ -73,14 +206,24 @@ namespace Tensorflow | |||||
return nd; | return nd; | ||||
} | } | ||||
public static TensorShapeProto as_shape(long[] dims) | |||||
public static TensorShapeProto as_shape<T>(T[] dims) | |||||
{ | { | ||||
TensorShapeProto shape = new TensorShapeProto(); | TensorShapeProto shape = new TensorShapeProto(); | ||||
for (int i = 0; i < dims.Length; i++) | for (int i = 0; i < dims.Length; i++) | ||||
{ | { | ||||
var dim = new TensorShapeProto.Types.Dim(); | var dim = new TensorShapeProto.Types.Dim(); | ||||
dim.Size = dims[i]; | |||||
switch(dims[i]) | |||||
{ | |||||
case int n: | |||||
dim.Size = n; | |||||
break; | |||||
case long l: | |||||
dim.Size = l; | |||||
break; | |||||
default: | |||||
throw new NotImplementedException("as_shape Not Implemented"); | |||||
} | |||||
dim.Name = $"dim_{i}"; | dim.Name = $"dim_{i}"; | ||||
shape.Dim.Add(dim); | shape.Dim.Add(dim); | ||||
@@ -7,14 +7,8 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class tf | public static partial class tf | ||||
{ | { | ||||
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||||
{ | |||||
return constant_op.constant(nd, name, verify_shape); | |||||
} | |||||
public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); | |||||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||||
{ | |||||
return array_ops.zeros(shape, dtype, name); | |||||
} | |||||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name); | |||||
} | } | ||||
} | } |
@@ -84,8 +84,8 @@ namespace Tensorflow | |||||
name = scope; | name = scope; | ||||
// Add a placeholder string tensor for the filename. | // Add a placeholder string tensor for the filename. | ||||
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||||
var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | |||||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||||
// Keep the name "Const" for backwards compatibility. | // Keep the name "Const" for backwards compatibility. | ||||
// Add the save ops. | // Add the save ops. | ||||
@@ -68,16 +68,14 @@ namespace Tensorflow | |||||
/// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||||
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
switch (value) | |||||
{ | |||||
case Tensor val: | |||||
return val; | |||||
default: | |||||
var nd = tensor_util.convert_to_numpy_ndarray(value); | |||||
return constant_op.constant(nd, name); | |||||
} | |||||
return convert_to_tensor_v2(value, dtype, preferred_dtype, name); | |||||
} | |||||
public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = "") | |||||
{ | |||||
return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false); | |||||
} | } | ||||
public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | ||||
@@ -87,7 +85,7 @@ namespace Tensorflow | |||||
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) | public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) | ||||
{ | { | ||||
return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref); | |||||
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -117,17 +115,14 @@ namespace Tensorflow | |||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | ||||
// Add inputs | // Add inputs | ||||
if(inputs != null) | |||||
foreach (var op_input in inputs) | |||||
{ | { | ||||
foreach (var op_input in inputs) | |||||
{ | |||||
if (op_input is Tensor[] op_inputs) | |||||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
else if (op_input is Tensor op_input1) | |||||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
else | |||||
throw new NotImplementedException("_create_c_op"); | |||||
} | |||||
if (op_input is Tensor[] op_inputs) | |||||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
else if (op_input is Tensor op_input1) | |||||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
else | |||||
throw new NotImplementedException("_create_c_op"); | |||||
} | } | ||||
var status = new Status(); | var status = new Status(); | ||||
@@ -142,8 +137,8 @@ namespace Tensorflow | |||||
var bytes = attr.Value.ToByteArray(); | var bytes = attr.Value.ToByteArray(); | ||||
var proto = Marshal.AllocHGlobal(bytes.Length); | var proto = Marshal.AllocHGlobal(bytes.Length); | ||||
Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); | |||||
uint len = (uint)bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
@@ -385,8 +380,8 @@ namespace Tensorflow | |||||
return ret.ToArray(); | return ret.ToArray(); | ||||
} | } | ||||
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | |||||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||||
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||||
bool as_ref = false) | bool as_ref = false) | ||||
{ | { | ||||
var ret = new List<Tensor>(); | var ret = new List<Tensor>(); | ||||
@@ -400,28 +395,30 @@ namespace Tensorflow | |||||
return ret.ToArray(); | return ret.ToArray(); | ||||
} | } | ||||
public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid, | |||||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||||
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||||
bool as_ref = false) | bool as_ref = false) | ||||
{ | { | ||||
switch (typeof(T).Name) | |||||
switch (value) | |||||
{ | { | ||||
case "Tensor": | |||||
return value as Tensor; | |||||
case "String": | |||||
return constant_op.constant(Convert.ToString(value), name); | |||||
case "String[]": | |||||
return constant_op.constant(value as string[], name); | |||||
case "Int32": | |||||
return constant_op.constant(Convert.ToInt32(value), name); | |||||
case "Single": | |||||
return constant_op.constant(Convert.ToSingle(value), name); | |||||
case "Double": | |||||
return constant_op.constant(Convert.ToDouble(value), name); | |||||
case "RefVariable": | |||||
return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref); | |||||
case Tensor tensor: | |||||
return tensor; | |||||
case string str: | |||||
return constant_op.constant(str, dtype: dtype, name: name); | |||||
case string[] strArray: | |||||
return constant_op.constant(strArray, dtype: dtype, name: name); | |||||
case int intVal: | |||||
return constant_op.constant(intVal, dtype: dtype, name: name); | |||||
case int[] intArray: | |||||
return constant_op.constant(intArray, dtype: dtype, name: name); | |||||
case float floatVal: | |||||
return constant_op.constant(floatVal, dtype: dtype, name: name); | |||||
case double doubleVal: | |||||
return constant_op.constant(doubleVal, dtype: dtype, name: name); | |||||
case RefVariable varVal: | |||||
return varVal._TensorConversionFunction(as_ref: as_ref); | |||||
default: | default: | ||||
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor"); | |||||
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -79,7 +79,7 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.AreEqual(result.shape[0], 2); | Assert.AreEqual(result.shape[0], 2); | ||||
Assert.AreEqual(result.shape[1], 3); | Assert.AreEqual(result.shape[1], 3); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data)); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 2, 1, 1, 1, 3 }, data)); | |||||
}); | }); | ||||
} | } | ||||
@@ -17,6 +17,35 @@ namespace TensorFlowNET.UnitTest | |||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | ||||
} | } | ||||
[TestMethod] | |||||
public void ImportGraph() | |||||
{ | |||||
var v = tf.Variable(0, name: "my_variable"); | |||||
var sess = tf.Session(); | |||||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||||
} | |||||
[TestMethod] | |||||
public void SaveSimple() | |||||
{ | |||||
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||||
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||||
var init_op = tf.global_variables_initializer(); | |||||
// Add ops to save and restore all the variables. | |||||
var saver = tf.train.Saver(); | |||||
with<Session>(tf.Session(), sess => | |||||
{ | |||||
sess.run(init_op); | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
}); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Save() | public void Save() | ||||
{ | { | ||||