Browse Source

_create_and_assert_global_step

tags/v0.12
Oceania2018 6 years ago
parent
commit
9b9a25c712
13 changed files with 130 additions and 12 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +3
    -1
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  3. +29
    -7
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  4. +4
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +2
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  7. +2
    -1
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  8. +3
    -1
      src/TensorFlowNET.Core/Train/Optimizer.cs
  9. +18
    -2
      src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
  10. +51
    -0
      src/TensorFlowNET.Core/Train/TrainingUtil.cs
  11. +3
    -0
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  12. +6
    -0
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  13. +2
    -0
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs

+ 6
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -25,6 +25,12 @@ namespace Tensorflow

public class train_internal
{
public RefVariable create_global_step(Graph graph)
=> TrainingUtil.create_global_step(graph);

public RefVariable get_global_step(Graph graph)
=> TrainingUtil.get_global_step(graph);

public Optimizer GradientDescentOptimizer(float learning_rate)
=> new GradientDescentOptimizer(learning_rate);



+ 3
- 1
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -46,6 +46,7 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
List<string> collections = null,
bool? use_resource = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.Auto,
@@ -60,7 +61,8 @@ namespace Tensorflow
use_resource: use_resource,
validate_shape: validate_shape,
initializer: initializer,
trainable: trainable);
trainable: trainable,
collections: collections);
}
}
}

+ 29
- 7
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
using static Tensorflow.Binding;

@@ -34,32 +36,52 @@ namespace Tensorflow.Estimators
if(max_steps > 0)
{
var start_step = _load_global_step_from_checkpoint_dir(_model_dir);
if (max_steps <= start_step)
{
Console.WriteLine("Skipping training since max_steps has already saved.");
return this;
}
}

_train_model();
_train_model(input_fn);
throw new NotImplementedException("");
}

private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
{
var cp = tf.train.latest_checkpoint(checkpoint_dir);
// var cp = tf.train.latest_checkpoint(checkpoint_dir);
// should use NewCheckpointReader (not implemented)
var cp = tf.train.get_checkpoint_state(checkpoint_dir);

return 0;
return cp.AllModelCheckpointPaths.Count - 1;
}

private void _train_model()
private void _train_model(Action input_fn)
{
_train_model_default();
_train_model_default(input_fn);
}

private void _train_model_default()
private void _train_model_default(Action input_fn)
{
using (var g = tf.Graph().as_default())
{
var global_step_tensor = _create_and_assert_global_step(g);
}
}

private Tensor _create_and_assert_global_step(Graph graph)
{
var step = _create_global_step(graph);
Debug.Assert(step == tf.train.get_global_step(graph));
Debug.Assert(step.dtype.is_integer());
return step;
}

private RefVariable _create_global_step(Graph graph)
{
return tf.train.create_global_step(graph);
}

public void __init__()
{
throw new NotImplementedException();


+ 4
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -175,6 +175,10 @@ namespace Tensorflow

if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];
else
throw new KeyError($"The name {name} refers to a Tensor which does not " +
$"exist. The operation, {op_name}, does not exist in the " +
"graph.");
}
else if (!name.Contains(":") & allow_operation)
{


+ 2
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -54,6 +54,8 @@ namespace Tensorflow
return _constant_if_small(0.0D, shape, dtype, name);
case TF_DataType.TF_FLOAT:
return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT64:
return _constant_if_small(0l, shape, dtype, name);
case TF_DataType.TF_INT32:
return _constant_if_small(0, shape, dtype, name);
case TF_DataType.TF_INT8:


+ 1
- 0
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -35,5 +35,6 @@
DtFloatRef = 101, // DT_FLOAT_REF
DtDoubleRef = 102, // DT_DOUBLE_REF
DtInt32Ref = 103, // DT_INT32_REF
DtInt64Ref = 109 // DT_INT64_REF
}
}

+ 2
- 1
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -246,7 +246,8 @@ namespace Tensorflow
public static bool is_integer(this TF_DataType type)
{
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 ||
type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64;
type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 ||
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref;
}

public static bool is_floating(this TF_DataType type)


+ 3
- 1
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -249,7 +249,9 @@ namespace Tensorflow
{
_maybe_initialize_trackable();
v = variable_scope.default_variable_creator(
initial_value, name: name, trainable: false,
initial_value,
name: name,
trainable: false,
use_resource: resource_variable_ops.is_resource_variable(
colocate_with));



+ 18
- 2
src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs View File

@@ -174,8 +174,24 @@ namespace Tensorflow
var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename);
if (File.Exists(coord_checkpoint_filename))
{
var file_content = File.ReadAllBytes(coord_checkpoint_filename);
var ckpt = CheckpointState.Parser.ParseFrom(file_content);
var file_content = File.ReadAllLines(coord_checkpoint_filename);
// https://github.com/protocolbuffers/protobuf/issues/6654
// var ckpt = CheckpointState.Parser.ParseFrom(file_content);
var ckpt = new CheckpointState();
var field = CheckpointState.Descriptor.FindFieldByName("model_checkpoint_path");
ckpt.ModelCheckpointPath = file_content.FirstOrDefault(x => x.StartsWith(field.Name + ":")).Substring(field.Name.Length + 2);
// remove first and last quote.
ckpt.ModelCheckpointPath = ckpt.ModelCheckpointPath.Substring(1, ckpt.ModelCheckpointPath.Length - 2);

field = CheckpointState.Descriptor.FindFieldByName("all_model_checkpoint_paths");
file_content.Where(x => x.StartsWith(field.Name + ":"))
.ToList()
.ForEach(x =>
{
string value = x.Substring(field.Name.Length + 2);
ckpt.AllModelCheckpointPaths.Add(value.Substring(1, value.Length - 2));
});

if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}");
// For relative model_checkpoint_path and all_model_checkpoint_paths,


+ 51
- 0
src/TensorFlowNET.Core/Train/TrainingUtil.cs View File

@@ -0,0 +1,51 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
public class TrainingUtil
{
public static RefVariable create_global_step(Graph graph)
{
graph = graph ?? ops.get_default_graph();
if (get_global_step(graph) != null)
throw new ValueError("global_step already exists.");

// Create in proper graph and base name_scope.
var g = graph.as_default();
g.name_scope(null);
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64,
initializer: tf.zeros_initializer,
trainable: false,
aggregation: VariableAggregation.OnlyFirstReplica,
collections: new List<string> { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP });
return v;
}

public static RefVariable get_global_step(Graph graph)
{
graph = graph ?? ops.get_default_graph();
RefVariable global_step_tensor = null;
var global_step_tensors = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_STEP);
if (global_step_tensors.Count == 1)
{
global_step_tensor = global_step_tensors[0];
}
else
{
try
{
global_step_tensor = graph.get_tensor_by_name("global_step:0");
}
catch (KeyError)
{
return null;
}
}
return global_step_tensor;
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -50,6 +51,7 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
List<string> collections = null,
bool? use_resource = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.Auto,
@@ -67,6 +69,7 @@ namespace Tensorflow
initializer: initializer,
reuse: resue,
trainable: trainable,
collections: collections,
synchronization: synchronization,
aggregation: aggregation);
});


+ 6
- 0
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -42,6 +42,7 @@ namespace Tensorflow
object initializer = null, // IInitializer or Tensor
bool? reuse = null,
bool? trainable = null,
List<string> collections = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
@@ -54,6 +55,7 @@ namespace Tensorflow
dtype: dtype,
initializer: initializer,
trainable: trainable,
collections: collections,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
@@ -64,6 +66,7 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
object initializer = null,
bool? trainable = null,
List<string> collections = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
@@ -77,6 +80,7 @@ namespace Tensorflow
dtype: dtype,
initializer: init,
trainable: trainable,
collections: collections,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
@@ -112,6 +116,7 @@ namespace Tensorflow
IInitializer initializer = null,
bool reuse = false,
bool? trainable = null,
List<string> collections = null,
bool validate_shape = false,
bool? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.Auto,
@@ -157,6 +162,7 @@ namespace Tensorflow
v = variable_scope.default_variable_creator(init_val,
name: name,
trainable: trainable,
collections: collections,
dtype: variable_dtype,
validate_shape: validate_shape,
synchronization: synchronization,


+ 2
- 0
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -175,6 +175,7 @@ namespace Tensorflow
public static RefVariable default_variable_creator(object initial_value,
string name = null,
bool? trainable = null,
List<string> collections = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool validate_shape = false,
bool ? use_resource = null,
@@ -199,6 +200,7 @@ namespace Tensorflow
return new RefVariable(initial_value,
trainable: trainable.Value,
validate_shape: validate_shape,
collections: collections,
name: name,
dtype: dtype);
}


Loading…
Cancel
Save