@@ -28,9 +28,9 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||
/// <returns>A list of control inputs for the op to be created.</returns> | |||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||
{ | |||
Operation[] ret = new Operation[0]; | |||
var ret = new ITensorOrOperation[0]; | |||
foreach(var controller in _control_dependencies_stack) | |||
{ | |||
@@ -54,12 +54,12 @@ namespace Tensorflow | |||
return ret; | |||
} | |||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
{ | |||
if (control_inputs == null) | |||
return new _ControlDependenciesController(this, null); | |||
var control_ops = new List<Operation>(); | |||
var control_ops = new List<ITensorOrOperation>(); | |||
foreach (var c in control_inputs) | |||
{ | |||
control_ops.Add(c); | |||
@@ -298,6 +298,11 @@ namespace Tensorflow | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
public string[] get_all_collection_keys() | |||
{ | |||
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||
} | |||
public object get_collection(string name, string scope = "") | |||
{ | |||
return _collections.ContainsKey(name) ? _collections[name] : null; | |||
@@ -11,20 +11,20 @@ namespace Tensorflow | |||
public class _ControlDependenciesController : IPython | |||
{ | |||
private Graph _graph; | |||
private List<Operation> _control_inputs_val; | |||
private List<Operation> _seen_nodes; | |||
private List<ITensorOrOperation> _control_inputs_val; | |||
private List<ITensorOrOperation> _seen_nodes; | |||
private Queue<_ControlDependenciesController> _old_stack; | |||
private bool _new_stack; | |||
private Context _old_control_flow_context; | |||
public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||
public _ControlDependenciesController(Graph graph, List<ITensorOrOperation> control_inputs) | |||
{ | |||
_graph = graph; | |||
if (control_inputs == null) | |||
{ | |||
_control_inputs_val = new List<Operation>(); | |||
_control_inputs_val = new List<ITensorOrOperation>(); | |||
_new_stack = true; | |||
} | |||
else | |||
@@ -33,15 +33,15 @@ namespace Tensorflow | |||
_new_stack = false; | |||
} | |||
_seen_nodes = new List<Operation>(); | |||
_seen_nodes = new List<ITensorOrOperation>(); | |||
} | |||
public void add_op(Operation op) | |||
public void add_op(ITensorOrOperation op) | |||
{ | |||
_seen_nodes.Add(op); | |||
} | |||
public bool op_in_group(Operation op) | |||
public bool op_in_group(ITensorOrOperation op) | |||
{ | |||
return _seen_nodes.Contains(op); | |||
} | |||
@@ -11,5 +11,6 @@ namespace Tensorflow | |||
public interface ITensorOrOperation | |||
{ | |||
string Device { get; } | |||
Operation op { get; } | |||
} | |||
} |
@@ -107,7 +107,9 @@ namespace Tensorflow | |||
values = ops.internal_convert_to_tensor(values, | |||
name: input_name, | |||
as_ref: input_arg.IsRef); | |||
dtype: dtype, | |||
as_ref: input_arg.IsRef, | |||
preferred_dtype: default_dtype); | |||
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
//attrs[input_arg.TypeAttr] = values.dtype; | |||
@@ -163,14 +165,20 @@ namespace Tensorflow | |||
foreach (var arg in op_def.OutputArg) | |||
{ | |||
types = new List<TF_DataType>(); | |||
if (!string.IsNullOrEmpty(arg.NumberAttr)) | |||
{ | |||
} | |||
else if (!string.IsNullOrEmpty(arg.TypeAttr)) | |||
{ | |||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||
types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; | |||
} | |||
if (arg.IsRef) | |||
types = types.Select(x => x.as_ref()).ToList(); | |||
output_types.AddRange(types); | |||
} | |||
// Add Op to graph | |||
@@ -16,6 +16,7 @@ namespace Tensorflow | |||
private int _id_value; | |||
public string type => OpType; | |||
public Operation op => this; | |||
private Status status = new Status(); | |||
@@ -75,7 +76,7 @@ namespace Tensorflow | |||
/// </param> | |||
/// <param name="original_op"></param> | |||
/// <param name="op_def"></param> | |||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
{ | |||
Graph = g; | |||
@@ -120,6 +121,11 @@ namespace Tensorflow | |||
_control_flow_post_processing(); | |||
} | |||
public void run(FeedItem[] feed_dict = null, Session session = null) | |||
{ | |||
ops._run_using_default_session(this, feed_dict, Graph, session); | |||
} | |||
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||
{ | |||
var grouped_inputs = new List<object>(); | |||
@@ -204,7 +210,7 @@ namespace Tensorflow | |||
public override string ToString() | |||
{ | |||
return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}"; | |||
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{Name}' type={OpType}"; | |||
} | |||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||
@@ -28,10 +28,7 @@ namespace Tensorflow | |||
{ | |||
var dev = ops_on_device.Keys.First(); | |||
var deps = ops_on_device.Values.First(); | |||
if (typeof(T).Name == "Operation") | |||
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); | |||
else | |||
throw new NotImplementedException("control_flow_ops.group"); | |||
return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name); | |||
} | |||
// 2-level tree. The root node is the returned NoOp node. | |||
@@ -35,17 +35,8 @@ namespace Tensorflow | |||
c_api.TF_DeleteSessionOptions(opts); | |||
} | |||
public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
public virtual NDArray run(Operation fetches, FeedItem[] feed_dict = null) | |||
public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
@@ -30,6 +30,11 @@ namespace Tensorflow | |||
return tensor._handle; | |||
} | |||
public static implicit operator Operation(Tensor tensor) | |||
{ | |||
return tensor.op; | |||
} | |||
public static implicit operator Tensor(IntPtr handle) | |||
{ | |||
return new Tensor(handle); | |||
@@ -261,7 +261,7 @@ namespace Tensorflow | |||
} | |||
} | |||
return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
} | |||
public void Dispose() | |||
@@ -83,6 +83,13 @@ namespace Tensorflow | |||
type; | |||
} | |||
public static TF_DataType as_ref(this TF_DataType type) | |||
{ | |||
return (int)type < 100 ? | |||
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : | |||
type; | |||
} | |||
public static bool is_complex(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; | |||
@@ -7,9 +7,9 @@ namespace Tensorflow | |||
{ | |||
public class BaseSaverBuilder | |||
{ | |||
protected int _write_version; | |||
protected SaverDef.Types.CheckpointFormatVersion _write_version; | |||
public BaseSaverBuilder(int write_version = 2) | |||
public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) | |||
{ | |||
_write_version = write_version; | |||
} | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
} | |||
} | |||
if (_write_version == 2) | |||
if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | |||
{ | |||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
} | |||
@@ -60,7 +60,7 @@ namespace Tensorflow | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
float keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
string filename = "model", | |||
@@ -76,7 +76,10 @@ namespace Tensorflow | |||
if (max_to_keep < 0) | |||
max_to_keep = 0; | |||
Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
Tensor save_tensor = null; | |||
Operation restore_op = null; | |||
return Python.with<ops.name_scope, SaverDef>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
{ | |||
name = scope; | |||
@@ -93,14 +96,35 @@ namespace Tensorflow | |||
else | |||
{ | |||
if (build_save) | |||
_AddSaveOps(filename_tensor, saveables); | |||
save_tensor = _AddSaveOps(filename_tensor, saveables); | |||
if (build_restore) | |||
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||
restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||
} | |||
var graph = ops.get_default_graph(); | |||
var check_collection_list = graph.get_all_collection_keys(); | |||
foreach (var collection_type in check_collection_list) | |||
{ | |||
foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>) | |||
{ | |||
} | |||
} | |||
return new SaverDef() | |||
{ | |||
FilenameTensorName = filename_tensor.name, | |||
SaveTensorName = save_tensor.name, | |||
RestoreOpName = restore_op.Name, | |||
MaxToKeep = max_to_keep, | |||
Sharded = sharded, | |||
KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours, | |||
Version = _write_version | |||
}; | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
@@ -6,7 +6,7 @@ namespace Tensorflow | |||
{ | |||
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | |||
{ | |||
public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||
public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version) | |||
{ | |||
} | |||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
float keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
string filename = "model", | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -13,30 +14,33 @@ namespace Tensorflow | |||
private bool _reshape; | |||
private bool _sharded; | |||
private int _max_to_keep; | |||
private double _keep_checkpoint_every_n_hours; | |||
private float _keep_checkpoint_every_n_hours; | |||
private string _name; | |||
private bool _restore_sequentially; | |||
private SaverDef _saver_def; | |||
private ISaverBuilder _builder; | |||
private bool _allow_empty; | |||
private bool _is_built; | |||
private int _write_version; | |||
private SaverDef.Types.CheckpointFormatVersion _write_version; | |||
private bool _pad_step_number; | |||
private string _filename; | |||
private bool _is_empty; | |||
private float _next_checkpoint_time; | |||
private bool _save_relative_paths; | |||
private bool? _object_restore_saver; | |||
public Saver(RefVariable[] var_list = null, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
float keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
SaverDef saver_def = null, | |||
ISaverBuilder builder = null, | |||
bool defer_build = false, | |||
bool allow_empty = false, | |||
int write_version = 2, | |||
SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2, | |||
bool pad_step_number = false, | |||
bool save_relative_paths = false, | |||
string filename = "") | |||
@@ -56,6 +60,14 @@ namespace Tensorflow | |||
if (!defer_build) | |||
build(); | |||
if(_saver_def != null) | |||
{ | |||
_check_saver_def(); | |||
_write_version = _saver_def.Version; | |||
} | |||
_save_relative_paths = save_relative_paths; | |||
_object_restore_saver = null; | |||
} | |||
public void build() | |||
@@ -106,8 +118,56 @@ namespace Tensorflow | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
_check_saver_def(); | |||
_next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; | |||
} | |||
private void _check_saver_def() | |||
{ | |||
if (!tf.context.executing_eagerly()) | |||
{ | |||
if (string.IsNullOrEmpty(_saver_def.SaveTensorName)) | |||
throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}"); | |||
if (string.IsNullOrEmpty(_saver_def.RestoreOpName)) | |||
throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}"); | |||
} | |||
} | |||
public string save(Session sess, | |||
string save_path, | |||
string global_step = "", | |||
string meta_graph_suffix = "meta", | |||
bool write_meta_graph = true, | |||
bool write_state = true, | |||
bool strip_default_attrs = false) | |||
{ | |||
string latest_filename = "checkpoint"; | |||
string model_checkpoint_path = ""; | |||
string checkpoint_file = ""; | |||
if (!string.IsNullOrEmpty(global_step)) | |||
{ | |||
} | |||
else | |||
{ | |||
checkpoint_file = save_path; | |||
} | |||
var save_path_parent = Path.GetDirectoryName(save_path); | |||
if (!_is_empty) | |||
{ | |||
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | |||
});*/ | |||
} | |||
throw new NotImplementedException(""); | |||
return model_checkpoint_path; | |||
} | |||
} | |||
} |
@@ -185,18 +185,12 @@ namespace Tensorflow | |||
/// A `Tensor` that will hold the new value of this variable after | |||
/// the assignment has completed. | |||
/// </returns> | |||
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||
where T : ITensorOrOperation | |||
public ITensorOrOperation assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||
{ | |||
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||
if (read_value) | |||
return (T)Convert.ChangeType(assign, typeof(T)); | |||
return (T)Convert.ChangeType(assign.op, typeof(T)); | |||
} | |||
public Tensor assign(Tensor value, bool use_locking = false, string name = "") | |||
{ | |||
return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||
return assign; | |||
return assign.op; | |||
} | |||
public override string ToString() | |||
@@ -292,6 +292,27 @@ namespace Tensorflow | |||
return tf.Session(); | |||
} | |||
public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) | |||
{ | |||
if (session == null) | |||
{ | |||
session = get_default_session(); | |||
if (session == null) | |||
throw new ValueError("Cannot execute operation using `run()`: No default " + | |||
"session is registered. Use `with " + | |||
"sess.as_default():` or pass an explicit session to " + | |||
"`run(session=sess)`"); | |||
} | |||
if (session.graph != graph) | |||
throw new ValueError("Cannot use the default session to execute operation: " + | |||
"the operation's graph is different from the " + | |||
"session's graph. Pass an explicit session to " + | |||
"run(session=sess)."); | |||
session.run(operation, feed_dict); | |||
} | |||
public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op) | |||
{ | |||
if (op.inputs == null) return null; | |||
@@ -27,6 +27,13 @@ namespace TensorFlowNET.UnitTest | |||
with<Session>(tf.Session(), sess => | |||
{ | |||
sess.run(init_op); | |||
// o some work with the model. | |||
inc_v1.op.run(); | |||
dec_v2.op.run(); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
}); | |||
} | |||
} | |||
@@ -46,6 +46,24 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
[TestMethod] | |||
public void Assign() | |||
{ | |||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
var inc_v1 = v1.assign(v1 + 1.0f); | |||
// Add an op to initialize the variables. | |||
var init_op = tf.global_variables_initializer(); | |||
with<Session>(tf.Session(), sess => | |||
{ | |||
sess.run(init_op); | |||
// o some work with the model. | |||
inc_v1.op.run(); | |||
}); | |||
} | |||
/// <summary> | |||
/// https://databricks.com/tensorflow/variables | |||
/// </summary> | |||