Browse Source

overload ndarray == operator

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
e748a8ff07
3 changed files with 15 additions and 3 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
  2. +8
    -2
      src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs
  3. +4
    -1
      src/TensorFlowNET.Core/Training/Saving/Saver.cs

+ 3
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs View File

@@ -22,5 +22,8 @@ namespace Tensorflow.NumPy
public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs)); public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs));
[AutoNumPy] [AutoNumPy]
public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs)); public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs));
[AutoNumPy]
public static bool operator ==(NDArray lhs, NDArray rhs) => rhs is null ? false : (bool)math_ops.equal(lhs, rhs);
public static bool operator !=(NDArray lhs, NDArray rhs) => !(lhs == rhs);
} }
} }

+ 8
- 2
src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs View File

@@ -62,10 +62,16 @@ namespace Tensorflow.NumPy
[AutoNumPy] [AutoNumPy]
public static NDArray load(string file) => tf.numpy.load(file); public static NDArray load(string file) => tf.numpy.load(file);


public static T Load<T>(string path)
where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
{
using (var stream = new FileStream(path, FileMode.Open))
return Load<T>(stream);
}

[AutoNumPy] [AutoNumPy]
public static T Load<T>(Stream stream) public static T Load<T>(Stream stream)
where T : class,
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
=> tf.numpy.Load<T>(stream); => tf.numpy.Load<T>(stream);


[AutoNumPy] [AutoNumPy]


+ 4
- 1
src/TensorFlowNET.Core/Training/Saving/Saver.cs View File

@@ -211,7 +211,10 @@ namespace Tensorflow
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info);
} }


return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0];
return checkpoint_file;
//var x = model_checkpoint_path[0];
//var str = x.StringData();
//return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0];
} }


public (Saver, object) import_meta_graph(string meta_graph_or_file, public (Saver, object) import_meta_graph(string meta_graph_or_file,


Loading…
Cancel
Save