@@ -28,9 +28,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="input_ops">The data input ops for an op to be created.</param> | /// <param name="input_ops">The data input ops for an op to be created.</param> | ||||
/// <returns>A list of control inputs for the op to be created.</returns> | /// <returns>A list of control inputs for the op to be created.</returns> | ||||
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||||
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||||
{ | { | ||||
Operation[] ret = new Operation[0]; | |||||
var ret = new ITensorOrOperation[0]; | |||||
foreach(var controller in _control_dependencies_stack) | foreach(var controller in _control_dependencies_stack) | ||||
{ | { | ||||
@@ -54,12 +54,12 @@ namespace Tensorflow | |||||
return ret; | return ret; | ||||
} | } | ||||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||||
{ | { | ||||
if (control_inputs == null) | if (control_inputs == null) | ||||
return new _ControlDependenciesController(this, null); | return new _ControlDependenciesController(this, null); | ||||
var control_ops = new List<Operation>(); | |||||
var control_ops = new List<ITensorOrOperation>(); | |||||
foreach (var c in control_inputs) | foreach (var c in control_inputs) | ||||
{ | { | ||||
control_ops.Add(c); | control_ops.Add(c); | ||||
@@ -298,6 +298,11 @@ namespace Tensorflow | |||||
return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
} | } | ||||
public string[] get_all_collection_keys() | |||||
{ | |||||
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||||
} | |||||
public object get_collection(string name, string scope = "") | public object get_collection(string name, string scope = "") | ||||
{ | { | ||||
return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
@@ -11,20 +11,20 @@ namespace Tensorflow | |||||
public class _ControlDependenciesController : IPython | public class _ControlDependenciesController : IPython | ||||
{ | { | ||||
private Graph _graph; | private Graph _graph; | ||||
private List<Operation> _control_inputs_val; | |||||
private List<Operation> _seen_nodes; | |||||
private List<ITensorOrOperation> _control_inputs_val; | |||||
private List<ITensorOrOperation> _seen_nodes; | |||||
private Queue<_ControlDependenciesController> _old_stack; | private Queue<_ControlDependenciesController> _old_stack; | ||||
private bool _new_stack; | private bool _new_stack; | ||||
private Context _old_control_flow_context; | private Context _old_control_flow_context; | ||||
public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||||
public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||||
public _ControlDependenciesController(Graph graph, List<ITensorOrOperation> control_inputs) | |||||
{ | { | ||||
_graph = graph; | _graph = graph; | ||||
if (control_inputs == null) | if (control_inputs == null) | ||||
{ | { | ||||
_control_inputs_val = new List<Operation>(); | |||||
_control_inputs_val = new List<ITensorOrOperation>(); | |||||
_new_stack = true; | _new_stack = true; | ||||
} | } | ||||
else | else | ||||
@@ -33,15 +33,15 @@ namespace Tensorflow | |||||
_new_stack = false; | _new_stack = false; | ||||
} | } | ||||
_seen_nodes = new List<Operation>(); | |||||
_seen_nodes = new List<ITensorOrOperation>(); | |||||
} | } | ||||
public void add_op(Operation op) | |||||
public void add_op(ITensorOrOperation op) | |||||
{ | { | ||||
_seen_nodes.Add(op); | _seen_nodes.Add(op); | ||||
} | } | ||||
public bool op_in_group(Operation op) | |||||
public bool op_in_group(ITensorOrOperation op) | |||||
{ | { | ||||
return _seen_nodes.Contains(op); | return _seen_nodes.Contains(op); | ||||
} | } | ||||
@@ -11,5 +11,6 @@ namespace Tensorflow | |||||
public interface ITensorOrOperation | public interface ITensorOrOperation | ||||
{ | { | ||||
string Device { get; } | string Device { get; } | ||||
Operation op { get; } | |||||
} | } | ||||
} | } |
@@ -107,7 +107,9 @@ namespace Tensorflow | |||||
values = ops.internal_convert_to_tensor(values, | values = ops.internal_convert_to_tensor(values, | ||||
name: input_name, | name: input_name, | ||||
as_ref: input_arg.IsRef); | |||||
dtype: dtype, | |||||
as_ref: input_arg.IsRef, | |||||
preferred_dtype: default_dtype); | |||||
//if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
//attrs[input_arg.TypeAttr] = values.dtype; | //attrs[input_arg.TypeAttr] = values.dtype; | ||||
@@ -163,14 +165,20 @@ namespace Tensorflow | |||||
foreach (var arg in op_def.OutputArg) | foreach (var arg in op_def.OutputArg) | ||||
{ | { | ||||
types = new List<TF_DataType>(); | |||||
if (!string.IsNullOrEmpty(arg.NumberAttr)) | if (!string.IsNullOrEmpty(arg.NumberAttr)) | ||||
{ | { | ||||
} | } | ||||
else if (!string.IsNullOrEmpty(arg.TypeAttr)) | else if (!string.IsNullOrEmpty(arg.TypeAttr)) | ||||
{ | { | ||||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; | |||||
} | } | ||||
if (arg.IsRef) | |||||
types = types.Select(x => x.as_ref()).ToList(); | |||||
output_types.AddRange(types); | |||||
} | } | ||||
// Add Op to graph | // Add Op to graph | ||||
@@ -16,6 +16,7 @@ namespace Tensorflow | |||||
private int _id_value; | private int _id_value; | ||||
public string type => OpType; | public string type => OpType; | ||||
public Operation op => this; | |||||
private Status status = new Status(); | private Status status = new Status(); | ||||
@@ -75,7 +76,7 @@ namespace Tensorflow | |||||
/// </param> | /// </param> | ||||
/// <param name="original_op"></param> | /// <param name="original_op"></param> | ||||
/// <param name="op_def"></param> | /// <param name="op_def"></param> | ||||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
{ | { | ||||
Graph = g; | Graph = g; | ||||
@@ -120,6 +121,11 @@ namespace Tensorflow | |||||
_control_flow_post_processing(); | _control_flow_post_processing(); | ||||
} | } | ||||
public void run(FeedItem[] feed_dict = null, Session session = null) | |||||
{ | |||||
ops._run_using_default_session(this, feed_dict, Graph, session); | |||||
} | |||||
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | ||||
{ | { | ||||
var grouped_inputs = new List<object>(); | var grouped_inputs = new List<object>(); | ||||
@@ -204,7 +210,7 @@ namespace Tensorflow | |||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}"; | |||||
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{Name}' type={OpType}"; | |||||
} | } | ||||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | public static implicit operator Operation(IntPtr handle) => new Operation(handle); | ||||
@@ -28,10 +28,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var dev = ops_on_device.Keys.First(); | var dev = ops_on_device.Keys.First(); | ||||
var deps = ops_on_device.Values.First(); | var deps = ops_on_device.Values.First(); | ||||
if (typeof(T).Name == "Operation") | |||||
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); | |||||
else | |||||
throw new NotImplementedException("control_flow_ops.group"); | |||||
return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name); | |||||
} | } | ||||
// 2-level tree. The root node is the returned NoOp node. | // 2-level tree. The root node is the returned NoOp node. | ||||
@@ -35,17 +35,8 @@ namespace Tensorflow | |||||
c_api.TF_DeleteSessionOptions(opts); | c_api.TF_DeleteSessionOptions(opts); | ||||
} | } | ||||
public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) | |||||
{ | |||||
return _run(fetches, feed_dict); | |||||
} | |||||
public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) | |||||
{ | |||||
return _run(fetches, feed_dict); | |||||
} | |||||
public virtual NDArray run(Operation fetches, FeedItem[] feed_dict = null) | |||||
public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||||
{ | { | ||||
return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
} | } | ||||
@@ -30,6 +30,11 @@ namespace Tensorflow | |||||
return tensor._handle; | return tensor._handle; | ||||
} | } | ||||
public static implicit operator Operation(Tensor tensor) | |||||
{ | |||||
return tensor.op; | |||||
} | |||||
public static implicit operator Tensor(IntPtr handle) | public static implicit operator Tensor(IntPtr handle) | ||||
{ | { | ||||
return new Tensor(handle); | return new Tensor(handle); | ||||
@@ -261,7 +261,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -83,6 +83,13 @@ namespace Tensorflow | |||||
type; | type; | ||||
} | } | ||||
public static TF_DataType as_ref(this TF_DataType type) | |||||
{ | |||||
return (int)type < 100 ? | |||||
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : | |||||
type; | |||||
} | |||||
public static bool is_complex(this TF_DataType type) | public static bool is_complex(this TF_DataType type) | ||||
{ | { | ||||
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; | return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; | ||||
@@ -7,9 +7,9 @@ namespace Tensorflow | |||||
{ | { | ||||
public class BaseSaverBuilder | public class BaseSaverBuilder | ||||
{ | { | ||||
protected int _write_version; | |||||
protected SaverDef.Types.CheckpointFormatVersion _write_version; | |||||
public BaseSaverBuilder(int write_version = 2) | |||||
public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) | |||||
{ | { | ||||
_write_version = write_version; | _write_version = write_version; | ||||
} | } | ||||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
if (_write_version == 2) | |||||
if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | |||||
{ | { | ||||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | ||||
} | } | ||||
@@ -60,7 +60,7 @@ namespace Tensorflow | |||||
bool reshape = false, | bool reshape = false, | ||||
bool sharded = false, | bool sharded = false, | ||||
int max_to_keep = 5, | int max_to_keep = 5, | ||||
double keep_checkpoint_every_n_hours = 10000, | |||||
float keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | string name = "", | ||||
bool restore_sequentially = false, | bool restore_sequentially = false, | ||||
string filename = "model", | string filename = "model", | ||||
@@ -76,7 +76,10 @@ namespace Tensorflow | |||||
if (max_to_keep < 0) | if (max_to_keep < 0) | ||||
max_to_keep = 0; | max_to_keep = 0; | ||||
Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||||
Tensor save_tensor = null; | |||||
Operation restore_op = null; | |||||
return Python.with<ops.name_scope, SaverDef>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||||
{ | { | ||||
name = scope; | name = scope; | ||||
@@ -93,14 +96,35 @@ namespace Tensorflow | |||||
else | else | ||||
{ | { | ||||
if (build_save) | if (build_save) | ||||
_AddSaveOps(filename_tensor, saveables); | |||||
save_tensor = _AddSaveOps(filename_tensor, saveables); | |||||
if (build_restore) | if (build_restore) | ||||
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||||
restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||||
} | } | ||||
var graph = ops.get_default_graph(); | |||||
var check_collection_list = graph.get_all_collection_keys(); | |||||
foreach (var collection_type in check_collection_list) | |||||
{ | |||||
foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>) | |||||
{ | |||||
} | |||||
} | |||||
return new SaverDef() | |||||
{ | |||||
FilenameTensorName = filename_tensor.name, | |||||
SaveTensorName = save_tensor.name, | |||||
RestoreOpName = restore_op.Name, | |||||
MaxToKeep = max_to_keep, | |||||
Sharded = sharded, | |||||
KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours, | |||||
Version = _write_version | |||||
}; | |||||
}); | }); | ||||
throw new NotImplementedException(""); | |||||
} | } | ||||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | ||||
@@ -6,7 +6,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | ||||
{ | { | ||||
public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||||
public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version) | |||||
{ | { | ||||
} | } | ||||
@@ -14,7 +14,7 @@ namespace Tensorflow | |||||
bool reshape = false, | bool reshape = false, | ||||
bool sharded = false, | bool sharded = false, | ||||
int max_to_keep = 5, | int max_to_keep = 5, | ||||
double keep_checkpoint_every_n_hours = 10000, | |||||
float keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | string name = "", | ||||
bool restore_sequentially = false, | bool restore_sequentially = false, | ||||
string filename = "model", | string filename = "model", | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -13,30 +14,33 @@ namespace Tensorflow | |||||
private bool _reshape; | private bool _reshape; | ||||
private bool _sharded; | private bool _sharded; | ||||
private int _max_to_keep; | private int _max_to_keep; | ||||
private double _keep_checkpoint_every_n_hours; | |||||
private float _keep_checkpoint_every_n_hours; | |||||
private string _name; | private string _name; | ||||
private bool _restore_sequentially; | private bool _restore_sequentially; | ||||
private SaverDef _saver_def; | private SaverDef _saver_def; | ||||
private ISaverBuilder _builder; | private ISaverBuilder _builder; | ||||
private bool _allow_empty; | private bool _allow_empty; | ||||
private bool _is_built; | private bool _is_built; | ||||
private int _write_version; | |||||
private SaverDef.Types.CheckpointFormatVersion _write_version; | |||||
private bool _pad_step_number; | private bool _pad_step_number; | ||||
private string _filename; | private string _filename; | ||||
private bool _is_empty; | private bool _is_empty; | ||||
private float _next_checkpoint_time; | |||||
private bool _save_relative_paths; | |||||
private bool? _object_restore_saver; | |||||
public Saver(RefVariable[] var_list = null, | public Saver(RefVariable[] var_list = null, | ||||
bool reshape = false, | bool reshape = false, | ||||
bool sharded = false, | bool sharded = false, | ||||
int max_to_keep = 5, | int max_to_keep = 5, | ||||
double keep_checkpoint_every_n_hours = 10000, | |||||
float keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | string name = "", | ||||
bool restore_sequentially = false, | bool restore_sequentially = false, | ||||
SaverDef saver_def = null, | SaverDef saver_def = null, | ||||
ISaverBuilder builder = null, | ISaverBuilder builder = null, | ||||
bool defer_build = false, | bool defer_build = false, | ||||
bool allow_empty = false, | bool allow_empty = false, | ||||
int write_version = 2, | |||||
SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2, | |||||
bool pad_step_number = false, | bool pad_step_number = false, | ||||
bool save_relative_paths = false, | bool save_relative_paths = false, | ||||
string filename = "") | string filename = "") | ||||
@@ -56,6 +60,14 @@ namespace Tensorflow | |||||
if (!defer_build) | if (!defer_build) | ||||
build(); | build(); | ||||
if(_saver_def != null) | |||||
{ | |||||
_check_saver_def(); | |||||
_write_version = _saver_def.Version; | |||||
} | |||||
_save_relative_paths = save_relative_paths; | |||||
_object_restore_saver = null; | |||||
} | } | ||||
public void build() | public void build() | ||||
@@ -106,8 +118,56 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
_check_saver_def(); | |||||
_next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; | |||||
} | |||||
private void _check_saver_def() | |||||
{ | |||||
if (!tf.context.executing_eagerly()) | |||||
{ | |||||
if (string.IsNullOrEmpty(_saver_def.SaveTensorName)) | |||||
throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}"); | |||||
if (string.IsNullOrEmpty(_saver_def.RestoreOpName)) | |||||
throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}"); | |||||
} | |||||
} | |||||
public string save(Session sess, | |||||
string save_path, | |||||
string global_step = "", | |||||
string meta_graph_suffix = "meta", | |||||
bool write_meta_graph = true, | |||||
bool write_state = true, | |||||
bool strip_default_attrs = false) | |||||
{ | |||||
string latest_filename = "checkpoint"; | |||||
string model_checkpoint_path = ""; | |||||
string checkpoint_file = ""; | |||||
if (!string.IsNullOrEmpty(global_step)) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
checkpoint_file = save_path; | |||||
} | |||||
var save_path_parent = Path.GetDirectoryName(save_path); | |||||
if (!_is_empty) | |||||
{ | |||||
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||||
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | |||||
});*/ | |||||
} | |||||
throw new NotImplementedException(""); | |||||
return model_checkpoint_path; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -185,18 +185,12 @@ namespace Tensorflow | |||||
/// A `Tensor` that will hold the new value of this variable after | /// A `Tensor` that will hold the new value of this variable after | ||||
/// the assignment has completed. | /// the assignment has completed. | ||||
/// </returns> | /// </returns> | ||||
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||||
where T : ITensorOrOperation | |||||
public ITensorOrOperation assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||||
{ | { | ||||
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | ||||
if (read_value) | if (read_value) | ||||
return (T)Convert.ChangeType(assign, typeof(T)); | |||||
return (T)Convert.ChangeType(assign.op, typeof(T)); | |||||
} | |||||
public Tensor assign(Tensor value, bool use_locking = false, string name = "") | |||||
{ | |||||
return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||||
return assign; | |||||
return assign.op; | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
@@ -292,6 +292,27 @@ namespace Tensorflow | |||||
return tf.Session(); | return tf.Session(); | ||||
} | } | ||||
public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) | |||||
{ | |||||
if (session == null) | |||||
{ | |||||
session = get_default_session(); | |||||
if (session == null) | |||||
throw new ValueError("Cannot execute operation using `run()`: No default " + | |||||
"session is registered. Use `with " + | |||||
"sess.as_default():` or pass an explicit session to " + | |||||
"`run(session=sess)`"); | |||||
} | |||||
if (session.graph != graph) | |||||
throw new ValueError("Cannot use the default session to execute operation: " + | |||||
"the operation's graph is different from the " + | |||||
"session's graph. Pass an explicit session to " + | |||||
"run(session=sess)."); | |||||
session.run(operation, feed_dict); | |||||
} | |||||
public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op) | public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op) | ||||
{ | { | ||||
if (op.inputs == null) return null; | if (op.inputs == null) return null; | ||||
@@ -27,6 +27,13 @@ namespace TensorFlowNET.UnitTest | |||||
with<Session>(tf.Session(), sess => | with<Session>(tf.Session(), sess => | ||||
{ | { | ||||
sess.run(init_op); | sess.run(init_op); | ||||
// o some work with the model. | |||||
inc_v1.op.run(); | |||||
dec_v2.op.run(); | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
}); | }); | ||||
} | } | ||||
} | } | ||||
@@ -46,6 +46,24 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public void Assign() | |||||
{ | |||||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||||
var inc_v1 = v1.assign(v1 + 1.0f); | |||||
// Add an op to initialize the variables. | |||||
var init_op = tf.global_variables_initializer(); | |||||
with<Session>(tf.Session(), sess => | |||||
{ | |||||
sess.run(init_op); | |||||
// o some work with the model. | |||||
inc_v1.op.run(); | |||||
}); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// https://databricks.com/tensorflow/variables | /// https://databricks.com/tensorflow/variables | ||||
/// </summary> | /// </summary> | ||||