You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrainingUtil.cs 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Train
  6. {
  7. public class TrainingUtil
  8. {
  9. public static RefVariable create_global_step(Graph graph)
  10. {
  11. graph = graph ?? ops.get_default_graph();
  12. if (get_global_step(graph) != null)
  13. throw new ValueError("global_step already exists.");
  14. // Create in proper graph and base name_scope.
  15. var g = graph.as_default();
  16. g.name_scope(null);
  17. var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64,
  18. initializer: tf.zeros_initializer,
  19. trainable: false,
  20. aggregation: VariableAggregation.OnlyFirstReplica,
  21. collections: new List<string> { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP });
  22. return v;
  23. }
  24. public static RefVariable get_global_step(Graph graph)
  25. {
  26. graph = graph ?? ops.get_default_graph();
  27. RefVariable global_step_tensor = null;
  28. var global_step_tensors = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_STEP);
  29. if (global_step_tensors.Count == 1)
  30. {
  31. global_step_tensor = global_step_tensors[0];
  32. }
  33. else
  34. {
  35. try
  36. {
  37. global_step_tensor = graph.get_tensor_by_name("global_step:0");
  38. }
  39. catch (KeyError)
  40. {
  41. return null;
  42. }
  43. }
  44. return global_step_tensor;
  45. }
  46. }
  47. }