@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
/// <summary> | |||
/// Outputs random values from a normal distribution. | |||
/// </summary> | |||
/// <param name="shape"></param> | |||
/// <param name="mean"></param> | |||
/// <param name="stddev"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="seed"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor random_normal(int[] shape, | |||
float mean = 0.0f, | |||
float stddev = 1.0f, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class random_seed | |||
{ | |||
public static (int?, int?) get_seed(int? op_seed = null) | |||
{ | |||
return (null, null); | |||
} | |||
} | |||
} |
@@ -141,10 +141,9 @@ namespace Tensorflow | |||
} | |||
if (String.IsNullOrEmpty(name)) | |||
{ | |||
name = op_type; | |||
} | |||
// If a names ends with a '/' it is a "name scope" and we use it as-is, | |||
// after removing the trailing '/'. | |||
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | |||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||
{ | |||
public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null) | |||
{ | |||
var keywords = ConvertToDict(args); | |||
Dictionary<string, object> keywords = ConvertToDict(args); | |||
var g = ops.get_default_graph(); | |||
var op_def = g.GetOpDef(op_type_name); | |||
@@ -42,7 +42,8 @@ namespace Tensorflow | |||
var attrs = new Dictionary<string, object>(); | |||
var inputs = new List<Tensor>(); | |||
var input_types = new List<TF_DataType>(); | |||
dynamic values = null; | |||
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => | |||
{ | |||
var inferred_from = new Dictionary<string, object>(); | |||
@@ -53,7 +54,17 @@ namespace Tensorflow | |||
foreach (var input_arg in op_def.InputArg) | |||
{ | |||
var input_name = input_arg.Name; | |||
var values = keywords[input_name]; | |||
if (keywords.ContainsKey(input_name)) | |||
values = keywords[input_name]; | |||
else if (keywords.ContainsKey(input_name + "_")) | |||
{ | |||
input_name += "_"; | |||
values = keywords[input_name]; | |||
} | |||
else | |||
throw new TypeError("No argument for input " + input_name); | |||
// Goals: | |||
// * Convert values to Tensors if it contains constants. | |||
// * Verify that values is a list if that matches the input_arg's | |||
@@ -92,8 +103,8 @@ namespace Tensorflow | |||
values = ops.internal_convert_n_to_tensor(values, | |||
name: input_arg.Name, | |||
dtype: dtype, | |||
preferred_dtype: default_dtype, | |||
dtype: dtype.as_tf_dtype(), | |||
preferred_dtype: default_dtype.as_tf_dtype(), | |||
as_ref: input_arg.IsRef); | |||
} | |||
else | |||
@@ -107,9 +118,9 @@ namespace Tensorflow | |||
values = ops.internal_convert_to_tensor(values, | |||
name: input_name, | |||
dtype: dtype, | |||
dtype: dtype.as_tf_dtype(), | |||
as_ref: input_arg.IsRef, | |||
preferred_dtype: default_dtype); | |||
preferred_dtype: default_dtype.as_tf_dtype()); | |||
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
//attrs[input_arg.TypeAttr] = values.dtype; | |||
@@ -7,6 +7,8 @@ namespace Tensorflow | |||
{ | |||
public class array_ops | |||
{ | |||
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") => gen_array_ops.placeholder_with_default(input, shape, name); | |||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
@@ -35,13 +37,13 @@ namespace Tensorflow | |||
var nd = np.zeros<T>(shape); | |||
if (shape.Size < 1000) | |||
{ | |||
return constant_op.constant(nd, name); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
else | |||
{ | |||
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | |||
var c = constant_op.constant(0); | |||
return gen_array_ops.fill(tShape, c, name); | |||
return gen_array_ops.fill(tShape, c, name: name); | |||
} | |||
} | |||
@@ -99,7 +101,7 @@ namespace Tensorflow | |||
if (optimize && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
return constant_op.constant(nd, name); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -122,7 +124,7 @@ namespace Tensorflow | |||
if (input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
return constant_op.constant(nd, name); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -113,7 +113,7 @@ namespace Tensorflow | |||
/// <param name="shape"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | |||
return _op.outputs[0]; | |||
@@ -0,0 +1,33 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class gen_random_ops | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
/// <summary> | |||
/// Outputs random values from a normal distribution. | |||
/// </summary> | |||
/// <param name="shape"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="seed"></param> | |||
/// <param name="seed2"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = "") | |||
{ | |||
if (!seed.HasValue) | |||
seed = 0; | |||
if (!seed2.HasValue) | |||
seed2 = 0; | |||
var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, | |||
args: new { shape, dtype, seed, seed2 }); | |||
return _op.outputs[0]; | |||
} | |||
} | |||
} |
@@ -6,6 +6,8 @@ namespace Tensorflow | |||
{ | |||
public class math_ops | |||
{ | |||
public static Tensor add(Tensor x, Tensor y, string name = "") => gen_math_ops.add(x, y, name); | |||
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
{ | |||
var base_type = dtype.as_base_dtype(); | |||
@@ -0,0 +1,35 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class random_ops | |||
{ | |||
public static Tensor random_normal(int[] shape, | |||
float mean = 0.0f, | |||
float stddev = 1.0f, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = "") | |||
{ | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope => | |||
{ | |||
var shape_tensor = _ShapeTensor(shape); | |||
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | |||
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); | |||
var (seed1, seed2) = random_seed.get_seed(seed); | |||
var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | |||
var mul = rnd * stddev_tensor; | |||
var value = math_ops.add(mul, mean_tensor, name: name); | |||
return value; | |||
}); | |||
} | |||
private static Tensor _ShapeTensor(int[] shape) | |||
{ | |||
return ops.convert_to_tensor(shape, name: "shape"); | |||
} | |||
} | |||
} | |||
@@ -55,11 +55,19 @@ namespace Tensorflow | |||
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||
switch(subfeed.Value) | |||
{ | |||
case float floatVal: | |||
feed_dict_tensor[subfeed_t] = (NDArray)floatVal; | |||
break; | |||
case int intVal: | |||
feed_dict_tensor[subfeed_t] = (NDArray)intVal; | |||
break; | |||
case string str: | |||
feed_dict_tensor[subfeed_t] = np.array(str); | |||
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||
feed_dict_tensor[subfeed_t] = (NDArray)str; | |||
break; | |||
default: | |||
throw new NotImplementedException("_run subfeed"); | |||
} | |||
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||
} | |||
} | |||
@@ -19,7 +19,12 @@ namespace Tensorflow | |||
/// <param name="name">Optional name for the tensor.</param> | |||
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> | |||
/// <returns></returns> | |||
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") | |||
{ | |||
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); | |||
} | |||
private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
@@ -27,13 +32,13 @@ namespace Tensorflow | |||
} | |||
Graph g = ops.get_default_graph(); | |||
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); | |||
var tensor_value = new AttrValue | |||
{ | |||
Type = tensor_pb.Dtype, | |||
Tensor = tensor_pb | |||
}; | |||
var tensor_value = new AttrValue(); | |||
tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||
dtype: dtype, | |||
shape: shape, | |||
verify_shape: verify_shape, | |||
allow_broadcast: allow_broadcast); | |||
var dtype_value = new AttrValue | |||
{ | |||
Type = tensor_value.Tensor.Dtype, | |||
@@ -44,8 +49,8 @@ namespace Tensorflow | |||
attrs["dtype"] = dtype_value; | |||
var op = g.create_op("Const", | |||
null, | |||
new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||
new Tensor[0], | |||
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, | |||
attrs: attrs, | |||
name: name); | |||
@@ -81,7 +86,7 @@ namespace Tensorflow | |||
if (string.IsNullOrEmpty(name)) | |||
name = "shape_as_tensor"; | |||
return constant_op.constant(s_list, name); | |||
return constant_op.constant(s_list, name: name); | |||
} | |||
} | |||
} |
@@ -83,6 +83,20 @@ namespace Tensorflow | |||
type; | |||
} | |||
public static TF_DataType as_tf_dtype(this DataType type) | |||
{ | |||
TF_DataType dtype = TF_DataType.DtInvalid; | |||
switch (type) | |||
{ | |||
default: | |||
Enum.TryParse(((int)type).ToString(), out dtype); | |||
break; | |||
} | |||
return dtype; | |||
} | |||
public static TF_DataType as_ref(this TF_DataType type) | |||
{ | |||
return (int)type < 100 ? | |||
@@ -10,33 +10,166 @@ namespace Tensorflow | |||
{ | |||
public static class tensor_util | |||
{ | |||
public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false) | |||
public static TF_DataType[] _TENSOR_CONTENT_TYPES = | |||
{ | |||
var shape = nd.Storage.Shape; | |||
TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, TF_DataType.TF_INT32, TF_DataType.TF_UINT8, TF_DataType.TF_INT16, | |||
TF_DataType.TF_INT8, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, | |||
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | |||
}; | |||
/// <summary> | |||
/// Create a TensorProto. | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="shape"></param> | |||
/// <param name="verify_shape"></param> | |||
/// <param name="allow_broadcast"></param> | |||
/// <returns></returns> | |||
public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false) | |||
{ | |||
if (allow_broadcast && verify_shape) | |||
throw new ValueError("allow_broadcast and verify_shape are not both allowed."); | |||
if (values is TensorProto tp) | |||
return tp; | |||
if (dtype != TF_DataType.DtInvalid) | |||
; | |||
bool is_quantized = new TF_DataType[] | |||
{ | |||
TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
TF_DataType.TF_QINT32 | |||
}.Contains(dtype); | |||
// We first convert value to a numpy array or scalar. | |||
NDArray nparray = null; | |||
if (values is NDArray nd) | |||
{ | |||
nparray = nd; | |||
} | |||
else | |||
{ | |||
if (values == null) | |||
throw new ValueError("None values not supported."); | |||
switch (values) | |||
{ | |||
/*case bool boolVal: | |||
nparray = boolVal; | |||
break;*/ | |||
case int intVal: | |||
nparray = intVal; | |||
break; | |||
case int[] intVals: | |||
nparray = np.array(intVals); | |||
break; | |||
case float floatVal: | |||
nparray = floatVal; | |||
break; | |||
case double doubleVal: | |||
nparray = doubleVal; | |||
break; | |||
case string strVal: | |||
nparray = strVal; | |||
break; | |||
default: | |||
throw new Exception("make_tensor_proto Not Implemented"); | |||
} | |||
} | |||
var numpy_dtype = dtypes.as_dtype(nparray.dtype); | |||
if (numpy_dtype == TF_DataType.DtInvalid) | |||
throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | |||
// If dtype was specified and is a quantized type, we convert | |||
// numpy_dtype back into the quantized version. | |||
if (is_quantized) | |||
numpy_dtype = dtype; | |||
bool is_same_size = false; | |||
int shape_size = 0; | |||
// If shape is not given, get the shape from the numpy array. | |||
if (shape == null) | |||
{ | |||
shape = nparray.shape; | |||
is_same_size = true; | |||
shape_size = nparray.size; | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||
} | |||
var numpy_dtype = dtypes.as_dtype(nd.dtype); | |||
var tensor_proto = new tensor_pb2.TensorProto | |||
{ | |||
Dtype = numpy_dtype.as_datatype_enum(), | |||
TensorShape = shape.reshape(nd.shape).as_proto() | |||
TensorShape = tensor_util.as_shape(shape) | |||
}; | |||
switch (nd.dtype.Name) | |||
if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||
{ | |||
var bytes = new List<byte>(); | |||
var nd2 = nparray.ravel(); | |||
switch (nparray.dtype.Name) | |||
{ | |||
case "Int32": | |||
nd2.Data<int>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
case "Single": | |||
nd2.Data<float>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
case "Double": | |||
nd2.Data<double>().Select(x => | |||
{ | |||
bytes.AddRange(BitConverter.GetBytes(x)); | |||
return x; | |||
}).ToArray(); | |||
break; | |||
default: | |||
throw new Exception("make_tensor_proto Not Implemented"); | |||
} | |||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
return tensor_proto; | |||
} | |||
if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||
{ | |||
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
return tensor_proto; | |||
} | |||
var proto_values = nparray.ravel(); | |||
switch (nparray.dtype.Name) | |||
{ | |||
case "Bool": | |||
tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | |||
break; | |||
case "Int32": | |||
tensor_proto.IntVal.AddRange(nd.Data<int>()); | |||
tensor_proto.IntVal.AddRange(proto_values.Data<int>()); | |||
break; | |||
case "Single": | |||
tensor_proto.FloatVal.AddRange(nd.Data<float>()); | |||
tensor_proto.FloatVal.AddRange(proto_values.Data<float>()); | |||
break; | |||
case "Double": | |||
tensor_proto.DoubleVal.AddRange(nd.Data<double>()); | |||
tensor_proto.DoubleVal.AddRange(proto_values.Data<double>()); | |||
break; | |||
case "String": | |||
tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||
tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); | |||
break; | |||
default: | |||
throw new Exception("Not Implemented"); | |||
throw new Exception("make_tensor_proto Not Implemented"); | |||
} | |||
return tensor_proto; | |||
@@ -73,14 +206,24 @@ namespace Tensorflow | |||
return nd; | |||
} | |||
public static TensorShapeProto as_shape(long[] dims) | |||
public static TensorShapeProto as_shape<T>(T[] dims) | |||
{ | |||
TensorShapeProto shape = new TensorShapeProto(); | |||
for (int i = 0; i < dims.Length; i++) | |||
{ | |||
var dim = new TensorShapeProto.Types.Dim(); | |||
dim.Size = dims[i]; | |||
switch(dims[i]) | |||
{ | |||
case int n: | |||
dim.Size = n; | |||
break; | |||
case long l: | |||
dim.Size = l; | |||
break; | |||
default: | |||
throw new NotImplementedException("as_shape Not Implemented"); | |||
} | |||
dim.Name = $"dim_{i}"; | |||
shape.Dim.Add(dim); | |||
@@ -7,14 +7,8 @@ namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||
{ | |||
return constant_op.constant(nd, name, verify_shape); | |||
} | |||
public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); | |||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||
{ | |||
return array_ops.zeros(shape, dtype, name); | |||
} | |||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name); | |||
} | |||
} |
@@ -84,8 +84,8 @@ namespace Tensorflow | |||
name = scope; | |||
// Add a placeholder string tensor for the filename. | |||
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||
var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | |||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||
// Keep the name "Const" for backwards compatibility. | |||
// Add the save ops. | |||
@@ -68,16 +68,14 @@ namespace Tensorflow | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||
{ | |||
switch (value) | |||
{ | |||
case Tensor val: | |||
return val; | |||
default: | |||
var nd = tensor_util.convert_to_numpy_ndarray(value); | |||
return constant_op.constant(nd, name); | |||
} | |||
return convert_to_tensor_v2(value, dtype, preferred_dtype, name); | |||
} | |||
public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = "") | |||
{ | |||
return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false); | |||
} | |||
public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
@@ -87,7 +85,7 @@ namespace Tensorflow | |||
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) | |||
{ | |||
return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref); | |||
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
} | |||
/// <summary> | |||
@@ -117,17 +115,14 @@ namespace Tensorflow | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
// Add inputs | |||
if(inputs != null) | |||
foreach (var op_input in inputs) | |||
{ | |||
foreach (var op_input in inputs) | |||
{ | |||
if (op_input is Tensor[] op_inputs) | |||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
else if (op_input is Tensor op_input1) | |||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
else | |||
throw new NotImplementedException("_create_c_op"); | |||
} | |||
if (op_input is Tensor[] op_inputs) | |||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
else if (op_input is Tensor op_input1) | |||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
else | |||
throw new NotImplementedException("_create_c_op"); | |||
} | |||
var status = new Status(); | |||
@@ -142,8 +137,8 @@ namespace Tensorflow | |||
var bytes = attr.Value.ToByteArray(); | |||
var proto = Marshal.AllocHGlobal(bytes.Length); | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); | |||
uint len = (uint)bytes.Length; | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
status.Check(true); | |||
} | |||
@@ -385,8 +380,8 @@ namespace Tensorflow | |||
return ret.ToArray(); | |||
} | |||
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | |||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
bool as_ref = false) | |||
{ | |||
var ret = new List<Tensor>(); | |||
@@ -400,28 +395,30 @@ namespace Tensorflow | |||
return ret.ToArray(); | |||
} | |||
public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid, | |||
string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
bool as_ref = false) | |||
{ | |||
switch (typeof(T).Name) | |||
switch (value) | |||
{ | |||
case "Tensor": | |||
return value as Tensor; | |||
case "String": | |||
return constant_op.constant(Convert.ToString(value), name); | |||
case "String[]": | |||
return constant_op.constant(value as string[], name); | |||
case "Int32": | |||
return constant_op.constant(Convert.ToInt32(value), name); | |||
case "Single": | |||
return constant_op.constant(Convert.ToSingle(value), name); | |||
case "Double": | |||
return constant_op.constant(Convert.ToDouble(value), name); | |||
case "RefVariable": | |||
return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref); | |||
case Tensor tensor: | |||
return tensor; | |||
case string str: | |||
return constant_op.constant(str, dtype: dtype, name: name); | |||
case string[] strArray: | |||
return constant_op.constant(strArray, dtype: dtype, name: name); | |||
case int intVal: | |||
return constant_op.constant(intVal, dtype: dtype, name: name); | |||
case int[] intArray: | |||
return constant_op.constant(intArray, dtype: dtype, name: name); | |||
case float floatVal: | |||
return constant_op.constant(floatVal, dtype: dtype, name: name); | |||
case double doubleVal: | |||
return constant_op.constant(doubleVal, dtype: dtype, name: name); | |||
case RefVariable varVal: | |||
return varVal._TensorConversionFunction(as_ref: as_ref); | |||
default: | |||
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor"); | |||
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | |||
} | |||
} | |||
} | |||
@@ -79,7 +79,7 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual(result.shape[0], 2); | |||
Assert.AreEqual(result.shape[1], 3); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 2, 1, 1, 1, 3 }, data)); | |||
}); | |||
} | |||
@@ -17,6 +17,35 @@ namespace TensorFlowNET.UnitTest | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
} | |||
[TestMethod] | |||
public void ImportGraph() | |||
{ | |||
var v = tf.Variable(0, name: "my_variable"); | |||
var sess = tf.Session(); | |||
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
} | |||
[TestMethod] | |||
public void SaveSimple() | |||
{ | |||
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||
var init_op = tf.global_variables_initializer(); | |||
// Add ops to save and restore all the variables. | |||
var saver = tf.train.Saver(); | |||
with<Session>(tf.Session(), sess => | |||
{ | |||
sess.run(init_op); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
}); | |||
} | |||
[TestMethod] | |||
public void Save() | |||
{ | |||