You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

UninitializedVariable.cs 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Gradients;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.Variables
  7. {
  8. /// <summary>
  9. /// A variable with no initializer.
  10. /// </summary>
  11. public sealed class UninitializedVariable: BaseResourceVariable
  12. {
  13. // TODO: complete the arg list.
  14. public UninitializedVariable(
  15. bool trainable = true,
  16. string caching_device = "",
  17. string name = null,
  18. TF_DataType dtype = TF_DataType.DtInvalid,
  19. VariableAggregation aggregation = VariableAggregation.None,
  20. Shape shape = null,
  21. Tensor extra_handle_data = null)
  22. {
  23. string unique_id = "";
  24. string handle_name = "";
  25. tf_with(ops.init_scope(), (x) =>
  26. {
  27. _in_graph_mode = !tf.Context.executing_eagerly();
  28. tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name =>
  29. {
  30. handle_name = ops.name_from_scope_name(name);
  31. string? shared_name;
  32. if (_in_graph_mode)
  33. {
  34. shared_name = handle_name;
  35. unique_id = shared_name;
  36. }
  37. else
  38. {
  39. unique_id = $"{handle_name}-{ops.uid()}";
  40. shared_name = null;
  41. }
  42. var handle = resource_variable_ops.variable_handle_from_shape_and_dtype(
  43. shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data);
  44. // skip the assignment of `handle._parent_trackable` because of lack of API.
  45. // skip the assignment of `handle._name` and `handle._unique_id` because of accessability.
  46. if (_in_graph_mode)
  47. {
  48. tf_with(ops.name_scope("Read"), _ =>
  49. {
  50. var value = tf_with(ops.device(handle.Device), _ =>
  51. {
  52. var result = gen_resource_variable_ops.read_variable_op(handle, dtype);
  53. // TODO(Rinne): _maybe_set_handle_data(dtype, handle, value)
  54. return result;
  55. });
  56. _graph_element = value;
  57. });
  58. ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
  59. }
  60. else
  61. {
  62. _graph_element = null;
  63. }
  64. });
  65. });
  66. _shape = shape;
  67. _dtype = dtype;
  68. base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name);
  69. }
  70. }
  71. }