@@ -25,6 +25,12 @@ namespace Tensorflow | |||
public class train_internal | |||
{ | |||
public RefVariable create_global_step(Graph graph) | |||
=> TrainingUtil.create_global_step(graph); | |||
public RefVariable get_global_step(Graph graph) | |||
=> TrainingUtil.get_global_step(graph); | |||
public Optimizer GradientDescentOptimizer(float learning_rate) | |||
=> new GradientDescentOptimizer(learning_rate); | |||
@@ -46,6 +46,7 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
object initializer = null, // IInitializer or Tensor | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
bool? use_resource = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
@@ -60,7 +61,8 @@ namespace Tensorflow | |||
use_resource: use_resource, | |||
validate_shape: validate_shape, | |||
initializer: initializer, | |||
trainable: trainable); | |||
trainable: trainable, | |||
collections: collections); | |||
} | |||
} | |||
} |
@@ -1,5 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
@@ -34,32 +36,52 @@ namespace Tensorflow.Estimators | |||
if(max_steps > 0) | |||
{ | |||
var start_step = _load_global_step_from_checkpoint_dir(_model_dir); | |||
if (max_steps <= start_step) | |||
{ | |||
Console.WriteLine("Skipping training since max_steps has already saved."); | |||
return this; | |||
} | |||
} | |||
_train_model(); | |||
_train_model(input_fn); | |||
throw new NotImplementedException(""); | |||
} | |||
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||
{ | |||
var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
// var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
// should use NewCheckpointReader (not implemented) | |||
var cp = tf.train.get_checkpoint_state(checkpoint_dir); | |||
return 0; | |||
return cp.AllModelCheckpointPaths.Count - 1; | |||
} | |||
private void _train_model() | |||
private void _train_model(Action input_fn) | |||
{ | |||
_train_model_default(); | |||
_train_model_default(input_fn); | |||
} | |||
private void _train_model_default() | |||
private void _train_model_default(Action input_fn) | |||
{ | |||
using (var g = tf.Graph().as_default()) | |||
{ | |||
var global_step_tensor = _create_and_assert_global_step(g); | |||
} | |||
} | |||
private Tensor _create_and_assert_global_step(Graph graph) | |||
{ | |||
var step = _create_global_step(graph); | |||
Debug.Assert(step == tf.train.get_global_step(graph)); | |||
Debug.Assert(step.dtype.is_integer()); | |||
return step; | |||
} | |||
private RefVariable _create_global_step(Graph graph) | |||
{ | |||
return tf.train.create_global_step(graph); | |||
} | |||
public void __init__() | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -175,6 +175,10 @@ namespace Tensorflow | |||
if (_nodes_by_name.ContainsKey(op_name)) | |||
return _nodes_by_name[op_name].outputs[out_n]; | |||
else | |||
throw new KeyError($"The name {name} refers to a Tensor which does not " + | |||
$"exist. The operation, {op_name}, does not exist in the " + | |||
"graph."); | |||
} | |||
else if (!name.Contains(":") & allow_operation) | |||
{ | |||
@@ -54,6 +54,8 @@ namespace Tensorflow | |||
return _constant_if_small(0.0D, shape, dtype, name); | |||
case TF_DataType.TF_FLOAT: | |||
return _constant_if_small(0.0F, shape, dtype, name); | |||
case TF_DataType.TF_INT64: | |||
return _constant_if_small(0l, shape, dtype, name); | |||
case TF_DataType.TF_INT32: | |||
return _constant_if_small(0, shape, dtype, name); | |||
case TF_DataType.TF_INT8: | |||
@@ -35,5 +35,6 @@ | |||
DtFloatRef = 101, // DT_FLOAT_REF | |||
DtDoubleRef = 102, // DT_DOUBLE_REF | |||
DtInt32Ref = 103, // DT_INT32_REF | |||
DtInt64Ref = 109 // DT_INT64_REF | |||
} | |||
} |
@@ -246,7 +246,8 @@ namespace Tensorflow | |||
public static bool is_integer(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 || | |||
type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64; | |||
type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 || | |||
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | |||
} | |||
public static bool is_floating(this TF_DataType type) | |||
@@ -249,7 +249,9 @@ namespace Tensorflow | |||
{ | |||
_maybe_initialize_trackable(); | |||
v = variable_scope.default_variable_creator( | |||
initial_value, name: name, trainable: false, | |||
initial_value, | |||
name: name, | |||
trainable: false, | |||
use_resource: resource_variable_ops.is_resource_variable( | |||
colocate_with)); | |||
@@ -174,8 +174,24 @@ namespace Tensorflow | |||
var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); | |||
if (File.Exists(coord_checkpoint_filename)) | |||
{ | |||
var file_content = File.ReadAllBytes(coord_checkpoint_filename); | |||
var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
var file_content = File.ReadAllLines(coord_checkpoint_filename); | |||
// https://github.com/protocolbuffers/protobuf/issues/6654 | |||
// var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
var ckpt = new CheckpointState(); | |||
var field = CheckpointState.Descriptor.FindFieldByName("model_checkpoint_path"); | |||
ckpt.ModelCheckpointPath = file_content.FirstOrDefault(x => x.StartsWith(field.Name + ":")).Substring(field.Name.Length + 2); | |||
// remove first and last quote. | |||
ckpt.ModelCheckpointPath = ckpt.ModelCheckpointPath.Substring(1, ckpt.ModelCheckpointPath.Length - 2); | |||
field = CheckpointState.Descriptor.FindFieldByName("all_model_checkpoint_paths"); | |||
file_content.Where(x => x.StartsWith(field.Name + ":")) | |||
.ToList() | |||
.ForEach(x => | |||
{ | |||
string value = x.Substring(field.Name.Length + 2); | |||
ckpt.AllModelCheckpointPaths.Add(value.Substring(1, value.Length - 2)); | |||
}); | |||
if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); | |||
// For relative model_checkpoint_path and all_model_checkpoint_paths, | |||
@@ -0,0 +1,51 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Train | |||
{ | |||
public class TrainingUtil | |||
{ | |||
public static RefVariable create_global_step(Graph graph) | |||
{ | |||
graph = graph ?? ops.get_default_graph(); | |||
if (get_global_step(graph) != null) | |||
throw new ValueError("global_step already exists."); | |||
// Create in proper graph and base name_scope. | |||
var g = graph.as_default(); | |||
g.name_scope(null); | |||
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, | |||
initializer: tf.zeros_initializer, | |||
trainable: false, | |||
aggregation: VariableAggregation.OnlyFirstReplica, | |||
collections: new List<string> { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP }); | |||
return v; | |||
} | |||
public static RefVariable get_global_step(Graph graph) | |||
{ | |||
graph = graph ?? ops.get_default_graph(); | |||
RefVariable global_step_tensor = null; | |||
var global_step_tensors = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_STEP); | |||
if (global_step_tensors.Count == 1) | |||
{ | |||
global_step_tensor = global_step_tensors[0]; | |||
} | |||
else | |||
{ | |||
try | |||
{ | |||
global_step_tensor = graph.get_tensor_by_name("global_step:0"); | |||
} | |||
catch (KeyError) | |||
{ | |||
return null; | |||
} | |||
} | |||
return global_step_tensor; | |||
} | |||
} | |||
} |
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -50,6 +51,7 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
object initializer = null, // IInitializer or Tensor | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
bool? use_resource = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
@@ -67,6 +69,7 @@ namespace Tensorflow | |||
initializer: initializer, | |||
reuse: resue, | |||
trainable: trainable, | |||
collections: collections, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
}); | |||
@@ -42,6 +42,7 @@ namespace Tensorflow | |||
object initializer = null, // IInitializer or Tensor | |||
bool? reuse = null, | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
VariableAggregation aggregation = VariableAggregation.None) | |||
@@ -54,6 +55,7 @@ namespace Tensorflow | |||
dtype: dtype, | |||
initializer: initializer, | |||
trainable: trainable, | |||
collections: collections, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
@@ -64,6 +66,7 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
object initializer = null, | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
VariableAggregation aggregation = VariableAggregation.None) | |||
@@ -77,6 +80,7 @@ namespace Tensorflow | |||
dtype: dtype, | |||
initializer: init, | |||
trainable: trainable, | |||
collections: collections, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
@@ -112,6 +116,7 @@ namespace Tensorflow | |||
IInitializer initializer = null, | |||
bool reuse = false, | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
bool validate_shape = false, | |||
bool? use_resource = null, | |||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
@@ -157,6 +162,7 @@ namespace Tensorflow | |||
v = variable_scope.default_variable_creator(init_val, | |||
name: name, | |||
trainable: trainable, | |||
collections: collections, | |||
dtype: variable_dtype, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
@@ -175,6 +175,7 @@ namespace Tensorflow | |||
public static RefVariable default_variable_creator(object initial_value, | |||
string name = null, | |||
bool? trainable = null, | |||
List<string> collections = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool validate_shape = false, | |||
bool ? use_resource = null, | |||
@@ -199,6 +200,7 @@ namespace Tensorflow | |||
return new RefVariable(initial_value, | |||
trainable: trainable.Value, | |||
validate_shape: validate_shape, | |||
collections: collections, | |||
name: name, | |||
dtype: dtype); | |||
} | |||