Browse Source

bunch of updates.

tags/v0.8.0
haiping008 6 years ago
parent
commit
11572bf770
18 changed files with 405 additions and 93 deletions
  1. +26
    -0
      src/TensorFlowNET.Core/APIs/tf.random.cs
  2. +14
    -0
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  3. +2
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +18
    -7
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  5. +6
    -4
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +33
    -0
      src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
  8. +2
    -0
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  9. +35
    -0
      src/TensorFlowNET.Core/Operations/random_ops.py.cs
  10. +10
    -2
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  11. +16
    -11
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  12. +14
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  13. +155
    -12
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  14. +2
    -8
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  16. +39
    -42
      src/TensorFlowNET.Core/ops.py.cs
  17. +1
    -1
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  18. +29
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 26
- 0
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
/// <summary>
/// Outputs random values from a normal distribution.
/// </summary>
/// <param name="shape"></param>
/// <param name="mean"></param>
/// <param name="stddev"></param>
/// <param name="dtype"></param>
/// <param name="seed"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor random_normal(int[] shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
}
}

+ 14
- 0
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class random_seed
{
public static (int?, int?) get_seed(int? op_seed = null)
{
return (null, null);
}
}
}

+ 2
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -141,10 +141,9 @@ namespace Tensorflow
}

if (String.IsNullOrEmpty(name))
{
name = op_type;
}
// If a names ends with a '/' it is a "name scope" and we use it as-is,
// after removing the trailing '/'.
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);



+ 18
- 7
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow
{
public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null)
{
var keywords = ConvertToDict(args);
Dictionary<string, object> keywords = ConvertToDict(args);
var g = ops.get_default_graph();
var op_def = g.GetOpDef(op_type_name);

@@ -42,7 +42,8 @@ namespace Tensorflow
var attrs = new Dictionary<string, object>();
var inputs = new List<Tensor>();
var input_types = new List<TF_DataType>();
dynamic values = null;

return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope =>
{
var inferred_from = new Dictionary<string, object>();
@@ -53,7 +54,17 @@ namespace Tensorflow
foreach (var input_arg in op_def.InputArg)
{
var input_name = input_arg.Name;
var values = keywords[input_name];
if (keywords.ContainsKey(input_name))
values = keywords[input_name];
else if (keywords.ContainsKey(input_name + "_"))
{
input_name += "_";
values = keywords[input_name];
}
else
throw new TypeError("No argument for input " + input_name);
// Goals:
// * Convert values to Tensors if it contains constants.
// * Verify that values is a list if that matches the input_arg's
@@ -92,8 +103,8 @@ namespace Tensorflow

values = ops.internal_convert_n_to_tensor(values,
name: input_arg.Name,
dtype: dtype,
preferred_dtype: default_dtype,
dtype: dtype.as_tf_dtype(),
preferred_dtype: default_dtype.as_tf_dtype(),
as_ref: input_arg.IsRef);
}
else
@@ -107,9 +118,9 @@ namespace Tensorflow

values = ops.internal_convert_to_tensor(values,
name: input_name,
dtype: dtype,
dtype: dtype.as_tf_dtype(),
as_ref: input_arg.IsRef,
preferred_dtype: default_dtype);
preferred_dtype: default_dtype.as_tf_dtype());

//if (!String.IsNullOrEmpty(input_arg.TypeAttr))
//attrs[input_arg.TypeAttr] = values.dtype;


+ 6
- 4
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -7,6 +7,8 @@ namespace Tensorflow
{
public class array_ops
{
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") => gen_array_ops.placeholder_with_default(input, shape, name);

public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
{
dtype = dtype.as_base_dtype();
@@ -35,13 +37,13 @@ namespace Tensorflow
var nd = np.zeros<T>(shape);
if (shape.Size < 1000)
{
return constant_op.constant(nd, name);
return constant_op.constant(nd, name: name);
}
else
{
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape());
var c = constant_op.constant(0);
return gen_array_ops.fill(tShape, c, name);
return gen_array_ops.fill(tShape, c, name: name);
}
}

@@ -99,7 +101,7 @@ namespace Tensorflow
if (optimize && input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
return constant_op.constant(nd, name);
return constant_op.constant(nd, name: name);
}
}

@@ -122,7 +124,7 @@ namespace Tensorflow
if (input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
return constant_op.constant(nd, name);
return constant_op.constant(nd, name: name);
}
}



+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -113,7 +113,7 @@ namespace Tensorflow
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "")
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "")
{
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name });
return _op.outputs[0];


+ 33
- 0
src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_random_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

/// <summary>
/// Outputs random values from a normal distribution.
/// </summary>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="seed"></param>
/// <param name="seed2"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = "")
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name,
args: new { shape, dtype, seed, seed2 });

return _op.outputs[0];
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -6,6 +6,8 @@ namespace Tensorflow
{
public class math_ops
{
public static Tensor add(Tensor x, Tensor y, string name = "") => gen_math_ops.add(x, y, name);

public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
{
var base_type = dtype.as_base_dtype();


+ 35
- 0
src/TensorFlowNET.Core/Operations/random_ops.py.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class random_ops
{
public static Tensor random_normal(int[] shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = "")
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope =>
{
var shape_tensor = _ShapeTensor(shape);
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev");
var (seed1, seed2) = random_seed.get_seed(seed);
var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2);
var mul = rnd * stddev_tensor;
var value = math_ops.add(mul, mean_tensor, name: name);
return value;
});
}

private static Tensor _ShapeTensor(int[] shape)
{
return ops.convert_to_tensor(shape, name: "shape");
}
}
}


+ 10
- 2
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -55,11 +55,19 @@ namespace Tensorflow
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch(subfeed.Value)
{
case float floatVal:
feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
break;
case int intVal:
feed_dict_tensor[subfeed_t] = (NDArray)intVal;
break;
case string str:
feed_dict_tensor[subfeed_t] = np.array(str);
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value);
feed_dict_tensor[subfeed_t] = (NDArray)str;
break;
default:
throw new NotImplementedException("_run subfeed");
}
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value);
}
}



+ 16
- 11
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -19,7 +19,12 @@ namespace Tensorflow
/// <param name="name">Optional name for the tensor.</param>
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
/// <returns></returns>
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const")
{
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true);
}

private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
{
if (tf.context.executing_eagerly())
{
@@ -27,13 +32,13 @@ namespace Tensorflow
}

Graph g = ops.get_default_graph();
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape);
var tensor_value = new AttrValue
{
Type = tensor_pb.Dtype,
Tensor = tensor_pb
};
var tensor_value = new AttrValue();
tensor_value.Tensor = tensor_util.make_tensor_proto(value,
dtype: dtype,
shape: shape,
verify_shape: verify_shape,
allow_broadcast: allow_broadcast);
var dtype_value = new AttrValue
{
Type = tensor_value.Tensor.Dtype,
@@ -44,8 +49,8 @@ namespace Tensorflow
attrs["dtype"] = dtype_value;

var op = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
new Tensor[0],
new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
attrs: attrs,
name: name);

@@ -81,7 +86,7 @@ namespace Tensorflow
if (string.IsNullOrEmpty(name))
name = "shape_as_tensor";

return constant_op.constant(s_list, name);
return constant_op.constant(s_list, name: name);
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -83,6 +83,20 @@ namespace Tensorflow
type;
}

public static TF_DataType as_tf_dtype(this DataType type)
{
TF_DataType dtype = TF_DataType.DtInvalid;

switch (type)
{
default:
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;
}

public static TF_DataType as_ref(this TF_DataType type)
{
return (int)type < 100 ?


+ 155
- 12
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -10,33 +10,166 @@ namespace Tensorflow
{
public static class tensor_util
{
public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false)
public static TF_DataType[] _TENSOR_CONTENT_TYPES =
{
var shape = nd.Storage.Shape;
TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, TF_DataType.TF_INT32, TF_DataType.TF_UINT8, TF_DataType.TF_INT16,
TF_DataType.TF_INT8, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16,
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};

/// <summary>
/// Create a TensorProto.
/// </summary>
/// <param name="values"></param>
/// <param name="dtype"></param>
/// <param name="shape"></param>
/// <param name="verify_shape"></param>
/// <param name="allow_broadcast"></param>
/// <returns></returns>
public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false)
{
if (allow_broadcast && verify_shape)
throw new ValueError("allow_broadcast and verify_shape are not both allowed.");
if (values is TensorProto tp)
return tp;

if (dtype != TF_DataType.DtInvalid)
;

bool is_quantized = new TF_DataType[]
{
TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16,
TF_DataType.TF_QINT32
}.Contains(dtype);

// We first convert value to a numpy array or scalar.
NDArray nparray = null;

if (values is NDArray nd)
{
nparray = nd;
}
else
{
if (values == null)
throw new ValueError("None values not supported.");

switch (values)
{
/*case bool boolVal:
nparray = boolVal;
break;*/
case int intVal:
nparray = intVal;
break;
case int[] intVals:
nparray = np.array(intVals);
break;
case float floatVal:
nparray = floatVal;
break;
case double doubleVal:
nparray = doubleVal;
break;
case string strVal:
nparray = strVal;
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
}

var numpy_dtype = dtypes.as_dtype(nparray.dtype);
if (numpy_dtype == TF_DataType.DtInvalid)
throw new TypeError($"Unrecognized data type: {nparray.dtype}");

// If dtype was specified and is a quantized type, we convert
// numpy_dtype back into the quantized version.
if (is_quantized)
numpy_dtype = dtype;

bool is_same_size = false;
int shape_size = 0;

// If shape is not given, get the shape from the numpy array.
if (shape == null)
{
shape = nparray.shape;
is_same_size = true;
shape_size = nparray.size;
}
else
{
throw new NotImplementedException("make_tensor_proto shape not implemented");
}

var numpy_dtype = dtypes.as_dtype(nd.dtype);
var tensor_proto = new tensor_pb2.TensorProto
{
Dtype = numpy_dtype.as_datatype_enum(),
TensorShape = shape.reshape(nd.shape).as_proto()
TensorShape = tensor_util.as_shape(shape)
};

switch (nd.dtype.Name)
if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1)
{
var bytes = new List<byte>();
var nd2 = nparray.ravel();
switch (nparray.dtype.Name)
{
case "Int32":
nd2.Data<int>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Single":
nd2.Data<float>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
case "Double":
nd2.Data<double>().Select(x =>
{
bytes.AddRange(BitConverter.GetBytes(x));
return x;
}).ToArray();
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
return tensor_proto;
}

if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str)
{
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
return tensor_proto;
}

var proto_values = nparray.ravel();

switch (nparray.dtype.Name)
{
case "Bool":
tensor_proto.BoolVal.AddRange(proto_values.Data<bool>());
break;
case "Int32":
tensor_proto.IntVal.AddRange(nd.Data<int>());
tensor_proto.IntVal.AddRange(proto_values.Data<int>());
break;
case "Single":
tensor_proto.FloatVal.AddRange(nd.Data<float>());
tensor_proto.FloatVal.AddRange(proto_values.Data<float>());
break;
case "Double":
tensor_proto.DoubleVal.AddRange(nd.Data<double>());
tensor_proto.DoubleVal.AddRange(proto_values.Data<double>());
break;
case "String":
tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString())));
break;
default:
throw new Exception("Not Implemented");
throw new Exception("make_tensor_proto Not Implemented");
}

return tensor_proto;
@@ -73,14 +206,24 @@ namespace Tensorflow
return nd;
}

public static TensorShapeProto as_shape(long[] dims)
public static TensorShapeProto as_shape<T>(T[] dims)
{
TensorShapeProto shape = new TensorShapeProto();

for (int i = 0; i < dims.Length; i++)
{
var dim = new TensorShapeProto.Types.Dim();
dim.Size = dims[i];
switch(dims[i])
{
case int n:
dim.Size = n;
break;
case long l:
dim.Size = l;
break;
default:
throw new NotImplementedException("as_shape Not Implemented");
}
dim.Name = $"dim_{i}";

shape.Dim.Add(dim);


+ 2
- 8
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -7,14 +7,8 @@ namespace Tensorflow
{
public static partial class tf
{
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{
return constant_op.constant(nd, name, verify_shape);
}
public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name);

public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
{
return array_ops.zeros(shape, dtype, name);
}
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -84,8 +84,8 @@ namespace Tensorflow
name = scope;

// Add a placeholder string tensor for the filename.
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename");
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const");
var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename");
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const");
// Keep the name "Const" for backwards compatibility.

// Add the save ops.


+ 39
- 42
src/TensorFlowNET.Core/ops.py.cs View File

@@ -68,16 +68,14 @@ namespace Tensorflow
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid)
{
switch (value)
{
case Tensor val:
return val;
default:
var nd = tensor_util.convert_to_numpy_ndarray(value);
return constant_op.constant(nd, name);
}
return convert_to_tensor_v2(value, dtype, preferred_dtype, name);
}

public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = "")
{
return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false);
}

public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
@@ -87,7 +85,7 @@ namespace Tensorflow

public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
{
return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref);
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
}

/// <summary>
@@ -117,17 +115,14 @@ namespace Tensorflow
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

// Add inputs
if(inputs != null)
foreach (var op_input in inputs)
{
foreach (var op_input in inputs)
{
if (op_input is Tensor[] op_inputs)
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
else
throw new NotImplementedException("_create_c_op");
}
if (op_input is Tensor[] op_inputs)
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
else
throw new NotImplementedException("_create_c_op");
}

var status = new Status();
@@ -142,8 +137,8 @@ namespace Tensorflow
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);

status.Check(true);
}
@@ -385,8 +380,8 @@ namespace Tensorflow
return ret.ToArray();
}

public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, TF_DataType dtype = TF_DataType.DtInvalid,
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid,
bool as_ref = false)
{
var ret = new List<Tensor>();
@@ -400,28 +395,30 @@ namespace Tensorflow
return ret.ToArray();
}

public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid,
string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid,
bool as_ref = false)
{
switch (typeof(T).Name)
switch (value)
{
case "Tensor":
return value as Tensor;
case "String":
return constant_op.constant(Convert.ToString(value), name);
case "String[]":
return constant_op.constant(value as string[], name);
case "Int32":
return constant_op.constant(Convert.ToInt32(value), name);
case "Single":
return constant_op.constant(Convert.ToSingle(value), name);
case "Double":
return constant_op.constant(Convert.ToDouble(value), name);
case "RefVariable":
return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref);
case Tensor tensor:
return tensor;
case string str:
return constant_op.constant(str, dtype: dtype, name: name);
case string[] strArray:
return constant_op.constant(strArray, dtype: dtype, name: name);
case int intVal:
return constant_op.constant(intVal, dtype: dtype, name: name);
case int[] intArray:
return constant_op.constant(intArray, dtype: dtype, name: name);
case float floatVal:
return constant_op.constant(floatVal, dtype: dtype, name: name);
case double doubleVal:
return constant_op.constant(doubleVal, dtype: dtype, name: name);
case RefVariable varVal:
return varVal._TensorConversionFunction(as_ref: as_ref);
default:
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor");
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor");
}
}
}


+ 1
- 1
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -79,7 +79,7 @@ namespace TensorFlowNET.UnitTest

Assert.AreEqual(result.shape[0], 2);
Assert.AreEqual(result.shape[1], 3);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 2, 1, 1, 1, 3 }, data));
});
}



+ 29
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -17,6 +17,35 @@ namespace TensorFlowNET.UnitTest
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
}

[TestMethod]
public void ImportGraph()
{
var v = tf.Variable(0, name: "my_variable");
var sess = tf.Session();
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
}

[TestMethod]
public void SaveSimple()
{
var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1");
var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2");

var init_op = tf.global_variables_initializer();

// Add ops to save and restore all the variables.
var saver = tf.train.Saver();

with<Session>(tf.Session(), sess =>
{
sess.run(init_op);

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
});
}

[TestMethod]
public void Save()
{


Loading…
Cancel
Save