Browse Source

Merge branch 'master' of github.com:Wanglongzhi2001/TensorFlow.NET into dev1

tags/v0.100.5-BERT-load
Wanglongzhi2001 2 years ago
parent
commit
f4e7fd47e7
100 changed files with 19852 additions and 980 deletions
  1. +2
    -2
      TensorFlow.NET.sln
  2. +17
    -0
      src/TensorFlowNET.Core/APIs/c_api.customize.cs
  3. +18
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  4. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.io.cs
  5. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Attributes/c_api.ops.cs
  7. +1
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  8. +6
    -0
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  9. +27
    -0
      src/TensorFlowNET.Core/Buffers/TF_Buffer.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  11. +20
    -14
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  12. +6
    -9
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  13. +9
    -6
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  14. +33
    -128
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  15. +19
    -17
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  16. +86
    -3
      src/TensorFlowNET.Core/Contexts/Context.Config.cs
  17. +8
    -0
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  18. +56
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  19. +4
    -2
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  20. +25
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
  21. +12
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  22. +1
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  23. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  24. +162
    -17
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  25. +4
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
  26. +13
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  27. +8
    -1
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  28. +53
    -0
      src/TensorFlowNET.Core/Eager/backprop_util.cs
  29. +7
    -1
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  30. +4
    -0
      src/TensorFlowNET.Core/Eager/execute.cs
  31. +13
    -0
      src/TensorFlowNET.Core/Eager/forwardprop_util.cs
  32. +31
    -0
      src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs
  33. +13
    -0
      src/TensorFlowNET.Core/Extensions/NamedTuple.cs
  34. +13
    -0
      src/TensorFlowNET.Core/Extensions/OneofExtension.cs
  35. +5
    -2
      src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs
  36. +0
    -6
      src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs
  37. +22
    -0
      src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs
  38. +11
    -1
      src/TensorFlowNET.Core/Framework/c_api_util.cs
  39. +297
    -0
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  40. +67
    -8
      src/TensorFlowNET.Core/Framework/importer.cs
  41. +12
    -0
      src/TensorFlowNET.Core/Framework/versions.cs
  42. +152
    -22
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  43. +202
    -20
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  44. +4
    -5
      src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
  45. +67
    -6
      src/TensorFlowNET.Core/Functions/Function.cs
  46. +12
    -0
      src/TensorFlowNET.Core/Functions/IGenericFunction.cs
  47. +142
    -73
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  48. +84
    -0
      src/TensorFlowNET.Core/Functions/TracingCompiler.cs
  49. +5
    -1
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  50. +50
    -0
      src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs
  51. +94
    -0
      src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs
  52. +282
    -0
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  53. +2
    -2
      src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
  54. +27
    -8
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  55. +15
    -8
      src/TensorFlowNET.Core/Gradients/ITape.cs
  56. +2
    -2
      src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
  57. +157
    -125
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  58. +31
    -32
      src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
  59. +21
    -10
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  60. +10
    -10
      src/TensorFlowNET.Core/Gradients/Tape.cs
  61. +45
    -9
      src/TensorFlowNET.Core/Gradients/TapeTensor.cs
  62. +1
    -1
      src/TensorFlowNET.Core/Gradients/TensorTape.cs
  63. +14
    -0
      src/TensorFlowNET.Core/Gradients/custom_gradient.cs
  64. +52
    -0
      src/TensorFlowNET.Core/Gradients/default_gradient.cs
  65. +71
    -4
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  66. +15
    -4
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  67. +3
    -1
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  68. +364
    -13
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  69. +8
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs
  70. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  71. +58
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  72. +37
    -0
      src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs
  73. +2
    -0
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  74. +5
    -2
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  75. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  76. +13
    -21
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  77. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
  78. +84
    -13
      src/TensorFlowNET.Core/Operations/Operation.cs
  79. +72
    -0
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  80. +46
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  81. +128
    -0
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  82. +38
    -0
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  83. +60
    -0
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  84. +129
    -44
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  85. +107
    -2
      src/TensorFlowNET.Core/Protobuf/AllocationDescription.cs
  86. +464
    -7
      src/TensorFlowNET.Core/Protobuf/ApiDef.cs
  87. +338
    -4
      src/TensorFlowNET.Core/Protobuf/AttrValue.cs
  88. +84
    -2
      src/TensorFlowNET.Core/Protobuf/CheckpointState.cs
  89. +126
    -3
      src/TensorFlowNET.Core/Protobuf/Cluster.cs
  90. +2346
    -237
      src/TensorFlowNET.Core/Protobuf/Config.cs
  91. +407
    -5
      src/TensorFlowNET.Core/Protobuf/ControlFlow.cs
  92. +791
    -0
      src/TensorFlowNET.Core/Protobuf/CoordinationConfig.cs
  93. +7964
    -0
      src/TensorFlowNET.Core/Protobuf/CoordinationService.cs
  94. +486
    -6
      src/TensorFlowNET.Core/Protobuf/CostGraph.cs
  95. +296
    -5
      src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs
  96. +1041
    -0
      src/TensorFlowNET.Core/Protobuf/DataService.cs
  97. +320
    -5
      src/TensorFlowNET.Core/Protobuf/Debug.cs
  98. +378
    -11
      src/TensorFlowNET.Core/Protobuf/DeviceAttributes.cs
  99. +660
    -9
      src/TensorFlowNET.Core/Protobuf/Event.cs
  100. +340
    -0
      src/TensorFlowNET.Core/Protobuf/Executable.cs

+ 2
- 2
TensorFlow.NET.sln View File

@@ -1,7 +1,7 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.31624.102
# Visual Studio Version 17
VisualStudioVersion = 17.4.33213.308
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject


+ 17
- 0
src/TensorFlowNET.Core/APIs/c_api.customize.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}

+ 18
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System.Text;

namespace Tensorflow
@@ -45,6 +46,23 @@ namespace Tensorflow
{
return as_text(bytes_or_text, encoding);
}

public ByteString as_bytes(ByteString bytes, Encoding encoding = null)
{
return bytes;
}
public ByteString as_bytes(byte[] bytes, Encoding encoding = null)
{
return ByteString.CopyFrom(bytes);
}
public ByteString as_bytes(string text, Encoding encoding = null)
{
if(encoding is null)
{
encoding = Encoding.UTF8;
}
return ByteString.CopyFrom(encoding.GetBytes(text));
}
}

public bool executing_eagerly()


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -54,6 +54,6 @@ namespace Tensorflow
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
string name = null,
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list);
}
}

+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
@@ -79,5 +81,10 @@ namespace Tensorflow
num_split: num_split,
axis: axis,
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Attributes/c_api.ops.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);

/// <summary>
/// Set `num_dims` to -1 to represent "unknown rank".


+ 1
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -22,6 +22,7 @@ using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Operations;

namespace Tensorflow
{


+ 6
- 0
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -107,6 +107,12 @@ namespace Tensorflow
}
}

public void Release()
{
_handle.Dispose();
_handle = null;
}

public override string ToString()
=> $"0x{_handle.DangerousGetHandle():x16}";



+ 27
- 0
src/TensorFlowNET.Core/Buffers/TF_Buffer.cs View File

@@ -25,5 +25,32 @@ namespace Tensorflow
public IntPtr data;
public ulong length;
public IntPtr data_deallocator;

public unsafe Span<T> AsSpan<T>() where T: unmanaged
{
if(length > int.MaxValue)
{
throw new ValueError($"The length {length} is too large to use in the span.");
}
return new Span<T>(data.ToPointer(), (int)length);
}

public unsafe byte[] ToByteArray()
{
byte[] res = new byte[length];
if(length > int.MaxValue)
{
byte* root = (byte*)data;
for(ulong i = 0; i < length; i++)
{
res[i] = *(root++);
}
}
else
{
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
}
return res;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -161,7 +161,7 @@ public static class CheckPointUtils

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
return full_list.Where(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;


+ 20
- 14
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -1,10 +1,12 @@
using System;
using OneOf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Common.Extensions;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint
@@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -104,7 +106,10 @@ namespace Tensorflow.Checkpoint
{
var td = trackable_data[i];
Debug.Assert(td.node_id == i);
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
TrackableObjectGraph.Types.TrackableObject trackable_object = new();
trackable_object.SlotVariables.AddRange(td.slot_variable_proto);
trackable_object.Children.AddRange(td.children_proto);
object_graph_proto.Nodes.Add(trackable_object);
}
return object_graph_proto;
}
@@ -117,16 +122,16 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -148,12 +153,12 @@ namespace Tensorflow.Checkpoint
return serialized_tensors;
}

private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
@@ -163,8 +168,7 @@ namespace Tensorflow.Checkpoint
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -173,10 +177,12 @@ namespace Tensorflow.Checkpoint

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
foreach(var key in maybe_tensor.Keys)
{
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>())
{
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name;
}
}

if(object_graph_proto is not null)
@@ -200,7 +206,7 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();


+ 6
- 9
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Google.Protobuf;
using OneOf;

namespace Tensorflow.Checkpoint;

@@ -114,14 +115,10 @@ public static class SaveUtilV1
{
var trackable = trackable_objects[i];
Debug.Assert(node_ids[trackable] == i);
TrackableObjectGraph.Types.TrackableObject object_proto;
var object_proto = new TrackableObjectGraph.Types.TrackableObject();
if (slot_variables.TryGetValue(trackable, out var slots))
{
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots);
}
else
{
object_proto = new TrackableObjectGraph.Types.TrackableObject();
object_proto.SlotVariables.AddRange(slots);
}
object_graph_proto.Nodes.Add(object_proto);
foreach (var child in graph_view.list_children(trackable))
@@ -184,13 +181,13 @@ public static class SaveUtilV1

// TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
if (maybe_saveable.TryPickT1(out var s, out var variable))
{
saveables.Add(s);
}
else
{
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key));
saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key));
}

foreach (var saveable in saveables)
@@ -222,7 +219,7 @@ public static class SaveUtilV1

public record class CheckpointFactoryData
(
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory,
string name,
string checkpoint_key
);

+ 9
- 6
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -12,6 +12,7 @@ using static Tensorflow.Binding;
using Tensorflow.Operations;
using Newtonsoft.Json;
using Tensorflow.Training;
using OneOf;

namespace Tensorflow.Checkpoint;

@@ -44,12 +45,12 @@ public class TrackableSaver
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -70,9 +71,10 @@ public class TrackableSaver
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (!serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>();
serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>();
}
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>();
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor);
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

@@ -392,6 +394,7 @@ public class CheckpointRestoreCoordinator
/// </summary>
public List<Trackable> AllTrackables => _all_trackables;
public HashSet<int> MatchedProtoIds => _matched_proto_ids;
// TODO(Rinne): change to weak ref.
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id;
public int RestoreUid => _restore_uid;
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto;
@@ -406,7 +409,7 @@ public class CheckpointRestoreCoordinator
// skip the callback.
}

public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
{
List<Operation> restore_ops = new();
foreach(var position in positions)
@@ -418,7 +421,7 @@ public class CheckpointRestoreCoordinator
Dictionary<string, BaseResourceVariable> variable_dict = new();
foreach(var item in tensor_saveables)
{
if(item.Value.TryGet<BaseResourceVariable>(out var variable))
if(item.Value.TryPickT0(out var variable, out var _))
{
variable_dict[item.Key] = variable;
}


+ 33
- 128
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -15,106 +15,14 @@ using Tensorflow.Graphs;
using System.Xml.Linq;
using System.Diagnostics;
using RestoreFunc = System.Func<object, object>;
using OneOf;

namespace Tensorflow.Checkpoint
{
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assignedTA;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assignedTA = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assignedTA = false;
}

public Type DataType => _type;

/// <summary>
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned.
/// It returns
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="res"></param>
/// <returns></returns>
public bool TryGet<T>(out T? res)
{
if(_valueA is T && _valueB is not T)
{
res = (T)(object)_valueA;
return _assignedTA;
}
else if(_valueA is not T && _valueB is T)
{
res = (T)(object)_valueB;
return !_assignedTA;
}
res = default(T);
return false;
}

public bool IsTypeOrDeriveFrom<T>()
{
if (_valueA is T && _valueB is not T)
{
return _assignedTA;
}
else if (_valueA is not T && _valueB is T)
{
return !_assignedTA;
}
else if (_valueA is T && _valueB is T)
{
return true;
}
else
{
return false;
}
}

public T GetValue<T>()
{
if (_valueA is T && _valueB is not T)
{
return (T)(object)_valueA;
}
else if (_valueA is not T && _valueB is T)
{
return (T)(object)_valueB;
}
else if (_valueA is T && _valueB is T)
{
throw new TypeError("The type is vague, this is always because TA and TB both derive from T.");
}
else
{
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}.");
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver
{
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict;
}
@@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value))
as IDictionary<string, OneOf<Tensor, SaveSpec>>);
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value))
as IDictionary<string, OneOf<Tensor, SaveSpec>>);
}
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
@@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
if(maybe_tensor.TryPickT1(out var spec, out var tensor))
{
var tensor_value = spec.tensor;
if (tensor_value is not null)
@@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint
}
else
{
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_names.Add(checkpoint_key);
tensors.Add(tensor);
slice_specs.Add(slice_spec);
@@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
if(maybe_tensor.TryPickT1(out var spec, out var tensor))
{
tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec);
@@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint
}
else
{
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key);
@@ -256,12 +162,12 @@ namespace Tensorflow.Checkpoint
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>();
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
@@ -276,9 +182,9 @@ namespace Tensorflow.Checkpoint
{
restore_fn = new RestoreFunc(x =>
{
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>)
if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>)
{
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>);
return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>);
}
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}.");
});
@@ -287,16 +193,7 @@ namespace Tensorflow.Checkpoint
foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.TryGet<Tensor>(out var t))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = t;
}
else
{
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>();
}
var spec_to_tensor = item.Value;

foreach(var spec in spec_to_tensor)
{
@@ -311,12 +208,20 @@ namespace Tensorflow.Checkpoint
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
string host_device;
if (tensor.IsT0)
{
host_device = tensor.AsT0.Device;
}
else
{
host_device = tensor.AsT1.device;
}
host_device = saveable_object_util.set_cpu0(host_device);
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
@@ -412,7 +317,7 @@ namespace Tensorflow.Checkpoint

IDictionary<string, Operation> restore_func()
{
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new();

@@ -433,29 +338,29 @@ namespace Tensorflow.Checkpoint
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
}
else
{
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
}
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach (var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
@@ -538,7 +443,7 @@ namespace Tensorflow.Checkpoint

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
foreach (var saveable in saveables)
{
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });


+ 19
- 17
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -1,7 +1,9 @@
using System;
using OneOf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -49,7 +51,7 @@ public class CheckpointPosition
{
_checkpoint.AllTrackables.Add(trackable);
_checkpoint.MatchedProtoIds.Add(_proto_id);
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null)
{
// skip the `logging.warning`.
return false;
@@ -61,13 +63,13 @@ public class CheckpointPosition
}
}

public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
{
// skip the registered_saver

if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
}

@@ -75,7 +77,7 @@ public class CheckpointPosition

List<Operation> existing_restore_ops;
List<CheckpointPosition> positions = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables;
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables;
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories);
@@ -109,8 +111,8 @@ public class CheckpointPosition
/// Creates a saveable using the _serialize_to_tensor method.
/// </summary>
/// <param name="saveable_factories"></param>
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
string suffix = SaveableCompat.get_saveable_name(this.Trackable);
suffix = suffix ?? "";
@@ -124,23 +126,23 @@ public class CheckpointPosition

var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name);
// skip the cache.
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new();
dict[saveable_name] = saveable;
return (new List<Operation>(), dict);
}

private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
// TODO(Rinne): implement it.
if(ObjectProto.Attributes is null)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>());
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>());
}

List<Operation> existing_restore_ops = new();
HashSet<string> created_compat_names = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new();
foreach (var serialized_tensor in ObjectProto.Attributes)
{
Operation existing_op;
@@ -172,12 +174,12 @@ public class CheckpointPosition
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name);
continue;
}
named_saveables[serialized_tensor.CheckpointKey] = saveable;
named_saveables[serialized_tensor.CheckpointKey] = saveable.Value;
}
return (existing_restore_ops, named_saveables);
}

private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories,
private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories,
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names)
{
var expected_factory_name = serialized_tensor.Name;
@@ -221,7 +223,7 @@ public class CheckpointPosition
Queue<(CheckpointPosition, Trackable)> visit_queue = new();
visit_queue.Enqueue((this, this.Trackable));
List<Operation> restore_ops = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
List<CheckpointPosition> positions = new();

CheckpointPosition current_position = null;
@@ -306,7 +308,7 @@ public class CheckpointPosition
}
}

private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
{
var trackable = this.Trackable;
trackable._maybe_initialize_trackable();
@@ -318,7 +320,7 @@ public class CheckpointPosition
}
else
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
}
}


+ 86
- 3
src/TensorFlowNET.Core/Contexts/Context.Config.cs View File

@@ -14,9 +14,11 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Contexts
{
@@ -25,12 +27,93 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
public ConfigProto Config { get; set; } = new ConfigProto
protected Device.PhysicalDevice[] _physical_devices;
protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index;
ConfigProto _config;
public ConfigProto Config
{
GpuOptions = new GPUOptions
get
{
_initialize_physical_devices();

var config = new ConfigProto();
if(_config is not null)
{
config.MergeFrom(_config);
}
config.LogDevicePlacement = _log_device_placement;

config.DeviceCount["CPU"] = 0;
config.DeviceCount["GPU"] = 0;
foreach(var dev in _physical_devices)
{
if (config.DeviceCount.ContainsKey(dev.DeviceType))
{
config.DeviceCount[dev.DeviceType] += 1;
}
else
{
config.DeviceCount[dev.DeviceType] = 1;
}
}

var gpu_options = _compute_gpu_options();
config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray());

return config;
}
set
{
_config = value;
}
}

protected void _initialize_physical_devices(bool reinitialize = false)
{
if(!reinitialize && _physical_devices is not null)
{
return;
}
var devs = list_physical_devices();
_physical_devices = devs.Select(d => new Device.PhysicalDevice()
{
DeviceName = d.DeviceName,
DeviceType = d.DeviceType
}).ToArray();
_physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i))
.ToDictionary(x => x.Key, x => x.Value);

_import_config();
}

protected void _import_config()
{
if(_config is null)
{
return;
}
if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus))
{
num_cpus = 1;
}
if(num_cpus != 1)
{
// TODO(Rinne): implement it.
}
};

var gpus = _physical_devices.Where(d => d.DeviceType == "GPU");
if(gpus.Count() == 0)
{
return;
}

if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count))
{
gpu_count = 0;
}

// TODO(Rinne): implement it.
}

ConfigProto MergeConfig()
{


+ 8
- 0
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -111,6 +111,14 @@ namespace Tensorflow.Contexts
return results.ToArray();
}

public bool is_custom_device(string device_name)
{
return false;
// TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs.
//ensure_initialized();
//return c_api.TFE_IsCustomDevice(_handle, device_name);
}

public EagerDeviceContext device(string name)
{
return new EagerDeviceContext(this, name);


+ 56
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -37,7 +37,26 @@ namespace Tensorflow.Contexts
public string ScopeName { get; set; } = "";
bool initialized = false;
ContextSwitchStack context_switches;
public FunctionCallOptions FunctionCallOptions { get; }
protected FunctionCallOptions _function_call_options;
public FunctionCallOptions FunctionCallOptions
{
get
{
if(_function_call_options is null)
{
var config = Config;
_function_call_options = new FunctionCallOptions()
{
Config = config
};
}
return _function_call_options;
}
set
{
_function_call_options = value;
}
}

SafeContextHandle _handle;

@@ -122,6 +141,11 @@ namespace Tensorflow.Contexts
name :
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";

public string anonymous_name()
{
return "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
}

public void graph_mode(bool isFunc = false)
=> context_switches.Push(false, isFunc);

@@ -158,6 +182,37 @@ namespace Tensorflow.Contexts
return has_graph_arg;
}

public bool has_function(string name)
{
ensure_initialized();
return c_api.TFE_ContextHasFunction(_handle, name);
}

public void add_function(SafeFuncGraphHandle fn)
{
ensure_initialized();
Status status = new();
c_api.TFE_ContextAddFunction(_handle, fn, status);
status.Check(true);
}

public void remove_function(string name)
{
ensure_initialized();
Status status = new();
c_api.TFE_ContextRemoveFunction(_handle, name, status);
status.Check(true);
}

public void add_function_def(FunctionDef fdef)
{
ensure_initialized();
var fdef_string = fdef.ToByteArray();
Status status = new Status();
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status);
status.Check(true);
}

public void restore_mode()
{
context_switches.Pop();


+ 4
- 2
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Google.Protobuf;
using Protobuf.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Contexts
@@ -9,10 +10,11 @@ namespace Tensorflow.Contexts
public class FunctionCallOptions
{
public ConfigProto Config { get; set; }
public string ExecutorType { get; set; }

public string config_proto_serialized()
public ByteString config_proto_serialized()
{
return Config.ToByteString().ToStringUtf8();
return Config.ToByteString();
}
}
}

+ 25
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs View File

@@ -12,18 +12,36 @@ namespace Tensorflow.Eager
return HasGradientTape();
}

private bool ShouldRecord(Tensor[] inputs)
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors)
{
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
var tape_set = tf.GetTapeSet();
var input_ids = MakeTensorIDList(tensors);
var input_dtypes = MakeTensorDtypeList(tensors);
bool some_tape_watching = false;
if (tape_set is not null && tape_set.Count > 0)
{
if (tape.ShouldRecord(inputs))
foreach (var tape in tape_set)
{
should_record = true;
break;
if (tape.ShouldRecord(input_ids, input_dtypes))
{
if (tape.Persistent || some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
}
}
return should_record;
// skip the forward_accumulators.

if (some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE;
}
}
}
}

+ 12
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -13,7 +13,17 @@ namespace Tensorflow.Eager
Tensor[] results,
BackwardFunction backwardFunction = null)
{
bool should_record = ShouldRecord(inputs);
var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs);
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
{
if (tape.ShouldRecord(input_ids, input_dtypes))
{
should_record = true;
break;
}
}

if (!should_record)
{
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager
op_inputs = inputs;*/

backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results);
TapeSetRecordOperation(op_name, inputs, results, backwardFunction);
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction);

return true;
}
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager
{
return HasGradientTape();
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
return tensors.Select(x => x.dtype).ToArray();
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow.Eager


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -358,7 +358,7 @@ namespace Tensorflow.Eager
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length);
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 162
- 17
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -1,6 +1,8 @@
using System;
using OneOf.Types;
using System;
using Tensorflow.Gradients;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager
/// </summary>
public partial class EagerRunner
{
/// <summary>
///
/// </summary>
/// <param name="tape"></param>
/// <param name="target"></param>
/// <param name="sources"></param>
/// <param name="output_gradients"></param>
/// <param name="unconnected_gradients">determines the value returned if the target and
/// sources are unconnected.When 'none' the value returned is None wheras when
/// 'zero' a zero tensor in the same shape as the sources is returned.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
Tensor[] output_gradients)
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients)
{
var target_vec = target;
var sources_vec = sources;
var sources_set = sources_vec;
if (!tape.Persistent)
{
var tape_set = tf.GetTapeSet();
if (tape_set.Contains(tape))
{
throw new RuntimeError("gradient() cannot be invoked within the " +
"GradientTape context (i.e., while operations are being " +
"recorded). Either move the call to gradient() to be " +
"outside the 'with tf.GradientTape' block, or " +
"use a persistent tape: " +
"'with tf.GradientTape(persistent=true)'");
}
}

var target_vec = MakeTensorIDList(target);
var sources_vec = MakeTensorIDList(sources);
HashSet<long> sources_set = new HashSet<long>(sources_vec);
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>();

int len = target.Length;
for(int i = 0; i < len; i++)
{
var target_id = target_vec[i];
if (sources_set.Contains(target_id))
{
var tensor = target[i];
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor);
}
}

List<Tensor> outgrad_vec = new();
if(output_gradients is not null)
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);

var seq_array = target;
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>();

for (int i = 0; i < target.Length; ++i)
bool unconnected_gradients_zero = unconnected_gradients == "zero";
Tensor[] sources_obj = null;
if (unconnected_gradients_zero)
{
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
sources_obj = MakeTensorList(sources_raw);
}

if (output_gradients != null)
if (result.Length > 0)
{
throw new NotImplementedException("");
for(int i = 0; i < result.Length; i++)
{
if (result[i] is null && unconnected_gradients_zero)
{
var dtype = sources_obj[i].dtype;
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike();
}
}
}
else
return result;
}

Tensor[] MakeTensorList(IEnumerable<Tensor> tensors)
{
return tensors.ToArray();
}

long[] MakeTensorIDList(Tensor[] tensors)
{
int len = tensors.Length;
long[] ids = new long[len];
for(int i = 0; i < len; i++)
{
var tensor = tensors[i];
ids[i] = tensor.Id;
}
return ids;
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
int len = tensors.Length;
TF_DataType[] dtypes = new TF_DataType[len];
for (int i = 0; i < len; i++)
{
output_gradients = new Tensor[0];
var tensor = tensors[i];
dtypes[i] = tensor.dtype;
}
return dtypes;
}

var outgrad_vec = MakeTensorList(output_gradients);
TapeTensor TapeTensorFromTensor(Tensor tensor)
{
long id = tensor.Id;
var dtype = tensor.dtype;
if (tensor is EagerTensor)
{
var handle = tensor.EagerTensorHandle;
if (DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor);
}

Status status = new();
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status);
long[] dims = new long[num_dims];
for(int i = 0; i < num_dims; i++)
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
}
else
{
return new TapeTensor(id, dtype, tensor_shape);
}
}
var shape_tuple = tensor.shape.dims;
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, dtype, tensor);
}
long[] l = new long[shape_tuple.Length];
for(int i = 0; i < shape_tuple.Length; i++)
{
if (shape_tuple[i] < 0)
{
l[i] = 0;
}
else
{
l[i] = shape_tuple[i];
}
}
return new TapeTensor(id, dtype, new Shape(l));
}

return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec);
bool DTypeNeedsHandleData(TF_DataType dtype)
{
return dtype == dtypes.variant || dtype == dtypes.resource;
}

Tensor[] MakeTensorList(Tensor[] tensors)
bool ListContainNone(long[] list)
{
return tensors;
int len = list.Length;
if(len == 0)
{
return true;
}
for(int i = 0; i < len; i++)
{
if (list[i] == -1)
{
return true;
}
}
return false;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs View File

@@ -7,8 +7,9 @@ namespace Tensorflow.Eager
public partial class EagerRunner
{
void TapeSetRecordBackprop(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
TapeTensor[] output_info,
long[] input_ids,
TF_DataType[] input_detyps,
BackwardFunction backward_function)
{
if (!CouldBackprop())
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager

foreach (var tape in tf.GetTapeSet())
{
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function);
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function);
}
}
}


+ 13
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -10,18 +10,28 @@ namespace Tensorflow.Eager
public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors,
Tensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();

var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray();
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function))
return false;

TapeSetRecordBackprop(op_type, input_tensors, output_info,
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes,
backward_function);

return true;
}

public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function)
{
var input_ids = MakeTensorIDList(input_tensors);
var input_dtypes = MakeTensorDtypeList(input_tensors);
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes,
backward_function);
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -29,7 +29,14 @@ namespace Tensorflow.Eager
Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
Tensor[] output_gradients);
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients);

void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function);

int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);

bool RecordGradient(string op_name,
Tensor[] inputs,


+ 53
- 0
src/TensorFlowNET.Core/Eager/backprop_util.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow.Eager
{
internal static class backprop_util
{
// TODO: add quantized_dtypes (after being supported).
private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[]
{
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16
});
public static bool IsTrainable(Tensor tensor)
{
var dtype = _DTypeFromTensor(tensor);
return _trainable_dtypes.Contains(dtype);
}
public static bool IsTrainable(TF_DataType dtype)
{
return _trainable_dtypes.Contains(dtype);
}

private static TF_DataType _DTypeFromTensor(Tensor tensor)
{
var dtype = tensor.dtype;
if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT)
{
CppShapeInferenceResult.Types.HandleData handle_data;
if (tensor is EagerTensor)
{
handle_data = tensor.HandleData;
}
else
{
handle_data = handle_data_util.get_resource_handle_data(tensor);
}
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null &&
handle_data.ShapeAndType.Count > 0)
{
var first_type = handle_data.ShapeAndType[0].Dtype;
if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type))
{
return first_type.as_tf_dtype();
}
}
}
return dtype;
}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -30,6 +30,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);

@@ -277,7 +280,7 @@ namespace Tensorflow
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status);
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status);

/// <summary>
///
@@ -480,5 +483,8 @@ namespace Tensorflow
IntPtr[] target, int target_size,
IntPtr[] sources, int source_size,
IntPtr[] outputs, int output_size);

[DllImport(TensorFlowLibName)]
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name);
}
}

+ 4
- 0
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -18,6 +18,10 @@ namespace Tensorflow.Eager
var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray());
}
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
}
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
string device_name = ctx.DeviceName;


+ 13
- 0
src/TensorFlowNET.Core/Eager/forwardprop_util.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public class TangentInfo
{
// TODO(Rinne): implement it.
public object Indices { get; set; }
public object Tangents { get; set; }
}
}

+ 31
- 0
src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class DictionaryExtension
{
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second)
{
first = pair.Key;
second = pair.Value;
}
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other)
{
foreach(var (key, value) in other)
{
dic[key] = value;
}
}
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue)
{
if (dic.ContainsKey(key))
{
return dic[key];
}
return defaultValue;
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Extensions/NamedTuple.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NamedTuple
{
public string Name { get; set; }
public Dictionary<string, object> ValueDict { get; set; }
}
}

+ 13
- 0
src/TensorFlowNET.Core/Extensions/OneofExtension.cs View File

@@ -0,0 +1,13 @@
using OneOf;
using System;

namespace Tensorflow.Common.Extensions
{
public static class OneofExtension
{
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src)
{
return src.Value is T;
}
}
}

+ 5
- 2
src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs View File

@@ -6,8 +6,11 @@
public class DenseSpec : TypeSpec
{
protected Shape _shape;
public Shape shape => _shape;

public Shape shape
{
get { return _shape; }
set { _shape = value; }
}
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;



+ 0
- 6
src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs View File

@@ -1,6 +0,0 @@
namespace Tensorflow.Framework.Models
{
class ScopedTFFunction
{
}
}

+ 22
- 0
src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
internal class ScopedTFFunction
{
SafeFuncGraphHandle _handle;
string _name;
public ScopedTFFunction(SafeFuncGraphHandle func, string name)
{
_handle = func;
_name = name;
}

public SafeFuncGraphHandle Get()
{
return _handle;
}
}
}

+ 11
- 1
src/TensorFlowNET.Core/Framework/c_api_util.cs View File

@@ -111,7 +111,17 @@ namespace Tensorflow

public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();

public static Buffer tf_buffer(byte[] data) => new Buffer(data);
public static Buffer tf_buffer(byte[] data = null)
{
if(data is not null)
{
return new Buffer(data); ;
}
else
{
return new Buffer();
}
}

public static IEnumerable<Operation> new_tf_operations(Graph graph)
{


+ 297
- 0
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -0,0 +1,297 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow.Framework
{
public class function_def_lib
{
// TODO(Rinne): process signatures and structured outputs.
public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature,
object? structured_outputs, List<TensorShapeProto> input_shapes = null)
{
var func_graph = new FuncGraph(fdef.Signature.Name);
if(input_shapes is null)
{
if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr))
{
var raw_input_shapes = input_shapes_attr.List.Shape;
input_shapes = new List<TensorShapeProto>();
foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y)))
{
if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0)
{
input_shapes.Add(null);
}
else
{
input_shapes.Add(input_shape);
}
}
}
}

var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes);

func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);

foreach(var node in graph_def.Node)
{
if(node.Attr.TryGetValue("_output_shapes", out var output_shapes))
{
var op = func_graph.get_operation_by_name(node.Name);
foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length)))
{
op.outputs[output_index].shape = new Shape(shape);
}
}
}
Dictionary<long, string> output_names = new();
foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names))
{
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name;
}
func_graph._output_names = output_names;

func_graph.Exit();
return func_graph;
}

public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes)
{
var graph_def = new GraphDef()
{
Versions = new VersionDef()
{
Producer = versions.GRAPH_DEF_VERSION,
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
}
};

var default_graph = ops.get_default_graph();

if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count)
{
throw new ValueError($"Length of `input_shapes` must match the number " +
$"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " +
$"{fdef.Signature.InputArg.Count} `input_arg`s.");
}

foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg))
{
NodeDef node_def = new();
node_def.Name = arg_def.Name;
node_def.Op = "Placeholder";
node_def.Attr["dtype"] = new AttrValue()
{
Type = arg_def.Type
};
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null)
{
var input_shape = input_shapes[i];
// skip the condition that input_shape is not `TensorShapeProto`.
AttrValue shape = new AttrValue()
{
Shape = new TensorShapeProto()
};
shape.Shape = new TensorShapeProto(input_shape);
node_def.Attr["shape"] = shape;
}
if (!fdef.ArgAttr.ContainsKey((uint)i))
{
fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs();
}
var arg_attrs = fdef.ArgAttr[(uint)i].Attr;
foreach(var k in arg_attrs.Keys)
{
if(k == "_output_shapes")
{
if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List)
{
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]);
}
else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape)
{
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape);
}
}
else if (k.StartsWith("_"))
{
if (!node_def.Attr.ContainsKey(k))
{
node_def.Attr[k] = new AttrValue();
}
node_def.Attr[k] = new AttrValue(arg_attrs[k]);
}
}

graph_def.Node.Add(node_def);
}

graph_def.Node.AddRange(fdef.NodeDef);

Dictionary<string, string> nested_to_flat_tensor_name = new();
foreach(var arg_def in fdef.Signature.InputArg)
{
nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0";
string control_name = "^" + arg_def.Name;
nested_to_flat_tensor_name[control_name] = control_name;
}

foreach(var node_def in fdef.NodeDef)
{
var graph = default_graph;
while (true)
{
if(graph is null)
{
break;
}
var f = graph.Functions.GetOrDefault(node_def.Op, null);
if(f is not null && graph.OuterGraph is null)
{
break;
}
graph = graph.OuterGraph;
}

var op_def = default_graph.GetOpDef(node_def.Op);

foreach(var attr in op_def.Attr)
{
if(attr.Type == "func")
{
var fname = node_def.Attr[attr.Name].Func.Name;
if (!is_function(fname))
{
throw new ValueError($"Function {fname} was not found. Please make sure " +
$"the FunctionDef `fdef` is correct.");
}
}
else if(attr.Type == "list(func)")
{
foreach(var fn in node_def.Attr[attr.Name].List.Func)
{
var fname = fn.Name;
if (!is_function(fname))
{
throw new ValueError($"Function {fname} was not found. Please make " +
$"sure the FunctionDef `fdef` is correct.");
}
}
}
}

int flattened_index = 0;
foreach(var arg_def in op_def.OutputArg)
{
var num_args = _get_num_args(arg_def, node_def);
for(int i = 0; i < num_args; i++)
{
var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}";
var flat_name = $"{node_def.Name}:{flattened_index}";
nested_to_flat_tensor_name[nested_name] = flat_name;
flattened_index++;
}
}
string control_name = "^" + node_def.Name;
nested_to_flat_tensor_name[control_name] = control_name;
}

foreach(var node_def in graph_def.Node)
{
for(int i = 0; i < node_def.Input.Count; i++)
{
node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]];
}
}

return (graph_def, nested_to_flat_tensor_name);
}

private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef)
{
foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg)))
{
if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0)
{
tensor.shape = Shape.Scalar;

var shape_and_type = arg_def.HandleData[0];
var handle_data = new HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType()
{
Shape = shape_and_type.Shape,
Dtype = shape_and_type.Dtype
});
resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true);
}
}
}

private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def)
{
if (!string.IsNullOrEmpty(arg_def.NumberAttr))
{
return node_def.Attr[arg_def.NumberAttr].I;
}
else if(!string.IsNullOrEmpty(arg_def.TypeListAttr))
{
return node_def.Attr[arg_def.TypeListAttr].List.Type.Count;
}
else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid)
{
return 1;
}
else
{
throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " +
$"FunctionDef `fdef` is correct.");
}
}

public static bool is_function(string fname)
{
if (tf.Context.executing_eagerly())
{
return tf.Context.has_function(fname);
}
else
{
var graph = ops.get_default_graph();
while(graph is not null)
{
if (graph.IsFunction(fname))
{
return true;
}
if(graph.OuterGraph is not null)
{
graph = graph.OuterGraph;
}
else
{
return false;
}
}
}
throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}

+ 67
- 8
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -17,6 +17,7 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;
@@ -25,9 +26,14 @@ namespace Tensorflow
{
public class importer
{
public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null)
{
return import_graph_def(graph_def, validate_colocation_constraints: false, name: name);
}
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
bool validate_colocation_constraints = true,
string name = null,
OpList producer_op_list = null)
{
@@ -60,7 +66,7 @@ namespace Tensorflow
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
var status = new Status();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints );
// need to create a class ImportGraphDefWithResults with IDisposal
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status));
status.Check(true);
@@ -107,21 +113,36 @@ namespace Tensorflow
foreach (var new_op in graph._add_new_tf_operations())
{
var original_device = new_op.Device;
new_op._set_device(original_device);
}
}

public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
string prefix,
Dictionary<string, Tensor> input_map,
string[] return_elements)
string[] return_elements,
bool validate_colocation_constraints)
{
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true);

foreach (var input in input_map)
{
var (src_name, src_index) = _ParseTensorName(input.Key);
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output());
var input_src = tf.compat.as_str(input.Key);
var input_dst = input.Value;
if (input_src.StartsWith("^"))
{
var src_name = tf.compat.as_str(input_src.Substring(1));
var dst_op = input_dst._as_tf_output().oper;
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op);
}
else
{
var (src_name, src_index) = _ParseTensorName(input.Key);
src_name = tf.compat.as_str(src_name);
var dst_output = input_dst._as_tf_output();
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output);
}
}

if (return_elements == null)
@@ -132,15 +153,16 @@ namespace Tensorflow
if (name.Contains(":"))
{
var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
op_name = tf.compat.as_str(op_name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index);
}
else
{
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name));
}
}

// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints);
}

private static (string, int) _ParseTensorName(string tensor_name)
@@ -173,6 +195,14 @@ namespace Tensorflow
return graph_def;
}

private static GraphDef _ProcessGraphDefParam(GraphDef graph_def)
{
var old_graph_def = graph_def;
graph_def = new GraphDef(old_graph_def);

return graph_def;
}

private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
{
foreach (var attr_def in op_def.Attr)
@@ -240,6 +270,35 @@ namespace Tensorflow
}
}

private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def)
{
var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x);

foreach (var node in graph_def.Node)
{
// Remove any default attr values that aren't in op_def.
if (producer_op_dict.ContainsKey(node.Op))
{
var op_def = op_def_registry.GetOpDef(node.Op);
if(op_def is null)
{
continue;
}
var producer_op_def = producer_op_dict[node.Op];
foreach (var key in node.Attr.Keys)
{
if (_FindAttrInOpDef(key, op_def) is null)
{
var attr_def = _FindAttrInOpDef(key, producer_op_def);
if (attr_def != null && attr_def.DefaultValue != null &&
node.Attr[key] == attr_def.DefaultValue)
node.Attr[key].ClearValue();
}
}
}
}
}

private static AttrDef _FindAttrInOpDef(string name, OpDef op_def)
{
return op_def.Attr.FirstOrDefault(x => x.Name == name);


+ 12
- 0
src/TensorFlowNET.Core/Framework/versions.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
public class versions
{
public static int GRAPH_DEF_VERSION = 1286;
public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0;
}
}

+ 152
- 22
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -1,9 +1,13 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -13,29 +17,46 @@ namespace Tensorflow.Functions
/// </summary>
public class ConcreteFunction: Trackable
{
protected IEnumerable<Tensor> _captured_inputs;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, AttrValue> _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function;
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new();
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;

public string Name => func_graph?.FuncName;
public string Name => _delayed_rewrite_functions.Forward().Name;

public Tensor[] Outputs;
public Tensor[] Outputs => func_graph.Outputs;
public Type ReturnType;
public TensorSpec[] OutputStructure;
public IEnumerable<string> ArgKeywords { get; set; }
public long NumPositionArgs { get; set; }
public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition;
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;

public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
_captured_inputs = func_graph.external_captures;
_attrs= new Dictionary<string, AttrValue>();
_set_infer_function();
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null)
{
func_graph = graph;
_captured_inputs = func_graph.external_captures;

ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
_attrs = attrs;
_set_infer_function();
}

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
@@ -53,6 +74,9 @@ namespace Tensorflow.Functions
new[] { output },
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, AttrValue>();
_set_infer_function();
}

public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
@@ -73,6 +97,9 @@ namespace Tensorflow.Functions
new[] { output.variant_tensor },
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, AttrValue>();
_set_infer_function();
}

/*public ConcreteFunction(Func<Tensors, Tensors> func,
@@ -130,39 +157,56 @@ namespace Tensorflow.Functions
{
var executing_eagerly = tf.Context.executing_eagerly();
var default_graph = ops.get_default_graph();
// TODO(Rinne): deal with `default_graph.building_function`

var tempvv = func_graph.Variables;
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph)
{
foreach(var v in this.func_graph.Variables)
{
resource_variable_ops.variable_accessed(v);
}
}

var tensor_inputs = new Tensors();
foreach (var (i, arg) in enumerate(args))
{
tensor_inputs.Add(arg);
// If we're graph building, shape inference is on.
if (!executing_eagerly)
{
}
}
tensor_inputs.AddRange(captured_inputs);
if (!executing_eagerly)
{
// TODO(Rinne): add the check
}
tensor_inputs.AddRange(captured_inputs);

args = tensor_inputs.ToArray();

var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0;
var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args);
// No tape is watching; skip to running the function.
if (possible_gradient_type == 0 && executing_eagerly)
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly)
{
var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
return _build_call_outputs(_inference_function.Call(args));
}

if (forward_backward == null)
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (executing_eagerly)
{
flat_outputs = forward_function.Call(args_with_tangents);
}
else
{
tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){
{ "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() }
}), _ =>
{
flat_outputs = forward_function.Call(args_with_tangents);
});
}
forward_backward.Record(flat_outputs);
return flat_outputs;
return _build_call_outputs(flat_outputs);
}

public void AddTograph(Graph? g = null)
@@ -171,13 +215,99 @@ namespace Tensorflow.Functions
{
g = ops.get_default_graph();
}
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
_delayed_rewrite_functions.Forward().AddToGraph(g);
}

public void SetExternalCaptures(IEnumerable<Tensor> captures)
{
_captured_inputs = captures;
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
return new ForwardBackwardCall(functions, args, tape_watching: true);
TangentInfo input_tangents;
if (executing_eagerly)
{
// TODO(Rinne): check if it needs to be implemented.
input_tangents = new TangentInfo();
}
else
{
input_tangents = new TangentInfo();
}
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
{
if(input_tangents.Indices is not null || executing_eagerly)
{
string cache_key = "first_order";
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions))
{
functions = new FirstOrderTapeGradientFunctions(func_graph, false);
_tape_functions_cache[cache_key] = functions;
}
return new ForwardBackwardCall(functions, args, tape_watching: true);
}
else
{
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true);
}
}
else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER)
{
throw new NotImplementedException();
}

// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
}

internal void set_variables(IEnumerable<IVariableV1> variables)
{
func_graph.Variables = variables;
}

internal void _set_infer_function()
{
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
_inference_function = _delayed_rewrite_functions.Forward();
}

internal void _set_function_spec(FunctionSpec spec)
{
_function_spec = null;
_pre_initialized_function_spec = spec;
_initialize_function_spec();
}

internal void _initialize_function_spec()
{
if(_pre_initialized_function_spec is null)
{
return;
}
Debug.Assert(_function_spec is null, "already initialized");
var spec = _pre_initialized_function_spec;
//var args = spec.Fullargspec.DictValue.Fields["args"];
// TODO(Rinne): self.structured_input_signature

_function_spec = new FunctionSpec()
{
Fullargspec = spec.Fullargspec,
IsMethod = spec.IsMethod,
InputSignature = spec.InputSignature
};
}

internal Func<Operation, object[], Tensor[]> _get_gradient_function()
{
return _delayed_rewrite_functions._rewrite_forward_and_call_backward;
}

private Tensors _build_call_outputs(Tensors result)
{
// TODO(Rinne): deal with `func_graph.structured_outputs`

return result;
}

public override string ToString()


+ 202
- 20
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -1,50 +1,232 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using System.Buffers;
using Tensorflow.Gradients;

namespace Tensorflow.Functions
{
public class EagerDefinedFunction
public class EagerDefinedFunction: IDisposable
{
public int _num_outputs;
public string Name => _func_graph.FuncName;
FuncGraph _graph;
FunctionDef _definition;
OpDef _signature;
string _name;
internal ScopedTFFunction _c_func;
internal Tensor[] _func_graph_outputs;
internal string _grad_func_name;
internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func;
internal EagerDefinedFunction _grad_func;
internal bool _registered_on_context = false;
public string Name => _name;
public DataType[] OutputTypes { get; protected set; }
public Shape[] OutputShapes { get; protected set; }
public FunctionDef Definition
{
get
{
if(_definition is null)
{
_definition = _get_definition();
}
return _definition;
}
}

FuncGraph _func_graph;
public EagerDefinedFunction(string name, FuncGraph graph,
public OpDef Signature
{
get
{
if( _signature is null)
{
_signature = Definition.Signature;
}
return _signature;
}
}
public unsafe EagerDefinedFunction(string name, FuncGraph graph,
Tensors inputs, Tensors outputs,
Dictionary<string, string> attrs)
Dictionary<string, AttrValue> attrs)
{
_num_outputs = outputs.Length;
var input_ops = inputs.Select(x => x.op).ToArray();
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
.Select(x => x as Operation).ToArray();
var output_names = new string[0];
var graph_output_names = graph._output_names;
string[] output_names;
if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t))))
{
output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray();
if(output_names.Distinct().Count() != output_names.Length)
{
output_names = new string[0];
}
}
else
{
output_names = new string[0];
}

_func_graph = new FuncGraph(graph, name, attrs);
_func_graph.ToGraph(operations, inputs, outputs, output_names);
Status status = new Status();
var fn = c_api.TF_GraphToFunction(graph.c_graph,
name,
false,
operations.Length,
operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(t => t._as_tf_output()).ToArray(),
outputs.Length,
outputs.Select(t => t._as_tf_output()).ToArray(),
output_names.Length != outputs.Length ? null : output_names,
IntPtr.Zero, // warning: the control output hasbben totally ignored.
null,
status);
status.Check(true);

_c_func = new ScopedTFFunction(fn, name);

foreach(var (attr_name, attr_value) in attrs)
{
var serialized = attr_value.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status);
status.Check(true);
}

var signature = _get_definition().Signature;
_name = signature.Name;
tf_with(ops.init_scope(), s =>
{
tf.Context.add_function(fn);
_registered_on_context = true;
});

_num_outputs = signature.OutputArg.Count;
OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray();
OutputShapes = outputs.Select(x => x.shape).ToArray();
_func_graph_outputs = new List<Tensor>(outputs).ToArray();
csharp_grad_func = null;
_graph = graph;
}

public Tensors Call(Tensors args)
public unsafe Tensors Call(Tensors args)
{
// TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions;
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons.

//if (function_call_options.config_proto_serialized().Length == 0)
//{
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
//}
//else
//{
// config = function_call_options.config_proto_serialized().ToStringUtf8();
//}

string executor_type = function_call_options.ExecutorType ?? "";
var executing_eagerly = tf.Context.executing_eagerly();

var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
"executor_type", executor_type,
"config_proto", config
};

var results = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
_func_graph.FuncName,
args,
attrs,
_num_outputs);
Tensor[] outputs;
if (executing_eagerly)
{
outputs = execute.executes(
Signature.Name,
_num_outputs,
args,
attrs,
tf.Context);
}
else
{
if(tf.GetTapeSet().Count == 0)
{
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
}
else
{
var tape = tf.GetTapeSet().Peek();
tape.StopRecord();
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
tape.StartRecord();
}
}
foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs))
{
handle_data_util.copy_handle_data(func_graph_output, outputs[i]);
}
if (executing_eagerly)
{
return outputs;
}
else
{
foreach(var (i, shape) in enumerate(OutputShapes))
{
outputs[i].shape = shape;
}
return outputs;
}
}

public void AddToGraph(Graph g = null)
{
if(g is null && tf.Context.executing_eagerly())
{
var ctx = tf.Context;
if (!ctx.has_function(this.Name))
{
ctx.add_function_def(Definition);
}
}
else
{
if (!g.IsFunction(Name))
{
g.AddFunction(this);
}
foreach(var f in _graph.Functions.Values)
{
if (!g.IsFunction(f.Name))
{
g.AddFunction(f);
}
}
}
}

return results;
private FunctionDef _get_definition()
{
var buffer = c_api_util.tf_buffer();
Status status = new();
c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status);
status.Check(true);
var proto_data = c_api.TF_GetBuffer(buffer);
return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>());
}

public void Dispose()
{
tf.Context.remove_function(Name);
}
}
}

+ 4
- 5
src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs View File

@@ -14,12 +14,11 @@ namespace Tensorflow.Functions

}

public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{
var outputs = _func_graph.Outputs;
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
= BuildFunctionsForOutputs(outputs, inference_args);
return _forward;
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray();
return BuildFunctionsForOutputs(outputs, inference_args);
}
}
}

+ 67
- 6
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -1,23 +1,84 @@
using System;
using Tensorflow.Functions;
using Tensorflow.Train;

namespace Tensorflow
{
public class Function: Trackable
public class Function: Trackable, IGenericFunction
{
#pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used

protected Func<Tensor[], Tensor[]> _csharp_function;
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _autograph;
protected TracingCompiler _variable_creation_fn;
public string Name { get; set; }
public Function()
public Function(Func<Tensor[], Tensor[]> csharp_function,
string name, bool auto_graph = true)
{
_csharp_function = csharp_function;
Name = name;
_autograph = auto_graph;
}

public virtual Tensors Apply(Tensors inputs)
{
if (_run_functions_eagerly())
{
return _csharp_function(inputs);
}

var result = _call(inputs);
return result;
}

public ConcreteFunction get_concrete_function(params Tensor[] args)
{
return _get_concrete_function_garbage_collected(args);
}
public Function(string name)
protected virtual Tensors _call(Tensors inputs)
{
Name = name;
if(_variable_creation_fn is not null)
{
return _variable_creation_fn.Apply(inputs);
}
_initialize(inputs);

return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
}

protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn)
{
var name = nameof(fn);
return new TracingCompiler(fn, name, autograph: _autograph);
}

protected virtual bool _run_functions_eagerly()
{
return false;
}

protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args)
{
if(_variable_creation_fn is null)
{
_initialize(args);
// TODO(Rinne): _initialize_uninitialized_variables
}

var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
return concrete;
}

private void _initialize(Tensor[] args)
{
_variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
}
}
}

+ 12
- 0
src/TensorFlowNET.Core/Functions/IGenericFunction.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Functions
{
public interface IGenericFunction
{
Tensors Apply(Tensors args);
ConcreteFunction get_concrete_function(params Tensor[] args);
}
}

+ 142
- 73
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -3,8 +3,10 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;

@@ -15,17 +17,21 @@ namespace Tensorflow.Functions
/// </summary>
public abstract class TapeGradientFunctions
{
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
string _FORWARD_PREFIX = "__forward_";
string _BACKWARD_PREFIX = "__backward_";
string _INFERENCE_PREFIX = "__inference_";
protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
protected string _FORWARD_PREFIX = "__forward_";
protected string _BACKWARD_PREFIX = "__backward_";
protected string _INFERENCE_PREFIX = "__inference_";

protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward;
protected FuncGraph _forward_graph;
protected List<int> _forwardprop_input_indices;
protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
protected int _num_inference_outputs;
protected int _num_outputs;
protected int _num_trainable_inference_outputs;
protected ConcreteFunction _backward;
BackwardFunction _backward_function_wrapper;

@@ -33,11 +39,25 @@ namespace Tensorflow.Functions
bool need_gradients_for_jvps)
{
_func_graph = func_graph;
_forward_graph = null;
_forward = null;
_backward = null;
_num_outputs = func_graph.Outputs.Length;
_forwardprop_output_indices = null;
_num_forwardprop_outputs = 0;
_num_inference_outputs = func_graph.Outputs.Length;
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count();
}

public EagerDefinedFunction Forward(Tensors inference_args)
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null)
{
return ForwardAndBackwardFunctions(inference_args);
// TODO(Rinne): add input_tangents arg.
if(_forward is null)
{
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
= ForwardAndBackwardFunctions(inference_args);
}
return _forward;
}

/// <summary>
@@ -45,11 +65,16 @@ namespace Tensorflow.Functions
/// </summary>
/// <param name="flat_outputs"></param>
/// <param name="inference_args"></param>
public void Record(Tensors flat_outputs, Tensors inference_args)
public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): add arg `input_tagents`.
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
getBackwardFunction: backward_function);
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
}
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function);
}

/// <summary>
@@ -61,66 +86,95 @@ namespace Tensorflow.Functions
/// <returns></returns>
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs)
.ToDictionary(x => x.Item1, x => x.Item2);
var captured_inputs = backward.CapturedInputs;
var remapped_captures = captured_inputs.Select(c =>
{
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value))
{
return value;
}
else
{
return c;
}
}).ToArray();
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph))
{
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph);
throw new RuntimeError($"Failed to map all backward graph captures to " +
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}");
}

Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>();
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors();
var trainable_recorded_outputs = 0;
foreach (var (output_index, output) in enumerate(outputs))
int trainable_recorded_outputs = 0;
var skip_positions = new HashSet<int>();
var relevant_outputs = outputs;
foreach (var (output_index, output) in enumerate(relevant_outputs))
{
if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output);
if (gradients_util.IsTrainable(output))
trainable_recorded_outputs += 1;
if (backprop_util.IsTrainable(output))
trainable_recorded_outputs++;
else
skip_positions.Add(output_index);
if (output.dtype == dtypes.variant)
variant_zeros_like[output_index] = default_gradient.zeros_like(output);
}

if(_backward_function_wrapper == null)
_backward_function_wrapper = (args, unneeded_gradients) =>
{
var capture_mapping = new Dictionary<long, Tensor>();
foreach (var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach (var capture in backward.CapturedInputs)
{
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(outputs))
if(backward.Outputs is null || backward.Outputs.Length == 0)
{
if (!gradients_util.IsTrainable(output))
skip_positions.Add(output_index);
return backward.FlatStructuredOutputs;
}

_backward_function_wrapper = (args, unneeded_gradients) =>
var processed_args = new Tensors();
int input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
{
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
if (skip_positions.Contains(output_index))
continue;
if (arg is null)
{
var input_placeholder = backward.Inputs[input_index];
Tensor variant_arg;
if (input_placeholder.dtype == dtypes.variant)
{
variant_arg = variant_zeros_like[output_index];
}
else
{
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder);

variant_arg = array_ops.zeros(shape, type);
}
processed_args.Add(variant_arg);
}
else
{
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
}
input_index++;
if (input_index >= backward_function_inputs)
break;
}

tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);

foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}
foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}

return gradients;
};
}
return gradients;
};

return (_backward_function_wrapper, recorded_outputs);
}
@@ -132,51 +186,66 @@ namespace Tensorflow.Functions
var trainable_indices = new List<int>();
foreach(var (index, output) in enumerate(outputs))
{
if (gradients_util.IsTrainable(output))
if (backprop_util.IsTrainable(output))
{
trainable_outputs.Add(output);
trainable_indices.Add(index);
}
}

var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name));
backwards_graph.as_default();
var gradients_wrt_outputs = new List<Tensor>();
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
{
var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output);
var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape);
gradients_wrt_outputs.Add(gradient_placeholder);
handle_data_util.copy_handle_data(output, gradient_placeholder);
}
// TODO(Rinne): with ops.device(None)
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);

var captures_from_forward = backwards_graph.external_captures
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph)
.ToArray();
HashSet<Tensor> existing_outputs = new(_func_graph.Outputs);
foreach(var capture in captures_from_forward)
{
if (!_func_graph.Outputs.Contains(capture))
if (!existing_outputs.Contains(capture))
{
existing_outputs.Add(capture);
_func_graph.Outputs.Add(capture);
}
}
backwards_graph.Exit();

var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
gradients_wrt_outputs.append(backwards_graph.internal_captures);
backwards_graph.Inputs = gradients_wrt_outputs;
backwards_graph.Outputs = gradients_wrt_inputs;
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray();
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null));

var (wrapped_forward_function, wrapped_backward_function) =
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
//var backward_function_attr = new Dictionary<string, string>();
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;

var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
//var backward_function = new ConcreteFunction(backwards_graph,
// monomorphic_function_utils._parse_func_attrs(backward_function_attr));
var forward_function_attr = new Dictionary<string, string>();
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
//var forward_function_attr = new Dictionary<string, string>();
//forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
//var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
// _func_graph.Inputs, _func_graph.Outputs,
// monomorphic_function_utils._parse_func_attrs(forward_function_attr));
return (forward_function, _func_graph, backward_function, null, 0);
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0);
}

public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{
throw new NotImplementedException("");
}


+ 84
- 0
src/TensorFlowNET.Core/Functions/TracingCompiler.cs View File

@@ -0,0 +1,84 @@
using System;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Tensorflow.Graphs;

namespace Tensorflow.Functions
{
public class TracingCompiler
{
Func<Tensor[], Tensor[]> _csharp_function;
//FunctionSpec _function_spec;
internal string _name;
bool _autograph;
Dictionary<string, ConcreteFunction> _function_cache;
Dictionary<string, AttrValue> _function_attributes;
int _tracing_count;
public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null,
Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null,
bool reduce_retracing = false, bool capture_by_value = false)
{
_csharp_function = csharp_function;
bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME);
_name = name;
_autograph = autograph;
_function_attributes = attributes ?? new Dictionary<string, AttrValue>();
_function_cache = new Dictionary<string, ConcreteFunction>();
_tracing_count = 0;
}

public Tensor[] Apply(Tensor[] inputs)
{
// TODO(Rinne): add lock here.
var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs);
return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs);
}

internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args)
{
var (concrete_function, _) = _maybe_define_concrete_function(args);
return concrete_function;
}

private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args)
{
return _maybe_define_function(args);
}

private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args)
{
var lookup_func_key = make_cache_key(args);
if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function))
{
return (concrete_function, args);
}
concrete_function = _create_concrete_function(args);
_function_cache[lookup_func_key] = concrete_function;
return (concrete_function, args);
}

private ConcreteFunction _create_concrete_function(Tensor[] args)
{
_tracing_count++;
int arglen = args.Length;
var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func(
_name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()),
args, new Dictionary<string, object>(), autograph: _autograph
), _function_attributes);
return concrete_function;
}

private static string make_cache_key(Tensor[] inputs)
{
//string res = "";
//foreach (var input in inputs)
//{
// res += $"{input.name}_{input.Id}";
//}
return inputs.Length.ToString();
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using Tensorflow.Functions;

namespace Tensorflow
{
@@ -54,6 +55,9 @@ namespace Tensorflow
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func);

[DllImport(TensorFlowLibName)]
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status);
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status);
}
}

+ 50
- 0
src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs View File

@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Util;

namespace Tensorflow.Functions
{
internal static class composite_tensor_utils
{
public static List<object> flatten_with_variables(object inputs)
{
List<object> flat_inputs = new();
foreach(var value in nest.flatten(inputs))
{
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value))
{
throw new NotImplementedException("The composite tensor has not been fully supported.");
}
else
{
flat_inputs.Add(value);
}
}
return flat_inputs;
}
public static List<object> flatten_with_variables_or_variable_specs(object arg)
{
List<object> flat_inputs = new();
foreach(var value in nest.flatten(arg))
{
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value))
{
throw new NotImplementedException("The composite tensor has not been fully supported.");
}
// TODO(Rinne): deal with `VariableSpec`.
else if(value is TypeSpec type_spec && value is not TensorSpec)
{
throw new NotImplementedException("The TypeSpec has not been fully supported.");
}
else
{
flat_inputs.Add(value);
}
}
return flat_inputs;
}
}
}

+ 94
- 0
src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs View File

@@ -0,0 +1,94 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Variables;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
{
public static class function_saved_model_utils
{
/// <summary>
///
/// </summary>
/// <param name="concrete_function"></param>
/// <param name="inputs">a list tensors or other objects (such as variables) which
/// contain tensors that were originally captured by the function</param>
public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs)
{
var bound_inputs = inputs?.Select(obj =>
{
if(obj is Tensor tensor)
{
return get_tensor_from_node(tensor);
}
else if(obj is IVariableV1 variable)
{
return get_tensor_from_node(variable);
}
else
{
throw new TypeError("Encountered an type error, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
});
var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x);

List<Tensor> captured_inputs_list = new();
concrete_function.set_variables(bound_variables);

if (bound_inputs is not null)
{
foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count())))
{
if(hasattr(bound_input, "__tf_experimental_restore_capture__"))
{
throw new NotImplementedException();
}
else
{
captured_inputs_list.Add(bound_input);
concrete_function.func_graph.replace_capture(bound_input, internal_capture);
if(internal_capture.dtype == dtypes.resource)
{
if (resource_variable_ops.is_resource_variable(bound_input))
{
handle_data_util.copy_handle_data(bound_input.Handle, internal_capture);
}
else
{
handle_data_util.copy_handle_data(bound_input, internal_capture);
}
}
concrete_function.func_graph.capture(bound_input);
}
}
}

if(captured_inputs_list.Any(inp => inp is null))
{
// TODO(Rinne): add warnings.
}
concrete_function.SetExternalCaptures(captured_inputs_list);
}

public static Tensor get_tensor_from_node(Tensor node)
{
return node;
}
public static Tensor get_tensor_from_node(IVariableV1 node)
{
if (resource_variable_ops.is_resource_variable(node))
{
return node.Handle;
}
else
{
throw new TypeError("Encountered an type error, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}
}

+ 282
- 0
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -0,0 +1,282 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using Tensorflow.Operations;
using Tensorflow.Framework;
using static Tensorflow.Binding;
using System.Diagnostics;

namespace Tensorflow.Functions
{
internal static class monomorphic_function_utils
{
internal static string _FORWARD_PREFIX = "__forward_";
internal static string _BACKWARD_PREFIX = "__backward_";
internal static string _INFERENCE_PREFIX = "__inference_";
internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements";
internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
public static string _inference_name(string name)
{
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}";
}
public static string _forward_name(string name)
{
return $"{_FORWARD_PREFIX}{name}_{ops.uid()}";
}
public static string _backward_name(string name)
{
return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}";
}

public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs,
FuncGraph forward_graph, FuncGraph backwards_graph)
{
string forward_function_name = _forward_name(forward_graph.Name);
Dictionary<string, AttrValue> common_attributes;
if(attrs is null)
{
common_attributes = new Dictionary<string, AttrValue>();
}
else
{
common_attributes = new Dictionary<string, AttrValue>(attrs);
}

if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME))
{
common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME);
}
var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>()
{
{FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name }
});
backward_function_attr.Update(common_attributes);
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>()
{
{BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name }
});
forward_function_attr.Update(common_attributes);
var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph,
forward_graph.Inputs, forward_graph.Outputs, forward_function_attr);
return (forward_function, backward_function);
}

public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes)
{
Dictionary<string, AttrValue> attrs = new();
foreach(var item in attributes)
{
var key = item.Key;
var value = item.Value;
if (value is AttrValue attr_value)
{
attrs[key] = attr_value;
}
else if (value is bool b)
{
attrs[key] = new AttrValue() { B = b };
}
else if (value is int i)
{
attrs[key] = new AttrValue() { I = i };
}
else if (value is float f)
{
attrs[key] = new AttrValue() { F = f };
}
else if(value is string s)
{
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) };
}
else if (value is byte[] bytes)
{
attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) };
}
else
{
throw new ValueError($"Attribute {key} must be bool, int, float, string, or " +
$"AttrValue. Got {value.GetType()}.");
}
}
return attrs;
}

public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes)
{
Dictionary<string, AttrValue> attrs = new();
foreach (var item in attributes)
{
var key = item.Key;
var value = item.Value;
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) };
}
return attrs;
}
}
public class DelayedRewriteGradientFunctions : TapeGradientFunctions
{
EagerDefinedFunction _inference_function;
Dictionary<string, AttrValue> _attrs;
int _num_inference_outputs;
Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new();
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs)
: base(func_graph, false)
{
_func_graph = func_graph;
_inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name),
_func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs);
_attrs = attrs;
_num_inference_outputs = _func_graph.Outputs.Length;
}

public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null)
{
if (input_tangents is not null)
{
throw new InvalidArgumentError($"unexpectedly got forwardprop information in " +
$"a class that does not support forwardprop.");
}
return _inference_function;
}

public override void Record(Tensors flat_outputs, Tensors inference_args)
{
var (backward_function, to_record) = _backward(flat_outputs);
foreach(var tape in tf.GetTapeSet())
{
tape.RecordOperation(_inference_function.Signature.Name, to_record,
inference_args, backward_function);
}
}

public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2)
{
if(num_doutputs == -2)
{
num_doutputs = _num_inference_outputs;
}
if(_cached_function_pairs.TryGetValue(num_doutputs, out var target))
{
return target;
}
var (forward, backward) = _construct_forward_backward(num_doutputs);
_cached_function_pairs[num_doutputs] = (forward, backward);
return (forward, backward);

}

private (BackwardFunction, Tensors) _backward(Tensors outputs)
{
Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients)
{
var call_op = outputs[0].op;
return _rewrite_forward_and_call_backward(call_op, args);
}
return (backward_function, outputs);
}

internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs)
{
var (forward_function, backward_function) = forward_backward(doutputs.Length);
if(backward_function.Outputs is null || backward_function.Outputs.Length == 0)
{
return backward_function.FlatStructuredOutputs;
}
forward_function.AddToGraph(op.graph);

op._set_func_attr("f", forward_function.Name);
op._set_type_list_attr("Tout", forward_function.OutputTypes);
op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()).
Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray()
);
for(int i = 0; i < op.outputs.Length; i++)
{
var func_graph_output = forward_function._func_graph_outputs[i];
handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]);
}

var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs).
ToDictionary(x => x.Item1, x => x.Item2);
var remapped_captures = backward_function.CapturedInputs.Select(
x => capture_mapping.GetOrDefault(ops.tensor_id(x), x)
);

List<Tensor> cleaned_doutputs = new();
foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs))
{
if (backprop_util.IsTrainable(placeholder))
{
if(doutput is IndexedSlices)
{
cleaned_doutputs.Add(ops.convert_to_tensor(doutput));
}
else if(doutput is null)
{
cleaned_doutputs.Add(default_gradient.zeros_like(placeholder));
}
else if(doutput is Tensor tensor)
{
cleaned_doutputs.Add(tensor);
}
else
{
throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward");
}
}
}

return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray());
}

private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs)
{
var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x));

List<TensorSpec> signature = new();
foreach(var t in trainable_outputs)
{
var (shape, dtype) = default_gradient.shape_and_dtype(t);
signature.Add(new TensorSpec(shape, dtype));
}

Tensor[] _backprop_function(Tensor[] grad_ys)
{
return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs,
grad_ys, src_graph: _func_graph);
}

_func_graph.as_default();
FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name));
FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y =>
{
Debug.Assert(y is Tensor);
return (Tensor)y;
}).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph);
var backwards_graph_captures = backwards_graph.external_captures;
var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph);
HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs);
foreach(var capture in captures_from_forward)
{
if (!existing_outputs.Contains(capture))
{
existing_outputs.Add(capture);
_func_graph.Outputs.Add(capture);
}
}

var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(
_attrs, _func_graph, backwards_graph);
_func_graph.Exit();
return (forward_function, backward_function);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients
/// Map from tensor to how many references still exist for this tensor in
/// the tape.
/// </summary>
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; }
public UnorderedMap<long, long> tensor_usage_counts { get; set; }
/// <summary>
/// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed.
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients
public BackpropInitialState()
{
op_tape = new OpTape();
tensor_usage_counts = new UnorderedMap<Tensor, long>();
tensor_usage_counts = new UnorderedMap<long, long>();
op_missing_tensor = new UnorderedMap<long, long>();
}
}


+ 27
- 8
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients
/// <param name="target"></param>
/// <param name="source"></param>
/// <returns></returns>
public Tensor gradient(Tensor target, Tensor source)
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
if(_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
ITape tape = stop_recording();

var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
new[] { source },
null);
output_gradients,
new[] { source },
unconnected_gradients);

return results[0];
}

public Tensor gradient(Tensor target, ResourceVariable source)
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
var results = gradient(target, new List<IVariableV1> { source });
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients);

return results[0];
}

public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 });
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients);

return (results[0], results[1]);
}

public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
if (_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
var tape = stop_recording();

var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
sources.Select(x => x.Handle).ToArray(),
null);
output_gradients,
sources.Select(x => x.Handle).ToArray(),
unconnected_gradients);

if (!tape.Persistent)
{


+ 15
- 8
src/TensorFlowNET.Core/Gradients/ITape.cs View File

@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients
public interface ITape
{
void SetTapeId(int id);
bool ShouldRecord(Tensor[] tensors);
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes);
void StartRecord();
void StopRecord();
bool Persistent { get; }
void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function);

void VariableAccessed(ResourceVariable variable);
void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function);

void VariableAccessed(IVariableV1 variable);

void Watch(Tensor x);

ResourceVariable[] WatchedVariables();
IVariableV1[] WatchedVariables();

Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients);
Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs View File

@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients
{
public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; }
public Tensor[] input_tensor_id { get; set; }
public long[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; }
public override string ToString()
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
}
}

+ 157
- 125
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -2,235 +2,246 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
public partial class Tape
{
// int kMinAggregateCount = 4;
// int kMinAggregateBytes = 128 * 1024 * 1024;
static readonly int kMinAggregateCount = 4;
static readonly int kMinAggregateBytes = 128 * 1024 * 1024;
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap;

public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients)
static Tape()
{
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids);
// var gradients_size = new UnorderedMap<Tensor, long>();
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap();
var state = PrepareBackprop(
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients,
tensor_tape_,
state.op_tape);
_functionsAcceptingNoneForIndicesMap = new();
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
}

while (!op_stack.empty())
public Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads)
{
UnorderedSet<long> sources_set = new(source_tensor_ids);
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape);
UnorderedMap<long, long> gradients_size = new();
while(op_stack.Count > 0)
{
var op = op_stack.Dequeue();
if (!state.op_tape.find(op, out var trace))
long op = op_stack.Dequeue();
if(!state.op_tape.TryGetValue(op, out var op_it))
{
continue;

// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
}
var trace = op_it;
state.op_tape.erase(op);

var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
var unneeded_gradients = new List<long>();
for (int i = 0; i < trace.input_tensor_id.Length; i++)
List<Tensor> out_gradients = new();
List<long> unneeded_gradients = new();
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++)
{
var in_tensor_id = trace.input_tensor_id[i];
if (!tensor_tape_.find(in_tensor_id) &&
!sources_set.find(in_tensor_id))
long in_tensor_id = trace.input_tensor_id[i];
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id))
{
unneeded_gradients.Add(i);
}
}

bool any_gradient_nonzero = false;
var zero_indices = new List<int>();
for (int i = 0; i < trace.output_tensor_info.Length; ++i)
List<int> zero_indices = new();
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++)
{
var id = trace.output_tensor_info[i].GetTensor();
if (!gradients.find(id, out var grad_it))
long id = trace.output_tensor_info[i].GetID();
if(!gradients.TryGetValue(id, out var grad_it))
{
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) &&
func_name_it.find(i))
out_gradients.Add(null);
if (build_default_zeros_grads)
{
out_gradients.Add(null);
}
else
{
out_gradients.Add(null);
zero_indices.Add(i);
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) ||
!func_name_it.find(i))
{
zero_indices.Add(i);
}
}
}
else
{
any_gradient_nonzero = true;
var new_gradients = grad_it.Count == 1 ?
grad_it[0] :
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients

Tensor new_gradients;
if (grad_it.Count == 1)
{
new_gradients = grad_it[0];
}
else
{
new_gradients = AggregateGradients(grad_it);
}
if (!sources_set.find(id))
{
gradients.Remove(id);
}
else
{
// grad_it.Clear();
// grad_it.Add(new_gradients);
// vspace.MarkAsResult(new_gradients);
grad_it.Clear();
grad_it.Add(new_gradients);
// MarkAsResult
}
out_gradients.Add(new_gradients);
}
}

Tensor[] in_gradients;
Tensor[] in_gradients = new Tensor[0];
if (any_gradient_nonzero)
{
// foreach (var i in zero_indices)
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray());

if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length)
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
if (!_persistent)
foreach(var i in zero_indices)
{
// trace.backward_function_deleter(trace.backward_function);
trace.backward_function = null;
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
}
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients);
}
else
{
in_gradients = new Tensor[trace.input_tensor_id.Length];
out_gradients.Clear();
}

bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length;
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k)
for(int i = 0, end = in_gradients.Length; i < end; i++)
{
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k;
var id = trace.input_tensor_id[k];
if (in_gradients[i] != null)
long id = trace.input_tensor_id[i];
if (in_gradients[i] is not null)
{
var unaggregated_grads = gradients[id];
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>());
unaggregated_grads.Add(in_gradients[i]);
/*if (unaggregated_grads.Count > kMinAggregateCount)
if(unaggregated_grads.Count > kMinAggregateCount)
{
if (!gradients_size.find(id, out var size))
if(!gradients_size.TryGetValue(id, out var size))
{
size = (long)unaggregated_grads[0].size;
size = NumElements(unaggregated_grads[0]);
gradients_size.emplace(id, size);
}

if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
{
throw new NotImplementedException("");
Tensor grad = AggregateGradients(unaggregated_grads);
unaggregated_grads.Clear();
unaggregated_grads.Add(grad);
}
}*/
}
}
if (!state.tensor_usage_counts.find(id))
if(!state.tensor_usage_counts.find(id))
{
continue;
}
state.tensor_usage_counts[id]--;
if (state.tensor_usage_counts[id] > 0)
if(state.tensor_usage_counts[id] > 0)
{
continue;
if (!tensor_tape_.find(id, out var tape_it))
}
if (!tensor_tape_.TryGetValue(id, out var tape_it))
{
if (gradients.find(id, out var grad_it))
if (gradients.find(id))
{
// foreach (var g in grad_it)
// DeleteGradient(g);
gradients.erase(id);
}
continue;
}
var op_id = tape_it;
if (op_id == -1)
long op_id = tape_it;
if(op_id == -1)
{
continue;
if (state.op_missing_tensor.find(op_id, out var missing_it))
}
if(state.op_missing_tensor.find(op_id))
{
state.op_missing_tensor[op_id]--;
if (state.op_missing_tensor[op_id] == 0)
if(state.op_missing_tensor[op_id] == 0)
{
op_stack.Enqueue(op_id);
}
}
}
}

if (state.op_tape.Count > 0)
if(state.op_tape.Count > 0)
{
throw new RuntimeError("Invalid tape state.");

var result = new Tensor[source_tensor_ids.Length];
var j = 0;
foreach (var id in source_tensor_ids)
}
Tensor[] result = new Tensor[source_tensor_ids.Length];
for(int i = 0; i < source_tensor_ids.Length; i++)
{
if (gradients.find(id, out var grad_it))
long tensor_id = source_tensor_ids[i];
if(!gradients.TryGetValue(tensor_id, out var grad_it))
{
if (grad_it.Count > 1)
result[j] = gen_math_ops.add_n(grad_it.ToArray());
else
result[j] = grad_it[0];
result[i] = null;
}
else
{
if(grad_it.Count > 1)
{
Tensor grad = AggregateGradients(grad_it);
grad_it.Clear();
grad_it.Add(grad);
}
result[i] = grad_it[0];
}
j++;
}

return result;
}

UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap()
{
var m = new UnorderedMap<string, UnorderedSet<int>>();
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
return m;
return _functionsAcceptingNoneForIndicesMap;
}

UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients,
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
TensorTape tensor_tape,
OpTape op_tape)
{
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>();
for (int i = 0; i < target_tensor_ids.Length; ++i)
var result = new UnorderedMap<long, List<Tensor>>();
for(int i = 0, end = target_tensor_ids.Length; i < end; i++)
{
var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null)
long id = target_tensor_ids[i];
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null)
{
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1)
{
if (!op_tape.find(tensor_tape[id], out var op_it))
if(!op_tape.TryGetValue(tensor_it, out var op_it))
{
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"failed to find operation producing a tensor");
"failed to find operation producing a tensor.");
}
bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
for(int j = 0; j < op_it.output_tensor_info.Length; j++)
{
if (op_it.output_tensor_info[j].GetTensor() == id)
if (op_it.output_tensor_info[j].GetID() == id)
{
found = true;
var ones = op_it.output_tensor_info[j].OnesLike();
result[id].Add(ones);
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
break;
}
}

if (!found)
{
throw new ValueError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor");
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor.");
}
}
else
{
if (sources_that_are_targets.find(id, out var source_tensor))
result[id].Add(source_tensor.OnesLike());
if(sources_that_are_targets.TryGetValue(id, out var source_tensor))
{
Tensor ones_like = BuildOnesLike(source_tensor);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
}
}
}
else
{
result[id].Add(output_gradients[i]);
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]);
}
}

@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients
}
return result;
}

Tensor BuildOnesLike(TapeTensor t)
{
return t.OnesLike();
}

Tensor AggregateGradients(List<Tensor> gradient_tensors)
{
if(gradient_tensors.Count == 0)
{
return gradient_tensors[0];
}
return tf.add_n(gradient_tensors.ToArray());
}

void DeleteGradient(Tensor gradient)
{
// Do not do anything here. Because GC will collect it when it has no reference.
}

long NumElements(Tensor tensor) => 1;
}
}

+ 31
- 32
src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs View File

@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients
{
public partial class Tape
{
public BackpropInitialState PrepareBackprop(Tensor[] target,
public BackpropInitialState PrepareBackprop(long[] target,
TensorTape tensor_tape,
OpTape op_tape,
UnorderedSet<Tensor> sources_set,
UnorderedSet<long> sources_set,
bool persistent_tape)
{
Stack<long> tensor_stack = new Stack<long>();
foreach(var t in target)
{
tensor_stack.Push(t);
}
BackpropInitialState result = new BackpropInitialState();
var tensor_stack = new Queue<Tensor>(target);
while (tensor_stack.Count > 0)
while(tensor_stack.Count > 0)
{
var tensor_id = tensor_stack.Dequeue();
if (!tensor_tape.find(tensor_id, out var op_id))
long tensor_id = tensor_stack.Pop();
if(!tensor_tape.TryGetValue(tensor_id, out var op_id))
{
continue;
if (op_id == -1 ||
!op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it))
}
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it)
|| result.op_tape.find(op_id))
{
continue;
}
result.op_tape.emplace(op_id, op_it);

foreach (var it in op_it.input_tensor_id)
foreach(var it in op_it.input_tensor_id)
{
if (result.tensor_usage_counts.find(it))
if(result.tensor_usage_counts.find(it))
{
result.tensor_usage_counts[it]++;
}
else
{
result.tensor_usage_counts[it] = 1;
if (tensor_tape.find(it))
tensor_stack.Enqueue(it);
{
tensor_stack.Push(it);
}
}
}

if (!persistent_tape)
op_tape.Remove(op_id);
{
op_tape.erase(op_id);
}
}

foreach (var pair in result.tensor_usage_counts)
foreach(var pair in result.tensor_usage_counts)
{
if (tensor_tape.find(pair.Key, out var it) && it != -1)
result.op_missing_tensor[it] += 1;
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1)
{
result.op_missing_tensor[it]++;
}
}

if (!persistent_tape)
{
// Call destructors for all unneeded gradient functions and
// clear the op_tape. We can clear the tape because ownership of
// backward functions that will be used for gradient computation
// has been transferred to `result`.
/*for (const auto&op_pair : *op_tape) {
op_pair.second.backward_function_deleter(
op_pair.second.backward_function);
}*/
op_tape.Clear();
}

return result;
}
}


+ 21
- 10
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients
public partial class Tape
{
long next_op_id_ = 0;
UnorderedMap<Tensor, long> tensor_usage_;
UnorderedMap<long, long> tensor_usage_;

public void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
if (!ShouldRecord(input_tensors))
if (!ShouldRecord(input_tensor_id, input_dtypes))
return;

var op_id = next_op_id_++;
foreach (var i in input_tensors)
foreach (var i in input_tensor_id)
{
tensor_usage_[i]++;

}
long op_id = next_op_id_++;
foreach (var o in output_tensors)
{
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
tensor_tape_[o.GetTensor()] = op_id;
tensor_usage_[o.GetTensor()] = 1;
tensor_tape_[o.GetID()] = op_id;
tensor_usage_[o.GetID()] = 1;
}

op_tape_[op_id] = new OpTapeEntry
{
op_type = op_type,
output_tensor_info = output_tensors,
input_tensor_id = input_tensors,
output_tensor_info = output_tensors.ToArray(),
input_tensor_id = input_tensor_id.ToArray(),
backward_function = backward_function
};
}

public void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function)
{
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function);
}
}
}

+ 10
- 10
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients
_created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape();
op_tape_ = new OpTape();
tensor_usage_ = new UnorderedMap<Tensor, long>();
tensor_usage_ = new UnorderedMap<long, long>();
if(_created_eagerly)
tf.Context.start_step();
// nesting_id = ++tape_nesting_id_counter;
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients
public void Watch(Tensor x)
{
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
tensor_tape_.emplace(x, -1);
tensor_tape_.emplace(x.Id, -1);
}

public bool ShouldRecord(Tensor[] tensors)
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes)
{
var dtypes = tensors.Select(x => x.dtype).ToArray();
for (int i = 0; i < tensors.Length; ++i)
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length);
for (int i = 0; i < tensor_ids.Length; ++i)
{
if (tensor_tape_.find(tensors[i]))
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i]))
{
if (IsDtypeTrainable(dtypes[i]))
return true;
return true;
}
}
return false;
}

public void VariableAccessed(ResourceVariable variable)
public void VariableAccessed(IVariableV1 variable)
{
Watch(variable.Handle);
}

public ResourceVariable[] WatchedVariables()
public IVariableV1[] WatchedVariables()
{
return null;
}


+ 45
- 9
src/TensorFlowNET.Core/Gradients/TapeTensor.cs View File

@@ -1,27 +1,63 @@
using static Tensorflow.Binding;
using OneOf;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
public class TapeTensor
{
Tensor tensor;
long id => tensor.Id;
TF_DataType dtype => tensor.dtype;
Shape shape => tensor.shape;
internal Tensor tensor;
internal long id;
internal TF_DataType dtype;
internal OneOf<Shape, Tensor> shape;

public TapeTensor(long id, TF_DataType dtype, Shape shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}

public TapeTensor(long id, TF_DataType dtype, Tensor shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}

public TapeTensor(Tensor tensor)
{
this.id = tensor.Id;
this.dtype = tensor.dtype;
this.shape = tensor.shape;
this.tensor = tensor;
}

public long GetID() => tensor.Id;
public Tensor GetTensor() => tensor;
public long GetID() => id;

public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype);
{
if(dtype == dtypes.resource)
{
return null;
}
if(shape.Index == 1)
{
return tf.zeros_like(shape.AsT1);
}
return tf.zeros(shape.AsT0, dtype);
}

public Tensor OnesLike()
=> tf.ones(shape: shape, dtype: dtype);
{
if (shape.Index == 1)
{
return tf.ones_like(shape.AsT1);
}
return tf.ones(shape.AsT0, dtype);
}

//public Tensor OnesLike()
// => tf.ones(shape: shape, dtype: dtype);

public override string ToString()
=> $"{id}, {shape}, {dtype.as_numpy_name()}";


+ 1
- 1
src/TensorFlowNET.Core/Gradients/TensorTape.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients
/// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape.
/// </summary>
public class TensorTape : UnorderedMap<Tensor, long>
public class TensorTape : UnorderedMap<long, long>
{

}


+ 14
- 0
src/TensorFlowNET.Core/Gradients/custom_gradient.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
public class custom_gradient
{
public static string generate_name()
{
return $"CustomGradient-{ops.uid()}";
}
}
}

+ 52
- 0
src/TensorFlowNET.Core/Gradients/default_gradient.cs View File

@@ -0,0 +1,52 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
internal static class default_gradient
{
public static (Shape, TF_DataType) shape_and_dtype(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t);
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new ValueError($"Internal error: Tried to take gradients (or similar) " +
$"of a variable without handle data:\n{t}");
}
return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype());
}
return (t.shape, t.dtype);
}

public static Tensor zeros_like(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var (shape, dtype) = shape_and_dtype(t);
return array_ops.zeros(shape, dtype);
}
else
{
return array_ops.zeros_like(t);
}
}

public static TF_DataType get_zeros_dtype(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t);
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new ValueError($"Internal error: Tried to take gradients (or similar) " +
$"of a variable without handle data:\n{t}");
}
return handle_data.ShapeAndType[0].Dtype.as_tf_dtype();
}
return t.dtype;
}
}
}

+ 71
- 4
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -14,10 +14,15 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using Tensorflow.Operations.ControlFlows;
using static Tensorflow.Binding;

@@ -25,6 +30,11 @@ namespace Tensorflow
{
public class gradients_util
{
// Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
// unfortunately too slow to use here.
public static int POSSIBLE_GRADIENT_TYPES_NONE = 0;
public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1;
public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2;
public static Tensor[] _GradientsHelper(Tensor[] ys,
Tensor[] xs,
Tensor[] grad_ys = null,
@@ -143,7 +153,7 @@ namespace Tensorflow
Tensor[] in_grads = null;
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{
@@ -157,14 +167,41 @@ namespace Tensorflow
{
if (is_func_call)
{
EagerDefinedFunction func_call = null;
if (is_partitioned_call)
{

var func_attr = op.get_attr("f");
Debug.Assert(func_attr is NameAttrList);
var func_name = ((NameAttrList)func_attr).Name;
func_call = src_graph._get_function(func_name);
if(func_call is null && src_graph.OuterGraph is not null)
{
var graph = src_graph.OuterGraph;
while(graph is not null)
{
func_call = graph._get_function(func_name);
if(func_call is not null)
{
break;
}
if(graph.OuterGraph is not null)
{
graph = graph.OuterGraph;
}
else
{
break;
}
}
}
}
else
{

func_call = src_graph._get_function(op.type);
}
// skip the following codes:
// `func_call = getattr(op, "__defun", func_call)`
grad_fn = func_call.csharp_grad_func;
}
else
{
@@ -208,6 +245,8 @@ namespace Tensorflow
}
else
{
in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
null, (x, y) => _SymGrad(x, y));
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
}
_VerifyGeneratedGradients(in_grads, op);
@@ -663,6 +702,11 @@ namespace Tensorflow
dtypes.resource, dtypes.variant}.Contains(dtype);
}

public static int PossibleTapeGradientTypes(Tensor[] tensors)
{
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors);
}

/// <summary>
/// Return true if op has real gradient.
/// </summary>
@@ -683,7 +727,7 @@ namespace Tensorflow

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
{
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
// scope = scope.TrimEnd('/').Replace('/', '_');
return grad_fn(op, out_grads);
}

@@ -696,5 +740,28 @@ namespace Tensorflow
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}");
}

private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads)
{
var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray();
var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray();
NameAttrList f = new();
if (_IsPartitionedCall(op))
{
var func_attr = op.get_attr("f");
Debug.Assert(func_attr is NameAttrList);
f.Name = ((NameAttrList)func_attr).Name;
}
else
{
f.Name = op.type;
}
foreach(var k in op.node_def.Attr.Keys)
{
f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray());
}
var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f);
return in_grads;
}
}
}

+ 15
- 4
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -98,12 +98,23 @@ namespace Tensorflow
{
if (op.inputs == null) return null;

RegisterFromAssembly();
var gradient_function = op._gradient_function;
if(gradient_function is null)
{
RegisterFromAssembly();

if (!gradientFunctions.ContainsKey(op.type))
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");

if (!gradientFunctions.ContainsKey(op.type))
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");
return gradientFunctions[op.type];
}

return gradientFunctions[op.type];
Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args)
{
return gradient_function(operation, args);
}
// TODO(Rinne): check if this needs to be registered.
return wrapped_gradient_function;
}
}
}

+ 3
- 1
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -1,6 +1,7 @@
using MethodBoundaryAspect.Fody.Attributes;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Functions;
@@ -22,7 +23,7 @@ namespace Tensorflow.Graphs
public override void OnEntry(MethodExecutionArgs args)
{
// TODO: func_name can be cache in FullName + Args
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}";
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}";

if (functions.ContainsKey(func_name))
{
@@ -91,6 +92,7 @@ namespace Tensorflow.Graphs

// cache function.
function.ReturnType = args.ReturnValue.GetType();
function._set_infer_function();
functions[func_name] = function;

// run function


+ 364
- 13
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -1,6 +1,15 @@
using Google.Protobuf;
using System;
using System.Buffers;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Exceptions;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs;
@@ -10,12 +19,66 @@ namespace Tensorflow.Graphs;
/// </summary>
public class FuncGraph : Graph, IDisposable
{
SafeFuncGraphHandle _func_graph_handle;
internal SafeFuncGraphHandle _func_graph_handle;
internal HashSet<Tensor> _resource_tensor_inputs;
internal HashSet<WeakReference<IVariableV1>> _watched_variables;
internal IEnumerable<WeakReference<IVariableV1>> _weak_variables;
internal object[] _structured_outputs;
internal Dictionary<long, string> _output_names;
public string FuncName => _graph_key;

public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; }
public Tensors FlatStructuredOutputs
{
get
{
List<Tensor> res = new();
foreach(var obj in _structured_outputs)
{
if(obj is Tensor tensor)
{
res.Add(tensor);
}
else if(obj is IEnumerable<Tensor> tensors)
{
res.AddRange(tensors);
}
else
{
throw new TypeError("The structured outputs member should be tensor or tensors.");
}
}
return res;
}
}
public string Name { get; set; }
public IEnumerable<IVariableV1> Variables
{
get
{
return _weak_variables.Select(v =>
{
if (v.TryGetTarget(out var target))
{
return target;
}
else
{
throw new AssertionError("Called a function referencing variables which have been deleted. " +
"This likely means that function-local variables were created and " +
"not referenced elsewhere in the program. This is generally a " +
"mistake; consider storing variables in an object attribute on first call.");
}
});
}
internal set
{
_weak_variables = value.Select(x => new WeakReference<IVariableV1>(x));
}
}
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();
@@ -39,31 +102,42 @@ public class FuncGraph : Graph, IDisposable
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
_graph_key = Name = name;
building_function = true;
_weak_variables = new List<WeakReference<IVariableV1>>();
_resource_tensor_inputs = new HashSet<Tensor>();
_watched_variables = new HashSet<WeakReference<IVariableV1>>();
}

public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base()
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base()
{
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
_graph_key = Name = name;
building_function = true;
Attrs = attrs;
// Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle);
_handle = handle;
_weak_variables = new List<WeakReference<IVariableV1>>();
_resource_tensor_inputs = new HashSet<Tensor>();
_watched_variables = new HashSet<WeakReference<IVariableV1>>();
}

public void ToGraph(Operation[] opers,
public void replace_capture(Tensor tensor, Tensor placeholder)
{
_captures[tensor.Id] = (tensor, placeholder);
}

public unsafe void ToGraph(Operation[] opers,
Tensor[] inputs, Tensor[] outputs,
string[] output_names)
{
var status = new Status();
if (output_names != null && output_names.Length == 0)
if (output_names is null)
{
output_names = null;
output_names = new string[0];
};

_func_graph_handle = c_api.TF_GraphToFunction(_handle,
@@ -75,7 +149,7 @@ public class FuncGraph : Graph, IDisposable
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
output_names,
output_names.Length != outputs.Length ? null : output_names,
IntPtr.Zero,
null,
status);
@@ -141,6 +215,16 @@ public class FuncGraph : Graph, IDisposable
return tensor;
}

public void watch_variable(IVariableV1 v)
{
if (_resource_tensor_inputs.Contains(v.Handle))
{
return;
}
_watched_variables.Add(new WeakReference<IVariableV1>(v));
//this = this.outer_graph;
}

Tensor capture_eager_tensor(Tensor tensor, string name)
{
Tensor graph_const = null;
@@ -205,6 +289,19 @@ public class FuncGraph : Graph, IDisposable
Inputs.Add(placeholder);
}

Tensor pop_capture(Tensor tensor)
{
if(_captures.TryGetValue(tensor.Id, out var capture))
{
_captures.Remove(tensor.Id);
return capture.Item2;
}
else
{
return null;
}
}

Tensor _create_substitute_placeholder(Tensor value,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -228,10 +325,7 @@ public class FuncGraph : Graph, IDisposable

foreach (var (_name, attr_value) in enumerate(Attrs))
{
var serialized = new AttrValue
{
S = ByteString.CopyFromUtf8(attr_value)
}.ToByteArray();
var serialized = attr_value.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status);
tf.Status.Check(true);
}
@@ -254,4 +348,261 @@ public class FuncGraph : Graph, IDisposable
{
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status);
}

public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func,
object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null,
FuncGraph func_graph = null, bool autograph = false, object autograph_options = null,
bool add_control_dependencies = true, string[] arg_names = null,
Tensor op_return_value = null, bool capture_by_value = false,
bool acd_record_initial_resource_uses = false)
{
if(func_graph is null)
{
func_graph = new FuncGraph(name);
}

// TODO(Rinne): deal with control dependencies.

func_graph.as_default();
var current_scope = variable_scope.get_variable_scope();
var default_use_resource = current_scope.use_resource;
current_scope.use_resource = true;

if(signature is not null)
{
args = signature;
kwargs = new Dictionary<string, object>();
}
var func_args = _get_defun_inputs_from_args(args, arg_names);
var func_kwargs = _get_defun_inputs_from_kwargs(kwargs);

if(func_kwargs is not null && func_kwargs.Count > 0)
{
throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`.");
}

foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs }))
{
if(arg is Tensor tensor && tensor.dtype == dtypes.resource)
{
func_graph._resource_tensor_inputs.Add(tensor);
}
else if (arg is ResourceVariable variable)
{
func_graph._resource_tensor_inputs.Add(variable.Handle);
}
}

// skip the assignment of `func_graph.structured_input_signature`.

var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);

Tensor convert(object x)
{
if (x is null) return null;
Tensor res = null;
if(op_return_value is not null && x is Operation)
{
tf_with(ops.control_dependencies(new object[] { x }), _ =>
{
res = array_ops.identity(op_return_value);
});
}
else if(x is not TensorArray)
{
Debug.Assert(x is Tensor);
res = ops.convert_to_tensor_or_composite(x as Tensor);
}
else
{
throw new NotImplementedException($"The `TensorArray` is not supported here currently.");
}
if (add_control_dependencies)
{
// TODO(Rinne): `x = deps_ctx.mark_as_return(x)`.
}
return res;
}

if (autograph)
{
throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported.");
}

var func_outputs = func(func_args);
func_outputs = variable_utils.convert_variables_to_tensors(func_outputs);
func_outputs = func_outputs.Select(x => convert(x)).ToArray();
// TODO(Rinne): `check_func_mutation`.

current_scope.use_resource = default_use_resource;

var graph_variables = func_graph._watched_variables.ToList();
HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>();
List<Tensor> inputs = new();
foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args))
{
if(arg is BaseResourceVariable variable)
{
var resource_placeholder = func_graph.pop_capture(variable.Handle);
if(resource_placeholder is null)
{
continue;
}
Debug.Assert(variable is IVariableV1);
arg_variables.Add(variable as IVariableV1);
inputs.Add(resource_placeholder);
}
else if(arg is Tensor tensor)
{
inputs.Add(tensor);
}
}
var variables = graph_variables.Select(v =>
{
if (v.TryGetTarget(out var target))
{
return target;
}
else
{
return null;
}
}).Where(v => v is not null && !arg_variables.Contains(v));
func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray();
func_graph._structured_outputs = func_outputs;
func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null)
.Select(x => func_graph.capture(x)));

func_graph.Variables = variables;

func_graph.Exit();

if (add_control_dependencies)
{
// TODO(Rinne): implement it.
}
return func_graph;
}

private static object[] _get_defun_inputs_from_args(object[] args, string[] names)
{
return _get_defun_inputs(args, names, args) as object[];
}

private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs)
{
// TODO(Rinne): implement it.
Debug.Assert(kwargs is null || kwargs.Count == 0);
return kwargs;
//string[] names;
//object[] args;
//if(kwargs is not null && kwargs.Count > 0)
//{
// var sorted_kwargs = kwargs.OrderBy(x => x.Key);
// names = sorted_kwargs.Select(x => x.Key).ToArray();
// args = sorted_kwargs.Select(x => x.Value).ToArray();
//}
//else
//{
// names = new string[0];
// args = new object[0];
//}
//return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>;
}

private static object _get_defun_inputs(object[] args, string[] names, object structured_args)
{
List<object> function_inputs = new();
if(names is null)
{
names = new string[args.Length];
}

foreach(var (arg_value, name) in zip(args, names))
{
foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value))
{
function_inputs.Add(_get_defun_input(val, name));
}
}
return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true);
}

private static object _get_defun_input(object arg, string name)
{
var func_graph = ops.get_default_graph() as FuncGraph;
Debug.Assert(func_graph is not null);
if (arg is Tensor tensor)
{
Tensor placeholder;
try
{
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(tensor.dtype, tensor.shape);
}
handle_data_util.copy_handle_data(tensor, placeholder);
if (name is not null)
{
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(name)
});
}
return placeholder;
}
else if (arg is TensorSpec spec)
{
string requested_name;
if (!string.IsNullOrEmpty(spec.name))
{
requested_name = spec.name;
}
else
{
requested_name = name;
}
Tensor placeholder;
try
{
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(spec.dtype, spec.shape);
}
if (name is not null)
{
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(requested_name)
});
}
return placeholder;
}
else if (arg is BaseResourceVariable variable)
{
var placeholder = func_graph.capture(variable.Handle, name);
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(name)
});
return arg;
}
// TODO(Rinne): deal with `VariableSpec`.
else
{
return arg;
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs View File

@@ -1,4 +1,6 @@
namespace Tensorflow
using Tensorflow.Graphs;

namespace Tensorflow
{
public partial class Graph
{
@@ -6,5 +8,10 @@
{

}

internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map)
{
return new GraphOverrideGradientContext(this, gradient_function_map);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -118,7 +118,7 @@ namespace Tensorflow
/// <param name="compute_device">(Optional.) If True, device functions will be executed
/// to compute the device property of the Operation.</param>
/// <returns>An `Operation` object.</returns>
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null)
{
var ret = new Operation(c_op, this);
_add_op(ret);


+ 58
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -19,6 +19,9 @@ using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Common.Extensions;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

@@ -86,6 +89,13 @@ namespace Tensorflow
private int _next_id_counter;
private List<Operation> _unfetchable_ops = new List<Operation>();
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
private Dictionary<string, EagerDefinedFunction> _functions = new();
internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new();
private VersionDef _graph_def_versions = new VersionDef()
{
Producer = versions.GRAPH_DEF_VERSION,
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
};

public string _name_stack = "";
protected string _graph_key;
@@ -121,6 +131,8 @@ namespace Tensorflow

protected Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;

public Graph()
{
@@ -147,6 +159,44 @@ namespace Tensorflow
return ops.set_default_graph(this);
}

public bool IsFunction(string name)
{
return _functions.ContainsKey(tf.compat.as_str(name));
}

internal void AddFunction(EagerDefinedFunction function)
{
_check_not_finalized();

var name = function.Name;
if(function._grad_func_name is not null && function.csharp_grad_func is not null)
{
throw new ValueError($"Gradient defined twice for function {name}");
}

var c_graph = this.c_graph;
var func = function._c_func.Get();
Status status = new();
if (function._grad_func is not null)
{
var gradient = function._grad_func._c_func.Get();
c_api.TF_GraphCopyFunction(c_graph, func, gradient, status);
status.Check(true);
}
else
{
c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status);
status.Check(true);
}

_functions[tf.compat.as_str(name)] = function;

if(_graph_def_versions.MinConsumer < 12)
{
_graph_def_versions.MinConsumer = 12;
}
}

private Tensor _as_graph_element(object obj)
{
if (obj is RefVariable var)
@@ -308,6 +358,9 @@ namespace Tensorflow

private void _create_op_helper(Operation op, bool compute_device = true)
{
// high priority
// TODO(Rinne): complete the implementation.
op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null);
_record_op_seen_by_control_dependencies(op);
}

@@ -524,6 +577,11 @@ namespace Tensorflow
ops.pop_graph();
}

internal EagerDefinedFunction _get_function(string name)
{
return _functions.GetOrDefault(name, null);
}

string debugString = string.Empty;
public override string ToString()
{


+ 37
- 0
src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs View File

@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Graphs
{
internal class GraphOverrideGradientContext: ITensorFlowObject
{
Graph _graph;
Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map;
public GraphOverrideGradientContext(Graph graph,
Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map)
{
_graph = graph;
_new_gradient_function_map = new_gradient_function_map;
}

[DebuggerStepThrough]
public void __enter__()
{
Debug.Assert(_graph._gradient_function_map.Count == 0);
_graph._gradient_function_map = _new_gradient_function_map;
}

[DebuggerStepThrough]
public void __exit__()
{
_graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>();
}

public void Dispose()
{

}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions
_handle = c_api.TF_NewImportGraphDefOptions();
}

public SafeImportGraphDefOptionsHandle Options => _handle;

public void AddReturnOutput(string name, int index)
{
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);


+ 5
- 2
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -185,6 +185,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name);

[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints);

/// <summary>
/// Add an output in `graph_def` to be returned via the `return_outputs` output
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input
@@ -246,7 +249,7 @@ namespace Tensorflow
/// <param name="ops">TF_ImportGraphDefOptions*</param>
/// <param name="uniquify_prefix">unsigned char</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix);
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix);

/// <summary>
/// Fetches the return operations requested via
@@ -308,7 +311,7 @@ namespace Tensorflow
/// <param name="types">const TF_DataType*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output,
public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output,
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types,
SafeStatusHandle status);



+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// This class has nothing but the attributes different from `LayerArgs`.
/// It's used to serialize the model to `tf` format.
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`,
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`.
/// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`.
/// </summary>
public class AutoSerializeLayerArgs: LayerArgs
{


+ 13
- 21
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -7,6 +7,11 @@ using System.Text;

namespace Tensorflow.Keras.Common
{
class ShapeInfoFromPython
{
public string class_name { get; set; }
public long?[] items { get; set; }
}
public class CustomizedShapeJsonConverter: JsonConverter
{
public override bool CanConvert(Type objectType)
@@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common
dims[i] = shape.dims[i];
}
}
var token = JToken.FromObject(dims);
var token = JToken.FromObject(new ShapeInfoFromPython()
{
class_name = "__tuple__",
items = dims
});
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
long?[] dims;
try
{
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
catch (JsonSerializationException ex)
{
if (reader.Value.Equals("class_name"))
{
reader.Read();
reader.Read();
reader.Read();
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
else
{
throw ex;
}
}
if (dims is null)
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
{
return null;
}
long ?[]dims = shape_info_from_python.items;
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs View File

@@ -4,10 +4,10 @@ public interface IOptimizer
{
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars);
Tensor[] clip_gradients(Tensor[] grads);
void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
}

+ 84
- 13
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -20,6 +20,9 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using System.Diagnostics;

namespace Tensorflow
{
@@ -47,6 +50,8 @@ namespace Tensorflow

private readonly Graph _graph;

internal Func<Operation, object[], Tensor[]> _gradient_function;

public string type => OpType;

public Graph graph => _graph;
@@ -61,7 +66,7 @@ namespace Tensorflow

public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle));

// OperationDescription _opDesc;
//private OperationDescription _op_desc;

public NodeDef node_def => GetNodeDef();

@@ -216,21 +221,19 @@ namespace Tensorflow

var x = AttrValue.Parser.ParseFrom(buf.ToArray());

string oneof_value = x.ValueCase.ToString();
if (string.IsNullOrEmpty(oneof_value))
return null;
var oneof_value = x.ValueCase;
if (oneof_value == AttrValue.ValueOneofCase.None)
return new object[0];

switch (oneof_value.ToLower())
if(oneof_value == AttrValue.ValueOneofCase.List)
{
case "list":
throw new NotImplementedException($"Unsupported field type in {oneof_value}");
case "type":
return x.Type;
case "s":
return x.S.ToStringUtf8();
default:
return x.GetType().GetProperty(oneof_value).GetValue(x);
throw new NotImplementedException($"Unsupported field type in {oneof_value}");
}
if(oneof_value == AttrValue.ValueOneofCase.Type)
{
return dtypes.as_tf_dtype(x.Type);
}
return ProtoUtils.GetSingleAttrValue(x, oneof_value);
}

public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
@@ -238,6 +241,19 @@ namespace Tensorflow
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
}

[Obsolete("The implementation is not complete.")]
internal void _set_device_from_string(string device_str)
{
// TODO(Rinne): complete it with new C API `SetRequestedDevice`.
//c_api.TF_SetDevice(_handle, device_str);
}

[Obsolete("The implementation is not complete.")]
internal void _set_device(string device)
{
_set_device_from_string(device);
}

private NodeDef GetNodeDef()
{
var buffer = new Buffer();
@@ -296,5 +312,60 @@ namespace Tensorflow
}

public NDArray numpy() => throw new NotImplementedException("");

internal void _add_outputs(TF_DataType[] types, Shape[] shapes)
{
Debug.Assert(types.Length == shapes.Length);
int orig_num_outputs = this.outputs.Length;
var new_outputs = new List<Tensor>(_outputs);

// Since the `_outputs` is defined as `Array`, when we add new output, we
// have to create a new array, which brings some performance concerns.
// In the future maybe the type of `outputs` should be reconsidered.
for(int i = 0; i < types.Length; i++)
{
var t = new Tensor(this, orig_num_outputs + i, types[i]);
t.shape = shapes[i];
new_outputs.Add(t);
}
_outputs = new_outputs.ToArray();
}

internal void _set_func_attr(string attr_name, string func_name)
{
var func = new NameAttrList() { Name = func_name };
_set_attr(attr_name, new AttrValue() { Func = func });
}

internal void _set_type_list_attr(string attr_name, DataType[] types)
{
if(types is null || types.Length == 0)
{
return;
}
var type_list = new AttrValue.Types.ListValue();
type_list.Type.AddRange(types);
_set_attr(attr_name, new AttrValue() { List = type_list });
}

internal void _set_attr(string attr_name, AttrValue attr_value)
{
var buffer = new Buffer(attr_value.ToByteArray());
try
{
_set_attr_with_buf(attr_name, buffer);
}
finally
{
buffer.Release();
}
}

internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
{
Status status = new();
c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status);
status.Check(true);
}
}
}

+ 72
- 0
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -14,10 +14,14 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding;

@@ -25,6 +29,74 @@ namespace Tensorflow
{
public class functional_ops
{
public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout,
bool executing_eagerly, string config, string executor_type)
{
if (tout is null)
{
throw new NotImplementedException();
}

if (config is null)
{
config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
}

if (executor_type is null)
{
executor_type = "";
}

if (executing_eagerly)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
}

var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray();
AttrValue tin_attr = new()
{
List = new AttrValue.Types.ListValue()
};
tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum()));
AttrValue tout_attr = new()
{
List = new AttrValue.Types.ListValue()
};
tout_attr.List.Type.AddRange(tout);
AttrValue func_attr = new()
{
Func = new NameAttrList()
};
func_attr.Func.Name = f.Name;
AttrValue executor_type_attr = new AttrValue()
{
S = tf.compat.as_bytes(executor_type)
};
AttrValue config_proto = new AttrValue()
{
S = ByteString.CopyFromUtf8(executor_type)
};

var graph = ops.get_default_graph();
f.AddToGraph(graph);
// TODO(Rinne): complete it with `f.stateful`
var op_name = "PartitionedCall";
string xla_compile_attr = "_XlaMustCompile";
Dictionary<string, AttrValue> op_attrs = new();
op_attrs["Tin"] = tin_attr;
op_attrs["Tout"] = tout_attr;
op_attrs["f"] = func_attr;
op_attrs["config_proto"] = config_proto;
op_attrs["executor_type"] = executor_type_attr;
// TODO(Rinne): deal with `f.definition`.
var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(),
name: op_name, attrs: op_attrs);
var outputs = op.outputs;
// TODO(Rinne): deal with `f.graph`.
return outputs;
}
public static Tensor scan(
Func<Tensor, Tensor, Tensor> fn,
Tensor elems,


+ 46
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -210,7 +211,51 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `value`.</returns>
public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value));
return _result[0];
}
catch (Exception)
{
}
try
{
return fill_eager_fallback(dims, value as Tensor, name, ctx);
}
catch (Exception)
{

}
}
Dictionary<string, object> attrs = new Dictionary<string, object>();
attrs["dims"] = dims;
attrs["value"] = value;
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result.output;
}

public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx)
{
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() };
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name);

if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return _result[0];
}
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));

/// <summary>
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast.


+ 128
- 0
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs View File

@@ -0,0 +1,128 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
public class gen_functional_ops
{
public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f,
string config = "", string config_proto = "", string executor_type = "", string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name,
args, tout, f, config, config_proto, executor_type));
}
catch (Exception)
{

}
}

if (config is null)
{
config = "";
}
if (config_proto is null)
{
config_proto = "";
}
if (executor_type is null)
{
executor_type = "";
}
Dictionary<string, object> kwargs = new();
kwargs["args"] = args;
kwargs["Tout"] = tout;
kwargs["f"] = f;
kwargs["config"] = config;
kwargs["config_proto"] = config_proto;
kwargs["executor_type"] = executor_type;
var output = tf.OpDefLib._apply_op_helper("PartitionedCall",
name, kwargs);
var result = output.outputs;
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}

public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f,
string config, string config_proto, string executor_type, string name, Context ctx)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
if(config is null)
{
config = "";
}
if(config_proto is null)
{
config_proto = "";
}
if(executor_type is null)
{
executor_type = "";
}
object[] attrs = new object[]
{

};
}

public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"SymbolicGradient", name, input, Tout, f));
return _result;
}
catch (Exception)
{

}

try
{
return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx);
}
catch (Exception)
{

}
}
var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f });
var result = op.outputs;
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}

public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx)
{
object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f };
var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}
}
}

+ 38
- 0
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape));
return _result[0];
}
catch (Exception)
{
}
try
{
return ensure_shape_eager_fallback(input, shape, name, ctx);
}
catch (Exception)
{
}
}

var dict = new Dictionary<string, object>();
dict["input"] = input;
dict["shape"] = shape;
var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return op.output;
}

public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx)
{
object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() };
var _result = execute.executes("EnsureShape", 1, new Tensor[] { input },
attrs, ctx, name);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return _result[0];
}

/// <summary>
/// Creates or finds a child frame, and makes <c>data</c> available to the child frame.
/// </summary>


+ 60
- 0
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -0,0 +1,60 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow.Operations
{
public static class handle_data_util
{
public static void copy_handle_data(Tensor source_t, Tensor target_t)
{
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant)
{
HandleData handle_data;
if(source_t is EagerTensor)
{
handle_data = source_t.HandleData;
}
else
{
handle_data = ops.get_resource_handle_data(source_t);
}
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null
&& handle_data.ShapeAndType.Count > 0)
{
set_handle_data(target_t, handle_data);
}
}
}

public static HandleData create_handle_data(Shape shape, TF_DataType dtype)
{
HandleData handle_data = new();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType()
{
Shape = shape.as_proto(),
Dtype = dtype.as_datatype_enum()
});
return handle_data;
}

public static void set_handle_data(Tensor target_t, HandleData handle_data)
{
if(target_t is EagerTensor)
{
target_t.HandleData = handle_data;
return;
}
Status status = new();
var proto = handle_data.ToByteArray();
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
status.Check(true);
}

public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op);
}
}

+ 129
- 44
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -21,6 +21,11 @@ using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;
using static Tensorflow.Binding;
using Tensorflow.Operations;
using System.Buffers;
using Tensorflow.Eager;
using Tensorflow.Graphs;

namespace Tensorflow
{
@@ -31,18 +36,14 @@ namespace Tensorflow
{
public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null)
{
// TODO(Rinne): deal with `_handle_graph`.
var value_tensor = ops.convert_to_tensor(value);
return gen_resource_variable_ops.assign_variable_op(handle,
value_tensor,
name: name);
}

public static bool is_resource_variable(IVariableV1 var)
{
return var is ResourceVariable;
}
public static bool is_resource_variable(Trackable var)
public static bool is_resource_variable(object var)
{
return var is BaseResourceVariable;
}
@@ -78,6 +79,18 @@ namespace Tensorflow
string shared_name, string name, bool graph_mode, Tensor initial_value = null)
{
var container = ops.get_default_graph().Container;
if(container is null)
{
container = "";
}
if (!graph_mode)
{
if(shared_name is not null)
{
throw new Exception("Using an explicit shared_name is not allowed when executing eagerly.");
}
shared_name = tf.Context.anonymous_name();
}
var handle = gen_resource_variable_ops.var_handle_op(shape: shape,
dtype: dtype,
shared_name: shared_name,
@@ -95,26 +108,20 @@ namespace Tensorflow
}
else
{
// We do not want two distinct ResourceVariable objects for the same
// underlying resource in the runtime.
// When in eager mode, explicitly ensure so here. When in graph mode, it's
// ensured by always generating different variable names.
var exists = gen_resource_variable_ops.var_is_initialized_op(handle);

// We create an assert Op instead of checking right away in order to be
// compatible with ASYNC execution mode. Further, since not all devices
// support string tensors, we encode the assertion string in the Op name
/*gen_logging_ops.assert(gen_math_ops.logical_not(exists),
new[] { exists },
name: "EagerVariableNameReuse");*/

var handle_data = new HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType
var handle_data = handle_data_util.create_handle_data(shape, dtype);
if (initial_value is not null && initial_value.dtype == dtypes.variant)
{
Dtype = dtype.as_datatype_enum(),
Shape = shape.as_proto()
});
var extra_handle_data = get_eager_safe_handle_data(initial_value);
if (extra_handle_data is not null && extra_handle_data.IsSet)
{
if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " +
$"but saw: '{handle_data}'");
}
handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType);
}
}
_set_handle_shapes_and_types(handle, handle_data, graph_mode);
return handle;
}
@@ -126,7 +133,7 @@ namespace Tensorflow
/// <param name="handle"></param>
/// <param name="handle_data"></param>
/// <param name="graph_mode"></param>
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
{
if (!graph_mode)
return;
@@ -144,6 +151,47 @@ namespace Tensorflow
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count;
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray();
}

//tensor.HandleData = handle_data;
//if (!graph_mode)
// return;

//var shapes = handle_data.ShapeAndType.Select(x => x.Shape);
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray();
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray();
//var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s =>
//{
// if (!s.UnknownRank)
// {
// return s.Dim.Select(d => (int)d.Size).ToArray();
// }
// else
// {
// return Memory<int>.Empty;
// }
//}).ToArray();

//List<MemoryHandle> handles = new();
//IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length];
//foreach(var (i, m) in enumerate(converted_shapes))
//{
// if(m.IsEmpty)
// {
// shapes_with_ptr[i] = IntPtr.Zero;
// }
// else
// {
// var handle = m.Pin();
// handles.Add(handle);
// shapes_with_ptr[i] = new IntPtr(handle.Pointer);
// }
//}

//Status status = new();
//// TODO(Rinne): enable it.
//c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(),
// shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status);
//handles = null;
}

/// <summary>
@@ -162,24 +210,6 @@ namespace Tensorflow
throw new NotImplementedException("");
}

private static HandleData get_eager_safe_handle_data(Tensor handle)
{
if (handle.Handle == null)
{
var data = new HandleData();
data.ShapeAndType.Add(new HandleShapeAndType
{
Shape = handle.shape.as_shape_proto(),
Dtype = handle.dtype.as_datatype_enum()
});
return data;
}
else
{
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
}

/// <summary>
/// Copies an existing variable to a new graph, with no initializer.
/// </summary>
@@ -231,5 +261,60 @@ namespace Tensorflow
}
}
}

public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor)
{
if(dtype == dtypes.variant)
{
var handle_data = get_eager_safe_handle_data(handle);
if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1)
{
tensor.HandleData = new HandleData()
{
IsSet = true
};
tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1));
}
}
}

public static HandleData get_eager_safe_handle_data(Tensor handle)
{
if (handle.Handle == null)
{
var data = new HandleData();
data.ShapeAndType.Add(new HandleShapeAndType
{
Shape = handle.shape.as_shape_proto(),
Dtype = handle.dtype.as_datatype_enum()
});
return data;
}
else
{
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
//if(handle is EagerTensor)
//{
// return handle.HandleData;
//}
//else
//{
// return handle_data_util.get_resource_handle_data(handle);
//}
}

public static void variable_accessed(IVariableV1 variable)
{
if (ops.get_default_graph() is FuncGraph func_graph)
{
func_graph.watch_variable(variable);
}
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
tape.VariableAccessed(variable);
}
}
}
}

+ 107
- 2
src/TensorFlowNET.Core/Protobuf/AllocationDescription.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/framework/allocation_description.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -43,23 +43,31 @@ namespace Tensorflow {

}
#region Messages
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> {
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<AllocationDescription> _parser = new pb::MessageParser<AllocationDescription>(() => new AllocationDescription());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<AllocationDescription> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AllocationDescription() {
OnConstruction();
}
@@ -67,6 +75,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AllocationDescription(AllocationDescription other) : this() {
requestedBytes_ = other.requestedBytes_;
allocatedBytes_ = other.allocatedBytes_;
@@ -78,6 +87,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AllocationDescription Clone() {
return new AllocationDescription(this);
}
@@ -89,6 +99,7 @@ namespace Tensorflow {
/// Total number of bytes requested
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long RequestedBytes {
get { return requestedBytes_; }
set {
@@ -103,6 +114,7 @@ namespace Tensorflow {
/// Total number of bytes allocated if known
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long AllocatedBytes {
get { return allocatedBytes_; }
set {
@@ -117,6 +129,7 @@ namespace Tensorflow {
/// Name of the allocator used
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string AllocatorName {
get { return allocatorName_; }
set {
@@ -131,6 +144,7 @@ namespace Tensorflow {
/// Identifier of the allocated buffer if known
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long AllocationId {
get { return allocationId_; }
set {
@@ -145,6 +159,7 @@ namespace Tensorflow {
/// Set if this tensor only has one remaining reference
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool HasSingleReference {
get { return hasSingleReference_; }
set {
@@ -159,6 +174,7 @@ namespace Tensorflow {
/// Address of the allocation.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ulong Ptr {
get { return ptr_; }
set {
@@ -167,11 +183,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as AllocationDescription);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(AllocationDescription other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -189,6 +207,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode();
@@ -204,12 +223,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (RequestedBytes != 0L) {
output.WriteRawTag(8);
output.WriteInt64(RequestedBytes);
@@ -237,9 +261,45 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (RequestedBytes != 0L) {
output.WriteRawTag(8);
output.WriteInt64(RequestedBytes);
}
if (AllocatedBytes != 0L) {
output.WriteRawTag(16);
output.WriteInt64(AllocatedBytes);
}
if (AllocatorName.Length != 0) {
output.WriteRawTag(26);
output.WriteString(AllocatorName);
}
if (AllocationId != 0L) {
output.WriteRawTag(32);
output.WriteInt64(AllocationId);
}
if (HasSingleReference != false) {
output.WriteRawTag(40);
output.WriteBool(HasSingleReference);
}
if (Ptr != 0UL) {
output.WriteRawTag(48);
output.WriteUInt64(Ptr);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (RequestedBytes != 0L) {
@@ -267,6 +327,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(AllocationDescription other) {
if (other == null) {
return;
@@ -293,7 +354,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -326,7 +391,47 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 8: {
RequestedBytes = input.ReadInt64();
break;
}
case 16: {
AllocatedBytes = input.ReadInt64();
break;
}
case 26: {
AllocatorName = input.ReadString();
break;
}
case 32: {
AllocationId = input.ReadInt64();
break;
}
case 40: {
HasSingleReference = input.ReadBool();
break;
}
case 48: {
Ptr = input.ReadUInt64();
break;
}
}
}
}
#endif

}



+ 464
- 7
src/TensorFlowNET.Core/Protobuf/ApiDef.cs
File diff suppressed because it is too large
View File


+ 338
- 4
src/TensorFlowNET.Core/Protobuf/AttrValue.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/framework/attr_value.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -63,23 +63,31 @@ namespace Tensorflow {
/// Comment indicates the corresponding attr type. Only the field matching the
/// attr type may be filled.
/// </summary>
public sealed partial class AttrValue : pb::IMessage<AttrValue> {
public sealed partial class AttrValue : pb::IMessage<AttrValue>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<AttrValue> _parser = new pb::MessageParser<AttrValue>(() => new AttrValue());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<AttrValue> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AttrValue() {
OnConstruction();
}
@@ -87,6 +95,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AttrValue(AttrValue other) : this() {
switch (other.ValueCase) {
case ValueOneofCase.S:
@@ -125,6 +134,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public AttrValue Clone() {
return new AttrValue(this);
}
@@ -135,6 +145,7 @@ namespace Tensorflow {
/// "string"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pb::ByteString S {
get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; }
set {
@@ -149,6 +160,7 @@ namespace Tensorflow {
/// "int"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long I {
get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; }
set {
@@ -163,6 +175,7 @@ namespace Tensorflow {
/// "float"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public float F {
get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; }
set {
@@ -177,6 +190,7 @@ namespace Tensorflow {
/// "bool"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool B {
get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; }
set {
@@ -191,6 +205,7 @@ namespace Tensorflow {
/// "type"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.DataType Type {
get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; }
set {
@@ -205,6 +220,7 @@ namespace Tensorflow {
/// "shape"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.TensorShapeProto Shape {
get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; }
set {
@@ -219,6 +235,7 @@ namespace Tensorflow {
/// "tensor"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.TensorProto Tensor {
get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; }
set {
@@ -233,6 +250,7 @@ namespace Tensorflow {
/// any "list(...)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.AttrValue.Types.ListValue List {
get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; }
set {
@@ -250,6 +268,7 @@ namespace Tensorflow {
/// that attr in the instantiation.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.NameAttrList Func {
get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; }
set {
@@ -270,6 +289,7 @@ namespace Tensorflow {
/// given the value "bar".
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Placeholder {
get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; }
set {
@@ -295,22 +315,26 @@ namespace Tensorflow {
}
private ValueOneofCase valueCase_ = ValueOneofCase.None;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ValueOneofCase ValueCase {
get { return valueCase_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void ClearValue() {
valueCase_ = ValueOneofCase.None;
value_ = null;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as AttrValue);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(AttrValue other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -333,6 +357,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode();
@@ -353,12 +378,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (valueCase_ == ValueOneofCase.List) {
output.WriteRawTag(10);
output.WriteMessage(List);
@@ -402,9 +432,61 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (valueCase_ == ValueOneofCase.List) {
output.WriteRawTag(10);
output.WriteMessage(List);
}
if (valueCase_ == ValueOneofCase.S) {
output.WriteRawTag(18);
output.WriteBytes(S);
}
if (valueCase_ == ValueOneofCase.I) {
output.WriteRawTag(24);
output.WriteInt64(I);
}
if (valueCase_ == ValueOneofCase.F) {
output.WriteRawTag(37);
output.WriteFloat(F);
}
if (valueCase_ == ValueOneofCase.B) {
output.WriteRawTag(40);
output.WriteBool(B);
}
if (valueCase_ == ValueOneofCase.Type) {
output.WriteRawTag(48);
output.WriteEnum((int) Type);
}
if (valueCase_ == ValueOneofCase.Shape) {
output.WriteRawTag(58);
output.WriteMessage(Shape);
}
if (valueCase_ == ValueOneofCase.Tensor) {
output.WriteRawTag(66);
output.WriteMessage(Tensor);
}
if (valueCase_ == ValueOneofCase.Placeholder) {
output.WriteRawTag(74);
output.WriteString(Placeholder);
}
if (valueCase_ == ValueOneofCase.Func) {
output.WriteRawTag(82);
output.WriteMessage(Func);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (valueCase_ == ValueOneofCase.S) {
@@ -444,6 +526,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(AttrValue other) {
if (other == null) {
return;
@@ -497,7 +580,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -567,32 +654,118 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue();
if (valueCase_ == ValueOneofCase.List) {
subBuilder.MergeFrom(List);
}
input.ReadMessage(subBuilder);
List = subBuilder;
break;
}
case 18: {
S = input.ReadBytes();
break;
}
case 24: {
I = input.ReadInt64();
break;
}
case 37: {
F = input.ReadFloat();
break;
}
case 40: {
B = input.ReadBool();
break;
}
case 48: {
value_ = input.ReadEnum();
valueCase_ = ValueOneofCase.Type;
break;
}
case 58: {
global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto();
if (valueCase_ == ValueOneofCase.Shape) {
subBuilder.MergeFrom(Shape);
}
input.ReadMessage(subBuilder);
Shape = subBuilder;
break;
}
case 66: {
global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto();
if (valueCase_ == ValueOneofCase.Tensor) {
subBuilder.MergeFrom(Tensor);
}
input.ReadMessage(subBuilder);
Tensor = subBuilder;
break;
}
case 74: {
Placeholder = input.ReadString();
break;
}
case 82: {
global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList();
if (valueCase_ == ValueOneofCase.Func) {
subBuilder.MergeFrom(Func);
}
input.ReadMessage(subBuilder);
Func = subBuilder;
break;
}
}
}
}
#endif

#region Nested types
/// <summary>Container for nested types declared in the AttrValue message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static partial class Types {
/// <summary>
/// LINT.IfChange
/// </summary>
public sealed partial class ListValue : pb::IMessage<ListValue> {
public sealed partial class ListValue : pb::IMessage<ListValue>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<ListValue> _parser = new pb::MessageParser<ListValue>(() => new ListValue());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<ListValue> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ListValue() {
OnConstruction();
}
@@ -600,6 +773,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ListValue(ListValue other) : this() {
s_ = other.s_.Clone();
i_ = other.i_.Clone();
@@ -613,6 +787,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ListValue Clone() {
return new ListValue(this);
}
@@ -626,6 +801,7 @@ namespace Tensorflow {
/// "list(string)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<pb::ByteString> S {
get { return s_; }
}
@@ -639,6 +815,7 @@ namespace Tensorflow {
/// "list(int)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<long> I {
get { return i_; }
}
@@ -652,6 +829,7 @@ namespace Tensorflow {
/// "list(float)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<float> F {
get { return f_; }
}
@@ -665,6 +843,7 @@ namespace Tensorflow {
/// "list(bool)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<bool> B {
get { return b_; }
}
@@ -678,6 +857,7 @@ namespace Tensorflow {
/// "list(type)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.DataType> Type {
get { return type_; }
}
@@ -691,6 +871,7 @@ namespace Tensorflow {
/// "list(shape)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.TensorShapeProto> Shape {
get { return shape_; }
}
@@ -704,6 +885,7 @@ namespace Tensorflow {
/// "list(tensor)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.TensorProto> Tensor {
get { return tensor_; }
}
@@ -717,16 +899,19 @@ namespace Tensorflow {
/// "list(attr)"
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.NameAttrList> Func {
get { return func_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as ListValue);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(ListValue other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -746,6 +931,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= s_.GetHashCode();
@@ -763,12 +949,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
s_.WriteTo(output, _repeated_s_codec);
i_.WriteTo(output, _repeated_i_codec);
f_.WriteTo(output, _repeated_f_codec);
@@ -780,9 +971,29 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
s_.WriteTo(ref output, _repeated_s_codec);
i_.WriteTo(ref output, _repeated_i_codec);
f_.WriteTo(ref output, _repeated_f_codec);
b_.WriteTo(ref output, _repeated_b_codec);
type_.WriteTo(ref output, _repeated_type_codec);
shape_.WriteTo(ref output, _repeated_shape_codec);
tensor_.WriteTo(ref output, _repeated_tensor_codec);
func_.WriteTo(ref output, _repeated_func_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += s_.CalculateSize(_repeated_s_codec);
@@ -800,6 +1011,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(ListValue other) {
if (other == null) {
return;
@@ -816,7 +1028,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -861,7 +1077,59 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 18: {
s_.AddEntriesFrom(ref input, _repeated_s_codec);
break;
}
case 26:
case 24: {
i_.AddEntriesFrom(ref input, _repeated_i_codec);
break;
}
case 34:
case 37: {
f_.AddEntriesFrom(ref input, _repeated_f_codec);
break;
}
case 42:
case 40: {
b_.AddEntriesFrom(ref input, _repeated_b_codec);
break;
}
case 50:
case 48: {
type_.AddEntriesFrom(ref input, _repeated_type_codec);
break;
}
case 58: {
shape_.AddEntriesFrom(ref input, _repeated_shape_codec);
break;
}
case 66: {
tensor_.AddEntriesFrom(ref input, _repeated_tensor_codec);
break;
}
case 74: {
func_.AddEntriesFrom(ref input, _repeated_func_codec);
break;
}
}
}
}
#endif

}

@@ -874,23 +1142,31 @@ namespace Tensorflow {
/// A list of attr names and their values. The whole list is attached
/// with a string name. E.g., MatMul[T=float].
/// </summary>
public sealed partial class NameAttrList : pb::IMessage<NameAttrList> {
public sealed partial class NameAttrList : pb::IMessage<NameAttrList>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<NameAttrList> _parser = new pb::MessageParser<NameAttrList>(() => new NameAttrList());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<NameAttrList> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public NameAttrList() {
OnConstruction();
}
@@ -898,6 +1174,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public NameAttrList(NameAttrList other) : this() {
name_ = other.name_;
attr_ = other.attr_.Clone();
@@ -905,6 +1182,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public NameAttrList Clone() {
return new NameAttrList(this);
}
@@ -913,6 +1191,7 @@ namespace Tensorflow {
public const int NameFieldNumber = 1;
private string name_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Name {
get { return name_; }
set {
@@ -926,16 +1205,19 @@ namespace Tensorflow {
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18);
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr {
get { return attr_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as NameAttrList);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(NameAttrList other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -949,6 +1231,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
@@ -960,12 +1243,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
@@ -974,9 +1262,26 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
attr_.WriteTo(ref output, _map_attr_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
@@ -990,6 +1295,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(NameAttrList other) {
if (other == null) {
return;
@@ -1002,7 +1308,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -1019,7 +1329,31 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
Name = input.ReadString();
break;
}
case 18: {
attr_.AddEntriesFrom(ref input, _map_attr_codec);
break;
}
}
}
}
#endif

}



+ 84
- 2
src/TensorFlowNET.Core/Protobuf/CheckpointState.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/python/training/checkpoint_state.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -43,23 +43,31 @@ namespace Tensorflow {
/// <summary>
/// Protocol buffer representing the checkpoint state.
/// </summary>
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> {
public sealed partial class CheckpointState : pb::IMessage<CheckpointState>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CheckpointState() {
OnConstruction();
}
@@ -67,6 +75,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CheckpointState(CheckpointState other) : this() {
modelCheckpointPath_ = other.modelCheckpointPath_;
allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone();
@@ -76,6 +85,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CheckpointState Clone() {
return new CheckpointState(this);
}
@@ -87,6 +97,7 @@ namespace Tensorflow {
/// Path to the most-recent model checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string ModelCheckpointPath {
get { return modelCheckpointPath_; }
set {
@@ -106,6 +117,7 @@ namespace Tensorflow {
/// this list.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> AllModelCheckpointPaths {
get { return allModelCheckpointPaths_; }
}
@@ -120,6 +132,7 @@ namespace Tensorflow {
/// when each checkpoint was created.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<double> AllModelCheckpointTimestamps {
get { return allModelCheckpointTimestamps_; }
}
@@ -132,6 +145,7 @@ namespace Tensorflow {
/// checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public double LastPreservedTimestamp {
get { return lastPreservedTimestamp_; }
set {
@@ -140,11 +154,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CheckpointState);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CheckpointState other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -160,6 +176,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode();
@@ -173,12 +190,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (ModelCheckpointPath.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ModelCheckpointPath);
@@ -192,9 +214,31 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (ModelCheckpointPath.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ModelCheckpointPath);
}
allModelCheckpointPaths_.WriteTo(ref output, _repeated_allModelCheckpointPaths_codec);
allModelCheckpointTimestamps_.WriteTo(ref output, _repeated_allModelCheckpointTimestamps_codec);
if (LastPreservedTimestamp != 0D) {
output.WriteRawTag(33);
output.WriteDouble(LastPreservedTimestamp);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (ModelCheckpointPath.Length != 0) {
@@ -212,6 +256,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CheckpointState other) {
if (other == null) {
return;
@@ -228,7 +273,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -254,7 +303,40 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
ModelCheckpointPath = input.ReadString();
break;
}
case 18: {
allModelCheckpointPaths_.AddEntriesFrom(ref input, _repeated_allModelCheckpointPaths_codec);
break;
}
case 26:
case 25: {
allModelCheckpointTimestamps_.AddEntriesFrom(ref input, _repeated_allModelCheckpointTimestamps_codec);
break;
}
case 33: {
LastPreservedTimestamp = input.ReadDouble();
break;
}
}
}
}
#endif

}



+ 126
- 3
src/TensorFlowNET.Core/Protobuf/Cluster.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/protobuf/cluster.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -47,23 +47,31 @@ namespace Tensorflow {
/// <summary>
/// Defines a single job in a TensorFlow cluster.
/// </summary>
public sealed partial class JobDef : pb::IMessage<JobDef> {
public sealed partial class JobDef : pb::IMessage<JobDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<JobDef> _parser = new pb::MessageParser<JobDef>(() => new JobDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<JobDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public JobDef() {
OnConstruction();
}
@@ -71,6 +79,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public JobDef(JobDef other) : this() {
name_ = other.name_;
tasks_ = other.tasks_.Clone();
@@ -78,6 +87,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public JobDef Clone() {
return new JobDef(this);
}
@@ -89,6 +99,7 @@ namespace Tensorflow {
/// The name of this job.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Name {
get { return name_; }
set {
@@ -109,16 +120,19 @@ namespace Tensorflow {
/// "/job:worker/task:7" will be assigned to "example.org:2222".
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::MapField<int, string> Tasks {
get { return tasks_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as JobDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(JobDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -132,6 +146,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
@@ -143,12 +158,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
@@ -157,9 +177,26 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
tasks_.WriteTo(ref output, _map_tasks_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
@@ -173,6 +210,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(JobDef other) {
if (other == null) {
return;
@@ -185,7 +223,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -202,30 +244,62 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
Name = input.ReadString();
break;
}
case 18: {
tasks_.AddEntriesFrom(ref input, _map_tasks_codec);
break;
}
}
}
}
#endif

}

/// <summary>
/// Defines a TensorFlow cluster as a set of jobs.
/// </summary>
public sealed partial class ClusterDef : pb::IMessage<ClusterDef> {
public sealed partial class ClusterDef : pb::IMessage<ClusterDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<ClusterDef> _parser = new pb::MessageParser<ClusterDef>(() => new ClusterDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<ClusterDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ClusterDef() {
OnConstruction();
}
@@ -233,12 +307,14 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ClusterDef(ClusterDef other) : this() {
job_ = other.job_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ClusterDef Clone() {
return new ClusterDef(this);
}
@@ -252,16 +328,19 @@ namespace Tensorflow {
/// The jobs that comprise the cluster.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.JobDef> Job {
get { return job_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as ClusterDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(ClusterDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -274,6 +353,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= job_.GetHashCode();
@@ -284,19 +364,37 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
job_.WriteTo(output, _repeated_job_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
job_.WriteTo(ref output, _repeated_job_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += job_.CalculateSize(_repeated_job_codec);
@@ -307,6 +405,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(ClusterDef other) {
if (other == null) {
return;
@@ -316,7 +415,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -329,7 +432,27 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
job_.AddEntriesFrom(ref input, _repeated_job_codec);
break;
}
}
}
}
#endif

}



+ 2346
- 237
src/TensorFlowNET.Core/Protobuf/Config.cs
File diff suppressed because it is too large
View File


+ 407
- 5
src/TensorFlowNET.Core/Protobuf/ControlFlow.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/protobuf/control_flow.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -64,23 +64,31 @@ namespace Tensorflow {
/// <summary>
/// Protocol buffer representing the values in ControlFlowContext.
/// </summary>
public sealed partial class ValuesDef : pb::IMessage<ValuesDef> {
public sealed partial class ValuesDef : pb::IMessage<ValuesDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<ValuesDef> _parser = new pb::MessageParser<ValuesDef>(() => new ValuesDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<ValuesDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ValuesDef() {
OnConstruction();
}
@@ -88,6 +96,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ValuesDef(ValuesDef other) : this() {
values_ = other.values_.Clone();
externalValues_ = other.externalValues_.Clone();
@@ -95,6 +104,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ValuesDef Clone() {
return new ValuesDef(this);
}
@@ -108,6 +118,7 @@ namespace Tensorflow {
/// Value names that have been seen in this context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> Values {
get { return values_; }
}
@@ -121,16 +132,19 @@ namespace Tensorflow {
/// Value names referenced by but external to this context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::MapField<string, string> ExternalValues {
get { return externalValues_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as ValuesDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(ValuesDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -144,6 +158,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= values_.GetHashCode();
@@ -155,20 +170,39 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
values_.WriteTo(output, _repeated_values_codec);
externalValues_.WriteTo(output, _map_externalValues_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
values_.WriteTo(ref output, _repeated_values_codec);
externalValues_.WriteTo(ref output, _map_externalValues_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += values_.CalculateSize(_repeated_values_codec);
@@ -180,6 +214,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(ValuesDef other) {
if (other == null) {
return;
@@ -190,7 +225,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -207,7 +246,31 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
values_.AddEntriesFrom(ref input, _repeated_values_codec);
break;
}
case 18: {
externalValues_.AddEntriesFrom(ref input, _map_externalValues_codec);
break;
}
}
}
}
#endif

}

@@ -215,23 +278,31 @@ namespace Tensorflow {
/// Container for any kind of control flow context. Any other control flow
/// contexts that are added below should also be added here.
/// </summary>
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> {
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<ControlFlowContextDef> _parser = new pb::MessageParser<ControlFlowContextDef>(() => new ControlFlowContextDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<ControlFlowContextDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ControlFlowContextDef() {
OnConstruction();
}
@@ -239,6 +310,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ControlFlowContextDef(ControlFlowContextDef other) : this() {
switch (other.CtxtCase) {
case CtxtOneofCase.CondCtxt:
@@ -253,6 +325,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ControlFlowContextDef Clone() {
return new ControlFlowContextDef(this);
}
@@ -260,6 +333,7 @@ namespace Tensorflow {
/// <summary>Field number for the "cond_ctxt" field.</summary>
public const int CondCtxtFieldNumber = 1;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.CondContextDef CondCtxt {
get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; }
set {
@@ -271,6 +345,7 @@ namespace Tensorflow {
/// <summary>Field number for the "while_ctxt" field.</summary>
public const int WhileCtxtFieldNumber = 2;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.WhileContextDef WhileCtxt {
get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; }
set {
@@ -288,22 +363,26 @@ namespace Tensorflow {
}
private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CtxtOneofCase CtxtCase {
get { return ctxtCase_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void ClearCtxt() {
ctxtCase_ = CtxtOneofCase.None;
ctxt_ = null;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as ControlFlowContextDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(ControlFlowContextDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -318,6 +397,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode();
@@ -330,12 +410,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (ctxtCase_ == CtxtOneofCase.CondCtxt) {
output.WriteRawTag(10);
output.WriteMessage(CondCtxt);
@@ -347,9 +432,29 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (ctxtCase_ == CtxtOneofCase.CondCtxt) {
output.WriteRawTag(10);
output.WriteMessage(CondCtxt);
}
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) {
output.WriteRawTag(18);
output.WriteMessage(WhileCtxt);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (ctxtCase_ == CtxtOneofCase.CondCtxt) {
@@ -365,6 +470,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(ControlFlowContextDef other) {
if (other == null) {
return;
@@ -388,7 +494,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -415,30 +525,72 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef();
if (ctxtCase_ == CtxtOneofCase.CondCtxt) {
subBuilder.MergeFrom(CondCtxt);
}
input.ReadMessage(subBuilder);
CondCtxt = subBuilder;
break;
}
case 18: {
global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef();
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) {
subBuilder.MergeFrom(WhileCtxt);
}
input.ReadMessage(subBuilder);
WhileCtxt = subBuilder;
break;
}
}
}
}
#endif

}

/// <summary>
/// Protocol buffer representing a CondContext object.
/// </summary>
public sealed partial class CondContextDef : pb::IMessage<CondContextDef> {
public sealed partial class CondContextDef : pb::IMessage<CondContextDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CondContextDef> _parser = new pb::MessageParser<CondContextDef>(() => new CondContextDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CondContextDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CondContextDef() {
OnConstruction();
}
@@ -446,6 +598,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CondContextDef(CondContextDef other) : this() {
contextName_ = other.contextName_;
predName_ = other.predName_;
@@ -457,6 +610,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CondContextDef Clone() {
return new CondContextDef(this);
}
@@ -468,6 +622,7 @@ namespace Tensorflow {
/// Name of the context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string ContextName {
get { return contextName_; }
set {
@@ -482,6 +637,7 @@ namespace Tensorflow {
/// Name of the pred tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PredName {
get { return predName_; }
set {
@@ -496,6 +652,7 @@ namespace Tensorflow {
/// Name of the pivot tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PivotName {
get { return pivotName_; }
set {
@@ -510,6 +667,7 @@ namespace Tensorflow {
/// Branch prediction. 0 or 1.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int Branch {
get { return branch_; }
set {
@@ -524,6 +682,7 @@ namespace Tensorflow {
/// Values and external values in control flow context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.ValuesDef ValuesDef {
get { return valuesDef_; }
set {
@@ -540,16 +699,19 @@ namespace Tensorflow {
/// Contexts contained inside this context (e.g. nested conds).
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts {
get { return nestedContexts_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CondContextDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CondContextDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -567,6 +729,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode();
@@ -582,12 +745,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (ContextName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ContextName);
@@ -612,9 +780,42 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (ContextName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ContextName);
}
if (PredName.Length != 0) {
output.WriteRawTag(18);
output.WriteString(PredName);
}
if (PivotName.Length != 0) {
output.WriteRawTag(26);
output.WriteString(PivotName);
}
if (Branch != 0) {
output.WriteRawTag(32);
output.WriteInt32(Branch);
}
if (valuesDef_ != null) {
output.WriteRawTag(42);
output.WriteMessage(ValuesDef);
}
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (ContextName.Length != 0) {
@@ -640,6 +841,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CondContextDef other) {
if (other == null) {
return;
@@ -667,7 +869,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -703,30 +909,81 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
ContextName = input.ReadString();
break;
}
case 18: {
PredName = input.ReadString();
break;
}
case 26: {
PivotName = input.ReadString();
break;
}
case 32: {
Branch = input.ReadInt32();
break;
}
case 42: {
if (valuesDef_ == null) {
ValuesDef = new global::Tensorflow.ValuesDef();
}
input.ReadMessage(ValuesDef);
break;
}
case 50: {
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec);
break;
}
}
}
}
#endif

}

/// <summary>
/// Protocol buffer representing a WhileContext object.
/// </summary>
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> {
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<WhileContextDef> _parser = new pb::MessageParser<WhileContextDef>(() => new WhileContextDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<WhileContextDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public WhileContextDef() {
OnConstruction();
}
@@ -734,6 +991,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public WhileContextDef(WhileContextDef other) : this() {
contextName_ = other.contextName_;
parallelIterations_ = other.parallelIterations_;
@@ -751,6 +1009,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public WhileContextDef Clone() {
return new WhileContextDef(this);
}
@@ -762,6 +1021,7 @@ namespace Tensorflow {
/// Name of the context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string ContextName {
get { return contextName_; }
set {
@@ -776,6 +1036,7 @@ namespace Tensorflow {
/// The number of iterations allowed to run in parallel.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int ParallelIterations {
get { return parallelIterations_; }
set {
@@ -790,6 +1051,7 @@ namespace Tensorflow {
/// Whether backprop is enabled for this while loop.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool BackProp {
get { return backProp_; }
set {
@@ -804,6 +1066,7 @@ namespace Tensorflow {
/// Whether GPU-CPU memory swap is enabled for this loop.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool SwapMemory {
get { return swapMemory_; }
set {
@@ -818,6 +1081,7 @@ namespace Tensorflow {
/// Name of the pivot tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PivotName {
get { return pivotName_; }
set {
@@ -832,6 +1096,7 @@ namespace Tensorflow {
/// Name of the pivot_for_pred tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PivotForPredName {
get { return pivotForPredName_; }
set {
@@ -846,6 +1111,7 @@ namespace Tensorflow {
/// Name of the pivot_for_body tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PivotForBodyName {
get { return pivotForBodyName_; }
set {
@@ -862,6 +1128,7 @@ namespace Tensorflow {
/// List of names for exit tensors.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> LoopExitNames {
get { return loopExitNames_; }
}
@@ -875,6 +1142,7 @@ namespace Tensorflow {
/// List of names for enter tensors.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> LoopEnterNames {
get { return loopEnterNames_; }
}
@@ -886,6 +1154,7 @@ namespace Tensorflow {
/// Values and external values in control flow context.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.ValuesDef ValuesDef {
get { return valuesDef_; }
set {
@@ -900,6 +1169,7 @@ namespace Tensorflow {
/// Optional name of the maximum_iterations tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string MaximumIterationsName {
get { return maximumIterationsName_; }
set {
@@ -916,16 +1186,19 @@ namespace Tensorflow {
/// Contexts contained inside this context (e.g. nested whiles).
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts {
get { return nestedContexts_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as WhileContextDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(WhileContextDef other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -949,6 +1222,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode();
@@ -970,12 +1244,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (ContextName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ContextName);
@@ -1018,9 +1297,60 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (ContextName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ContextName);
}
if (ParallelIterations != 0) {
output.WriteRawTag(16);
output.WriteInt32(ParallelIterations);
}
if (BackProp != false) {
output.WriteRawTag(24);
output.WriteBool(BackProp);
}
if (SwapMemory != false) {
output.WriteRawTag(32);
output.WriteBool(SwapMemory);
}
if (PivotName.Length != 0) {
output.WriteRawTag(42);
output.WriteString(PivotName);
}
if (PivotForPredName.Length != 0) {
output.WriteRawTag(50);
output.WriteString(PivotForPredName);
}
if (PivotForBodyName.Length != 0) {
output.WriteRawTag(58);
output.WriteString(PivotForBodyName);
}
loopExitNames_.WriteTo(ref output, _repeated_loopExitNames_codec);
if (valuesDef_ != null) {
output.WriteRawTag(74);
output.WriteMessage(ValuesDef);
}
loopEnterNames_.WriteTo(ref output, _repeated_loopEnterNames_codec);
if (MaximumIterationsName.Length != 0) {
output.WriteRawTag(90);
output.WriteString(MaximumIterationsName);
}
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (ContextName.Length != 0) {
@@ -1060,6 +1390,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(WhileContextDef other) {
if (other == null) {
return;
@@ -1101,7 +1432,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -1161,7 +1496,74 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
ContextName = input.ReadString();
break;
}
case 16: {
ParallelIterations = input.ReadInt32();
break;
}
case 24: {
BackProp = input.ReadBool();
break;
}
case 32: {
SwapMemory = input.ReadBool();
break;
}
case 42: {
PivotName = input.ReadString();
break;
}
case 50: {
PivotForPredName = input.ReadString();
break;
}
case 58: {
PivotForBodyName = input.ReadString();
break;
}
case 66: {
loopExitNames_.AddEntriesFrom(ref input, _repeated_loopExitNames_codec);
break;
}
case 74: {
if (valuesDef_ == null) {
ValuesDef = new global::Tensorflow.ValuesDef();
}
input.ReadMessage(ValuesDef);
break;
}
case 82: {
loopEnterNames_.AddEntriesFrom(ref input, _repeated_loopEnterNames_codec);
break;
}
case 90: {
MaximumIterationsName = input.ReadString();
break;
}
case 98: {
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec);
break;
}
}
}
}
#endif

}



+ 791
- 0
src/TensorFlowNET.Core/Protobuf/CoordinationConfig.cs View File

@@ -0,0 +1,791 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/protobuf/coordination_config.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/coordination_config.proto</summary>
public static partial class CoordinationConfigReflection {

#region Descriptor
/// <summary>File descriptor for tensorflow/core/protobuf/coordination_config.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static CoordinationConfigReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjJ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX2NvbmZp",
"Zy5wcm90bxIKdGVuc29yZmxvdyIxCg5Db29yZGluYXRlZEpvYhIMCgRuYW1l",
"GAEgASgJEhEKCW51bV90YXNrcxgCIAEoBSLdAgoZQ29vcmRpbmF0aW9uU2Vy",
"dmljZUNvbmZpZxIUCgxzZXJ2aWNlX3R5cGUYASABKAkSFgoOc2VydmljZV9s",
"ZWFkZXIYAiABKAkSGwoTZW5hYmxlX2hlYWx0aF9jaGVjaxgDIAEoCBImCh5j",
"bHVzdGVyX3JlZ2lzdGVyX3RpbWVvdXRfaW5fbXMYBCABKAMSHwoXaGVhcnRi",
"ZWF0X3RpbWVvdXRfaW5fbXMYBSABKAMSOAoUY29vcmRpbmF0ZWRfam9iX2xp",
"c3QYCiADKAsyGi50ZW5zb3JmbG93LkNvb3JkaW5hdGVkSm9iEiYKHnNodXRk",
"b3duX2JhcnJpZXJfdGltZW91dF9pbl9tcxgHIAEoAxIqCiJhZ2VudF9kZXN0",
"cnVjdGlvbl93aXRob3V0X3NodXRkb3duGAggASgIEhgKEHJlY292ZXJhYmxl",
"X2pvYnMYCSADKAlKBAgGEAdCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl",
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl",
"X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedJob), global::Tensorflow.CoordinatedJob.Parser, new[]{ "Name", "NumTasks" }, null, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceConfig), global::Tensorflow.CoordinationServiceConfig.Parser, new[]{ "ServiceType", "ServiceLeader", "EnableHealthCheck", "ClusterRegisterTimeoutInMs", "HeartbeatTimeoutInMs", "CoordinatedJobList", "ShutdownBarrierTimeoutInMs", "AgentDestructionWithoutShutdown", "RecoverableJobs" }, null, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Represents a job type and the number of tasks under this job.
/// For example, ("worker", 20) implies that there will be 20 worker tasks.
/// </summary>
public sealed partial class CoordinatedJob : pb::IMessage<CoordinatedJob>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CoordinatedJob> _parser = new pb::MessageParser<CoordinatedJob>(() => new CoordinatedJob());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CoordinatedJob> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinatedJob() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinatedJob(CoordinatedJob other) : this() {
name_ = other.name_;
numTasks_ = other.numTasks_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinatedJob Clone() {
return new CoordinatedJob(this);
}

/// <summary>Field number for the "name" field.</summary>
public const int NameFieldNumber = 1;
private string name_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Name {
get { return name_; }
set {
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "num_tasks" field.</summary>
public const int NumTasksFieldNumber = 2;
private int numTasks_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int NumTasks {
get { return numTasks_; }
set {
numTasks_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CoordinatedJob);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CoordinatedJob other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Name != other.Name) return false;
if (NumTasks != other.NumTasks) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
if (NumTasks != 0) hash ^= NumTasks.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
if (NumTasks != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumTasks);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
if (NumTasks != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumTasks);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name);
}
if (NumTasks != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumTasks);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CoordinatedJob other) {
if (other == null) {
return;
}
if (other.Name.Length != 0) {
Name = other.Name;
}
if (other.NumTasks != 0) {
NumTasks = other.NumTasks;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
Name = input.ReadString();
break;
}
case 16: {
NumTasks = input.ReadInt32();
break;
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
Name = input.ReadString();
break;
}
case 16: {
NumTasks = input.ReadInt32();
break;
}
}
}
}
#endif

}

/// <summary>
/// Coordination service configuration parameters.
/// The system picks appropriate values for fields that are not set.
/// </summary>
public sealed partial class CoordinationServiceConfig : pb::IMessage<CoordinationServiceConfig>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CoordinationServiceConfig> _parser = new pb::MessageParser<CoordinationServiceConfig>(() => new CoordinationServiceConfig());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CoordinationServiceConfig> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinationServiceConfig() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinationServiceConfig(CoordinationServiceConfig other) : this() {
serviceType_ = other.serviceType_;
serviceLeader_ = other.serviceLeader_;
enableHealthCheck_ = other.enableHealthCheck_;
clusterRegisterTimeoutInMs_ = other.clusterRegisterTimeoutInMs_;
heartbeatTimeoutInMs_ = other.heartbeatTimeoutInMs_;
coordinatedJobList_ = other.coordinatedJobList_.Clone();
shutdownBarrierTimeoutInMs_ = other.shutdownBarrierTimeoutInMs_;
agentDestructionWithoutShutdown_ = other.agentDestructionWithoutShutdown_;
recoverableJobs_ = other.recoverableJobs_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CoordinationServiceConfig Clone() {
return new CoordinationServiceConfig(this);
}

/// <summary>Field number for the "service_type" field.</summary>
public const int ServiceTypeFieldNumber = 1;
private string serviceType_ = "";
/// <summary>
/// Type of coordination service implementation to enable.
/// For example, setting the service type as "standalone" starts a service
/// instance on the leader task to provide the coordination services such as
/// heartbeats and consistent key-value store.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string ServiceType {
get { return serviceType_; }
set {
serviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "service_leader" field.</summary>
public const int ServiceLeaderFieldNumber = 2;
private string serviceLeader_ = "";
/// <summary>
/// Address where the coordination service instance is hosted.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string ServiceLeader {
get { return serviceLeader_; }
set {
serviceLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "enable_health_check" field.</summary>
public const int EnableHealthCheckFieldNumber = 3;
private bool enableHealthCheck_;
/// <summary>
/// Whether to enable the health check mechanism.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool EnableHealthCheck {
get { return enableHealthCheck_; }
set {
enableHealthCheck_ = value;
}
}

/// <summary>Field number for the "cluster_register_timeout_in_ms" field.</summary>
public const int ClusterRegisterTimeoutInMsFieldNumber = 4;
private long clusterRegisterTimeoutInMs_;
/// <summary>
/// Maximum wait time for all members in the cluster to be registered.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long ClusterRegisterTimeoutInMs {
get { return clusterRegisterTimeoutInMs_; }
set {
clusterRegisterTimeoutInMs_ = value;
}
}

/// <summary>Field number for the "heartbeat_timeout_in_ms" field.</summary>
public const int HeartbeatTimeoutInMsFieldNumber = 5;
private long heartbeatTimeoutInMs_;
/// <summary>
/// Heartbeat timeout, if a task does not record heartbeat in this time
/// window, it will be considered disconnected.
/// Note: This is also used as a grace period to accept any heartbeats after
/// the agent has disconnected, to account for the lag time between the service
/// recording the state change and the agent stopping heartbeats.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long HeartbeatTimeoutInMs {
get { return heartbeatTimeoutInMs_; }
set {
heartbeatTimeoutInMs_ = value;
}
}

/// <summary>Field number for the "coordinated_job_list" field.</summary>
public const int CoordinatedJobListFieldNumber = 10;
private static readonly pb::FieldCodec<global::Tensorflow.CoordinatedJob> _repeated_coordinatedJobList_codec
= pb::FieldCodec.ForMessage(82, global::Tensorflow.CoordinatedJob.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.CoordinatedJob> coordinatedJobList_ = new pbc::RepeatedField<global::Tensorflow.CoordinatedJob>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.CoordinatedJob> CoordinatedJobList {
get { return coordinatedJobList_; }
}

/// <summary>Field number for the "shutdown_barrier_timeout_in_ms" field.</summary>
public const int ShutdownBarrierTimeoutInMsFieldNumber = 7;
private long shutdownBarrierTimeoutInMs_;
/// <summary>
/// Denotes how long to wait for all coordination agents to reach the barriers
/// (after the first shutdown request) before disconnecting together. If
/// set to 0, no barrier is imposed upon shutdown and each worker can
/// disconnect individually.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long ShutdownBarrierTimeoutInMs {
get { return shutdownBarrierTimeoutInMs_; }
set {
shutdownBarrierTimeoutInMs_ = value;
}
}

/// <summary>Field number for the "agent_destruction_without_shutdown" field.</summary>
public const int AgentDestructionWithoutShutdownFieldNumber = 8;
private bool agentDestructionWithoutShutdown_;
/// <summary>
/// If set, agents do not make an explicit Shutdown() call. Service will only
/// find out about the disconnecte agent via stale heartbeats. Used for
/// testing.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool AgentDestructionWithoutShutdown {
get { return agentDestructionWithoutShutdown_; }
set {
agentDestructionWithoutShutdown_ = value;
}
}

/// <summary>Field number for the "recoverable_jobs" field.</summary>
public const int RecoverableJobsFieldNumber = 9;
private static readonly pb::FieldCodec<string> _repeated_recoverableJobs_codec
= pb::FieldCodec.ForString(74);
private readonly pbc::RepeatedField<string> recoverableJobs_ = new pbc::RepeatedField<string>();
/// <summary>
/// The list of jobs which are recoverable. If a task in this list fails,
/// it will not propagate error to other tasks.
/// If empty, no jobs will be recoverable and every task failure will cause
/// error propagation to other tasks.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> RecoverableJobs {
get { return recoverableJobs_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CoordinationServiceConfig);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CoordinationServiceConfig other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (ServiceType != other.ServiceType) return false;
if (ServiceLeader != other.ServiceLeader) return false;
if (EnableHealthCheck != other.EnableHealthCheck) return false;
if (ClusterRegisterTimeoutInMs != other.ClusterRegisterTimeoutInMs) return false;
if (HeartbeatTimeoutInMs != other.HeartbeatTimeoutInMs) return false;
if(!coordinatedJobList_.Equals(other.coordinatedJobList_)) return false;
if (ShutdownBarrierTimeoutInMs != other.ShutdownBarrierTimeoutInMs) return false;
if (AgentDestructionWithoutShutdown != other.AgentDestructionWithoutShutdown) return false;
if(!recoverableJobs_.Equals(other.recoverableJobs_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (ServiceType.Length != 0) hash ^= ServiceType.GetHashCode();
if (ServiceLeader.Length != 0) hash ^= ServiceLeader.GetHashCode();
if (EnableHealthCheck != false) hash ^= EnableHealthCheck.GetHashCode();
if (ClusterRegisterTimeoutInMs != 0L) hash ^= ClusterRegisterTimeoutInMs.GetHashCode();
if (HeartbeatTimeoutInMs != 0L) hash ^= HeartbeatTimeoutInMs.GetHashCode();
hash ^= coordinatedJobList_.GetHashCode();
if (ShutdownBarrierTimeoutInMs != 0L) hash ^= ShutdownBarrierTimeoutInMs.GetHashCode();
if (AgentDestructionWithoutShutdown != false) hash ^= AgentDestructionWithoutShutdown.GetHashCode();
hash ^= recoverableJobs_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (ServiceType.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ServiceType);
}
if (ServiceLeader.Length != 0) {
output.WriteRawTag(18);
output.WriteString(ServiceLeader);
}
if (EnableHealthCheck != false) {
output.WriteRawTag(24);
output.WriteBool(EnableHealthCheck);
}
if (ClusterRegisterTimeoutInMs != 0L) {
output.WriteRawTag(32);
output.WriteInt64(ClusterRegisterTimeoutInMs);
}
if (HeartbeatTimeoutInMs != 0L) {
output.WriteRawTag(40);
output.WriteInt64(HeartbeatTimeoutInMs);
}
if (ShutdownBarrierTimeoutInMs != 0L) {
output.WriteRawTag(56);
output.WriteInt64(ShutdownBarrierTimeoutInMs);
}
if (AgentDestructionWithoutShutdown != false) {
output.WriteRawTag(64);
output.WriteBool(AgentDestructionWithoutShutdown);
}
recoverableJobs_.WriteTo(output, _repeated_recoverableJobs_codec);
coordinatedJobList_.WriteTo(output, _repeated_coordinatedJobList_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (ServiceType.Length != 0) {
output.WriteRawTag(10);
output.WriteString(ServiceType);
}
if (ServiceLeader.Length != 0) {
output.WriteRawTag(18);
output.WriteString(ServiceLeader);
}
if (EnableHealthCheck != false) {
output.WriteRawTag(24);
output.WriteBool(EnableHealthCheck);
}
if (ClusterRegisterTimeoutInMs != 0L) {
output.WriteRawTag(32);
output.WriteInt64(ClusterRegisterTimeoutInMs);
}
if (HeartbeatTimeoutInMs != 0L) {
output.WriteRawTag(40);
output.WriteInt64(HeartbeatTimeoutInMs);
}
if (ShutdownBarrierTimeoutInMs != 0L) {
output.WriteRawTag(56);
output.WriteInt64(ShutdownBarrierTimeoutInMs);
}
if (AgentDestructionWithoutShutdown != false) {
output.WriteRawTag(64);
output.WriteBool(AgentDestructionWithoutShutdown);
}
recoverableJobs_.WriteTo(ref output, _repeated_recoverableJobs_codec);
coordinatedJobList_.WriteTo(ref output, _repeated_coordinatedJobList_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (ServiceType.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceType);
}
if (ServiceLeader.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceLeader);
}
if (EnableHealthCheck != false) {
size += 1 + 1;
}
if (ClusterRegisterTimeoutInMs != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClusterRegisterTimeoutInMs);
}
if (HeartbeatTimeoutInMs != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatTimeoutInMs);
}
size += coordinatedJobList_.CalculateSize(_repeated_coordinatedJobList_codec);
if (ShutdownBarrierTimeoutInMs != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownBarrierTimeoutInMs);
}
if (AgentDestructionWithoutShutdown != false) {
size += 1 + 1;
}
size += recoverableJobs_.CalculateSize(_repeated_recoverableJobs_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CoordinationServiceConfig other) {
if (other == null) {
return;
}
if (other.ServiceType.Length != 0) {
ServiceType = other.ServiceType;
}
if (other.ServiceLeader.Length != 0) {
ServiceLeader = other.ServiceLeader;
}
if (other.EnableHealthCheck != false) {
EnableHealthCheck = other.EnableHealthCheck;
}
if (other.ClusterRegisterTimeoutInMs != 0L) {
ClusterRegisterTimeoutInMs = other.ClusterRegisterTimeoutInMs;
}
if (other.HeartbeatTimeoutInMs != 0L) {
HeartbeatTimeoutInMs = other.HeartbeatTimeoutInMs;
}
coordinatedJobList_.Add(other.coordinatedJobList_);
if (other.ShutdownBarrierTimeoutInMs != 0L) {
ShutdownBarrierTimeoutInMs = other.ShutdownBarrierTimeoutInMs;
}
if (other.AgentDestructionWithoutShutdown != false) {
AgentDestructionWithoutShutdown = other.AgentDestructionWithoutShutdown;
}
recoverableJobs_.Add(other.recoverableJobs_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
ServiceType = input.ReadString();
break;
}
case 18: {
ServiceLeader = input.ReadString();
break;
}
case 24: {
EnableHealthCheck = input.ReadBool();
break;
}
case 32: {
ClusterRegisterTimeoutInMs = input.ReadInt64();
break;
}
case 40: {
HeartbeatTimeoutInMs = input.ReadInt64();
break;
}
case 56: {
ShutdownBarrierTimeoutInMs = input.ReadInt64();
break;
}
case 64: {
AgentDestructionWithoutShutdown = input.ReadBool();
break;
}
case 74: {
recoverableJobs_.AddEntriesFrom(input, _repeated_recoverableJobs_codec);
break;
}
case 82: {
coordinatedJobList_.AddEntriesFrom(input, _repeated_coordinatedJobList_codec);
break;
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
ServiceType = input.ReadString();
break;
}
case 18: {
ServiceLeader = input.ReadString();
break;
}
case 24: {
EnableHealthCheck = input.ReadBool();
break;
}
case 32: {
ClusterRegisterTimeoutInMs = input.ReadInt64();
break;
}
case 40: {
HeartbeatTimeoutInMs = input.ReadInt64();
break;
}
case 56: {
ShutdownBarrierTimeoutInMs = input.ReadInt64();
break;
}
case 64: {
AgentDestructionWithoutShutdown = input.ReadBool();
break;
}
case 74: {
recoverableJobs_.AddEntriesFrom(ref input, _repeated_recoverableJobs_codec);
break;
}
case 82: {
coordinatedJobList_.AddEntriesFrom(ref input, _repeated_coordinatedJobList_codec);
break;
}
}
}
}
#endif

}

#endregion

}

#endregion Designer generated code

+ 7964
- 0
src/TensorFlowNET.Core/Protobuf/CoordinationService.cs
File diff suppressed because it is too large
View File


+ 486
- 6
src/TensorFlowNET.Core/Protobuf/CostGraph.cs
File diff suppressed because it is too large
View File


+ 296
- 5
src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/python/framework/cpp_shape_inference.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -55,23 +55,31 @@ namespace Tensorflow {

}
#region Messages
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> {
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CppShapeInferenceResult> _parser = new pb::MessageParser<CppShapeInferenceResult>(() => new CppShapeInferenceResult());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CppShapeInferenceResult> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceResult() {
OnConstruction();
}
@@ -79,6 +87,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceResult(CppShapeInferenceResult other) : this() {
shape_ = other.shape_ != null ? other.shape_.Clone() : null;
handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null;
@@ -86,6 +95,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceResult Clone() {
return new CppShapeInferenceResult(this);
}
@@ -94,6 +104,7 @@ namespace Tensorflow {
public const int ShapeFieldNumber = 1;
private global::Tensorflow.TensorShapeProto shape_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.TensorShapeProto Shape {
get { return shape_; }
set {
@@ -105,6 +116,7 @@ namespace Tensorflow {
public const int HandleDataFieldNumber = 4;
private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData {
get { return handleData_; }
set {
@@ -113,11 +125,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CppShapeInferenceResult);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CppShapeInferenceResult other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -131,6 +145,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (shape_ != null) hash ^= Shape.GetHashCode();
@@ -142,12 +157,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (shape_ != null) {
output.WriteRawTag(10);
output.WriteMessage(Shape);
@@ -159,9 +179,29 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (shape_ != null) {
output.WriteRawTag(10);
output.WriteMessage(Shape);
}
if (handleData_ != null) {
output.WriteRawTag(34);
output.WriteMessage(HandleData);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (shape_ != null) {
@@ -177,6 +217,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CppShapeInferenceResult other) {
if (other == null) {
return;
@@ -197,7 +238,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -220,29 +265,68 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
if (shape_ == null) {
Shape = new global::Tensorflow.TensorShapeProto();
}
input.ReadMessage(Shape);
break;
}
case 34: {
if (handleData_ == null) {
HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData();
}
input.ReadMessage(HandleData);
break;
}
}
}
}
#endif

#region Nested types
/// <summary>Container for nested types declared in the CppShapeInferenceResult message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static partial class Types {
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> {
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<HandleShapeAndType> _parser = new pb::MessageParser<HandleShapeAndType>(() => new HandleShapeAndType());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<HandleShapeAndType> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleShapeAndType() {
OnConstruction();
}
@@ -250,6 +334,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleShapeAndType(HandleShapeAndType other) : this() {
shape_ = other.shape_ != null ? other.shape_.Clone() : null;
dtype_ = other.dtype_;
@@ -258,6 +343,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleShapeAndType Clone() {
return new HandleShapeAndType(this);
}
@@ -266,6 +352,7 @@ namespace Tensorflow {
public const int ShapeFieldNumber = 1;
private global::Tensorflow.TensorShapeProto shape_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.TensorShapeProto Shape {
get { return shape_; }
set {
@@ -277,6 +364,7 @@ namespace Tensorflow {
public const int DtypeFieldNumber = 2;
private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.DataType Dtype {
get { return dtype_; }
set {
@@ -288,6 +376,7 @@ namespace Tensorflow {
public const int TypeFieldNumber = 4;
private global::Tensorflow.FullTypeDef type_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.FullTypeDef Type {
get { return type_; }
set {
@@ -296,11 +385,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as HandleShapeAndType);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(HandleShapeAndType other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -315,6 +406,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (shape_ != null) hash ^= Shape.GetHashCode();
@@ -327,12 +419,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (shape_ != null) {
output.WriteRawTag(10);
output.WriteMessage(Shape);
@@ -348,9 +445,33 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (shape_ != null) {
output.WriteRawTag(10);
output.WriteMessage(Shape);
}
if (Dtype != global::Tensorflow.DataType.DtInvalid) {
output.WriteRawTag(16);
output.WriteEnum((int) Dtype);
}
if (type_ != null) {
output.WriteRawTag(34);
output.WriteMessage(Type);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (shape_ != null) {
@@ -369,6 +490,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(HandleShapeAndType other) {
if (other == null) {
return;
@@ -392,7 +514,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -419,27 +545,69 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
if (shape_ == null) {
Shape = new global::Tensorflow.TensorShapeProto();
}
input.ReadMessage(Shape);
break;
}
case 16: {
Dtype = (global::Tensorflow.DataType) input.ReadEnum();
break;
}
case 34: {
if (type_ == null) {
Type = new global::Tensorflow.FullTypeDef();
}
input.ReadMessage(Type);
break;
}
}
}
}
#endif

}

public sealed partial class HandleData : pb::IMessage<HandleData> {
public sealed partial class HandleData : pb::IMessage<HandleData>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<HandleData> _parser = new pb::MessageParser<HandleData>(() => new HandleData());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<HandleData> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleData() {
OnConstruction();
}
@@ -447,6 +615,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleData(HandleData other) : this() {
isSet_ = other.isSet_;
shapeAndType_ = other.shapeAndType_.Clone();
@@ -454,6 +623,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public HandleData Clone() {
return new HandleData(this);
}
@@ -462,6 +632,7 @@ namespace Tensorflow {
public const int IsSetFieldNumber = 1;
private bool isSet_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool IsSet {
get { return isSet_; }
set {
@@ -478,16 +649,19 @@ namespace Tensorflow {
/// Only valid if &lt;is_set>.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType {
get { return shapeAndType_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as HandleData);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(HandleData other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -501,6 +675,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (IsSet != false) hash ^= IsSet.GetHashCode();
@@ -512,12 +687,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (IsSet != false) {
output.WriteRawTag(8);
output.WriteBool(IsSet);
@@ -526,9 +706,26 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (IsSet != false) {
output.WriteRawTag(8);
output.WriteBool(IsSet);
}
shapeAndType_.WriteTo(ref output, _repeated_shapeAndType_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (IsSet != false) {
@@ -542,6 +739,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(HandleData other) {
if (other == null) {
return;
@@ -554,7 +752,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -571,8 +773,32 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 8: {
IsSet = input.ReadBool();
break;
}
case 18: {
shapeAndType_.AddEntriesFrom(ref input, _repeated_shapeAndType_codec);
break;
}
}
}
}
#endif

}

}
@@ -580,23 +806,31 @@ namespace Tensorflow {

}

public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> {
public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<CppShapeInferenceInputsNeeded> _parser = new pb::MessageParser<CppShapeInferenceInputsNeeded>(() => new CppShapeInferenceInputsNeeded());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<CppShapeInferenceInputsNeeded> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceInputsNeeded() {
OnConstruction();
}
@@ -604,6 +838,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() {
inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone();
inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone();
@@ -611,6 +846,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public CppShapeInferenceInputsNeeded Clone() {
return new CppShapeInferenceInputsNeeded(this);
}
@@ -621,6 +857,7 @@ namespace Tensorflow {
= pb::FieldCodec.ForInt32(10);
private readonly pbc::RepeatedField<int> inputTensorsNeeded_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<int> InputTensorsNeeded {
get { return inputTensorsNeeded_; }
}
@@ -631,16 +868,19 @@ namespace Tensorflow {
= pb::FieldCodec.ForInt32(18);
private readonly pbc::RepeatedField<int> inputTensorsAsShapesNeeded_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<int> InputTensorsAsShapesNeeded {
get { return inputTensorsAsShapesNeeded_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as CppShapeInferenceInputsNeeded);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(CppShapeInferenceInputsNeeded other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -654,6 +894,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= inputTensorsNeeded_.GetHashCode();
@@ -665,20 +906,39 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec);
inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
inputTensorsNeeded_.WriteTo(ref output, _repeated_inputTensorsNeeded_codec);
inputTensorsAsShapesNeeded_.WriteTo(ref output, _repeated_inputTensorsAsShapesNeeded_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec);
@@ -690,6 +950,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(CppShapeInferenceInputsNeeded other) {
if (other == null) {
return;
@@ -700,7 +961,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -719,7 +984,33 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10:
case 8: {
inputTensorsNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsNeeded_codec);
break;
}
case 18:
case 16: {
inputTensorsAsShapesNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsAsShapesNeeded_codec);
break;
}
}
}
}
#endif

}



+ 1041
- 0
src/TensorFlowNET.Core/Protobuf/DataService.cs
File diff suppressed because it is too large
View File


+ 320
- 5
src/TensorFlowNET.Core/Protobuf/Debug.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/protobuf/debug.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -55,23 +55,31 @@ namespace Tensorflow {
/// <summary>
/// Option for watching a node in TensorFlow Debugger (tfdbg).
/// </summary>
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> {
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DebugTensorWatch> _parser = new pb::MessageParser<DebugTensorWatch>(() => new DebugTensorWatch());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DebugTensorWatch> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugTensorWatch() {
OnConstruction();
}
@@ -79,6 +87,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugTensorWatch(DebugTensorWatch other) : this() {
nodeName_ = other.nodeName_;
outputSlot_ = other.outputSlot_;
@@ -89,6 +98,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugTensorWatch Clone() {
return new DebugTensorWatch(this);
}
@@ -102,6 +112,7 @@ namespace Tensorflow {
/// general.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string NodeName {
get { return nodeName_; }
set {
@@ -120,6 +131,7 @@ namespace Tensorflow {
/// errors currently.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int OutputSlot {
get { return outputSlot_; }
set {
@@ -138,6 +150,7 @@ namespace Tensorflow {
/// e.g., {"DebugIdentity", "DebugNanCount"}
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> DebugOps {
get { return debugOps_; }
}
@@ -170,6 +183,7 @@ namespace Tensorflow {
/// TODO(cais): More visible documentation of this in g3docs.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> DebugUrls {
get { return debugUrls_; }
}
@@ -182,6 +196,7 @@ namespace Tensorflow {
/// incompatibility). Instead, just log the failure.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool TolerateDebugOpCreationFailures {
get { return tolerateDebugOpCreationFailures_; }
set {
@@ -190,11 +205,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DebugTensorWatch);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DebugTensorWatch other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -211,6 +228,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (NodeName.Length != 0) hash ^= NodeName.GetHashCode();
@@ -225,12 +243,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (NodeName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(NodeName);
@@ -248,9 +271,35 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (NodeName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(NodeName);
}
if (OutputSlot != 0) {
output.WriteRawTag(16);
output.WriteInt32(OutputSlot);
}
debugOps_.WriteTo(ref output, _repeated_debugOps_codec);
debugUrls_.WriteTo(ref output, _repeated_debugUrls_codec);
if (TolerateDebugOpCreationFailures != false) {
output.WriteRawTag(40);
output.WriteBool(TolerateDebugOpCreationFailures);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (NodeName.Length != 0) {
@@ -271,6 +320,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DebugTensorWatch other) {
if (other == null) {
return;
@@ -290,7 +340,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -319,30 +373,74 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
NodeName = input.ReadString();
break;
}
case 16: {
OutputSlot = input.ReadInt32();
break;
}
case 26: {
debugOps_.AddEntriesFrom(ref input, _repeated_debugOps_codec);
break;
}
case 34: {
debugUrls_.AddEntriesFrom(ref input, _repeated_debugUrls_codec);
break;
}
case 40: {
TolerateDebugOpCreationFailures = input.ReadBool();
break;
}
}
}
}
#endif

}

/// <summary>
/// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg).
/// </summary>
public sealed partial class DebugOptions : pb::IMessage<DebugOptions> {
public sealed partial class DebugOptions : pb::IMessage<DebugOptions>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DebugOptions> _parser = new pb::MessageParser<DebugOptions>(() => new DebugOptions());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DebugOptions> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugOptions() {
OnConstruction();
}
@@ -350,6 +448,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugOptions(DebugOptions other) : this() {
debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone();
globalStep_ = other.globalStep_;
@@ -358,6 +457,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebugOptions Clone() {
return new DebugOptions(this);
}
@@ -371,6 +471,7 @@ namespace Tensorflow {
/// Debugging options
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.DebugTensorWatch> DebugTensorWatchOpts {
get { return debugTensorWatchOpts_; }
}
@@ -384,6 +485,7 @@ namespace Tensorflow {
/// step count.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long GlobalStep {
get { return globalStep_; }
set {
@@ -401,6 +503,7 @@ namespace Tensorflow {
/// are cleaned up from the disk after each Session.run.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool ResetDiskByteUsage {
get { return resetDiskByteUsage_; }
set {
@@ -409,11 +512,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DebugOptions);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DebugOptions other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -428,6 +533,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= debugTensorWatchOpts_.GetHashCode();
@@ -440,12 +546,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec);
if (GlobalStep != 0L) {
output.WriteRawTag(80);
@@ -458,9 +569,30 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
debugTensorWatchOpts_.WriteTo(ref output, _repeated_debugTensorWatchOpts_codec);
if (GlobalStep != 0L) {
output.WriteRawTag(80);
output.WriteInt64(GlobalStep);
}
if (ResetDiskByteUsage != false) {
output.WriteRawTag(88);
output.WriteBool(ResetDiskByteUsage);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec);
@@ -477,6 +609,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DebugOptions other) {
if (other == null) {
return;
@@ -492,7 +625,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -513,27 +650,63 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 34: {
debugTensorWatchOpts_.AddEntriesFrom(ref input, _repeated_debugTensorWatchOpts_codec);
break;
}
case 80: {
GlobalStep = input.ReadInt64();
break;
}
case 88: {
ResetDiskByteUsage = input.ReadBool();
break;
}
}
}
}
#endif

}

public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> {
public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DebuggedSourceFile> _parser = new pb::MessageParser<DebuggedSourceFile>(() => new DebuggedSourceFile());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DebuggedSourceFile> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFile() {
OnConstruction();
}
@@ -541,6 +714,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFile(DebuggedSourceFile other) : this() {
host_ = other.host_;
filePath_ = other.filePath_;
@@ -551,6 +725,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFile Clone() {
return new DebuggedSourceFile(this);
}
@@ -562,6 +737,7 @@ namespace Tensorflow {
/// The host name on which a source code file is located.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Host {
get { return host_; }
set {
@@ -576,6 +752,7 @@ namespace Tensorflow {
/// Path to the source code file.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string FilePath {
get { return filePath_; }
set {
@@ -590,6 +767,7 @@ namespace Tensorflow {
/// The timestamp at which the source code file is last modified.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long LastModified {
get { return lastModified_; }
set {
@@ -604,6 +782,7 @@ namespace Tensorflow {
/// Byte size of the file.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long Bytes {
get { return bytes_; }
set {
@@ -620,16 +799,19 @@ namespace Tensorflow {
/// Line-by-line content of the source code file.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<string> Lines {
get { return lines_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DebuggedSourceFile);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DebuggedSourceFile other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -646,6 +828,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (Host.Length != 0) hash ^= Host.GetHashCode();
@@ -660,12 +843,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (Host.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Host);
@@ -686,9 +874,38 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (Host.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Host);
}
if (FilePath.Length != 0) {
output.WriteRawTag(18);
output.WriteString(FilePath);
}
if (LastModified != 0L) {
output.WriteRawTag(24);
output.WriteInt64(LastModified);
}
if (Bytes != 0L) {
output.WriteRawTag(32);
output.WriteInt64(Bytes);
}
lines_.WriteTo(ref output, _repeated_lines_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (Host.Length != 0) {
@@ -711,6 +928,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DebuggedSourceFile other) {
if (other == null) {
return;
@@ -732,7 +950,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -761,27 +983,71 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
Host = input.ReadString();
break;
}
case 18: {
FilePath = input.ReadString();
break;
}
case 24: {
LastModified = input.ReadInt64();
break;
}
case 32: {
Bytes = input.ReadInt64();
break;
}
case 42: {
lines_.AddEntriesFrom(ref input, _repeated_lines_codec);
break;
}
}
}
}
#endif

}

public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> {
public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DebuggedSourceFiles> _parser = new pb::MessageParser<DebuggedSourceFiles>(() => new DebuggedSourceFiles());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DebuggedSourceFiles> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFiles() {
OnConstruction();
}
@@ -789,12 +1055,14 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFiles(DebuggedSourceFiles other) : this() {
sourceFiles_ = other.sourceFiles_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DebuggedSourceFiles Clone() {
return new DebuggedSourceFiles(this);
}
@@ -808,16 +1076,19 @@ namespace Tensorflow {
/// A collection of source code files.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.DebuggedSourceFile> SourceFiles {
get { return sourceFiles_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DebuggedSourceFiles);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DebuggedSourceFiles other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -830,6 +1101,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= sourceFiles_.GetHashCode();
@@ -840,19 +1112,37 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
sourceFiles_.WriteTo(ref output, _repeated_sourceFiles_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec);
@@ -863,6 +1153,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DebuggedSourceFiles other) {
if (other == null) {
return;
@@ -872,7 +1163,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -885,7 +1180,27 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
sourceFiles_.AddEntriesFrom(ref input, _repeated_sourceFiles_codec);
break;
}
}
}
}
#endif

}



+ 378
- 11
src/TensorFlowNET.Core/Protobuf/DeviceAttributes.cs View File

@@ -2,7 +2,7 @@
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/framework/device_attributes.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
@@ -30,44 +30,53 @@ namespace Tensorflow {
"OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl",
"cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo",
"BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm",
"bG93LkxvY2FsTGlua3MirAEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB",
"bG93LkxvY2FsTGlua3MiwwEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB",
"IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB",
"KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs",
"aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k",
"ZXNjGAcgASgJQpEBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCFkRldmlj",
"ZUF0dHJpYnV0ZXNQcm90b3NQAVpYZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl",
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay9kZXZpY2Vf",
"YXR0cmlidXRlc19nb19wcm90b/gBAWIGcHJvdG8z"));
"ZXNjGAcgASgJEhUKDXhsYV9nbG9iYWxfaWQYCCABKANCkQEKGG9yZy50ZW5z",
"b3JmbG93LmZyYW1ld29ya0IWRGV2aWNlQXR0cmlidXRlc1Byb3Rvc1ABWlhn",
"aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv",
"L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVzX2dvX3Byb3Rv+AEB",
"YgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc" }, null, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc", "XlaGlobalId" }, null, null, null, null)
}));
}
#endregion

}
#region Messages
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> {
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<InterconnectLink> _parser = new pb::MessageParser<InterconnectLink>(() => new InterconnectLink());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<InterconnectLink> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public InterconnectLink() {
OnConstruction();
}
@@ -75,6 +84,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public InterconnectLink(InterconnectLink other) : this() {
deviceId_ = other.deviceId_;
type_ = other.type_;
@@ -83,6 +93,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public InterconnectLink Clone() {
return new InterconnectLink(this);
}
@@ -91,6 +102,7 @@ namespace Tensorflow {
public const int DeviceIdFieldNumber = 1;
private int deviceId_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int DeviceId {
get { return deviceId_; }
set {
@@ -102,6 +114,7 @@ namespace Tensorflow {
public const int TypeFieldNumber = 2;
private string type_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Type {
get { return type_; }
set {
@@ -113,6 +126,7 @@ namespace Tensorflow {
public const int StrengthFieldNumber = 3;
private int strength_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int Strength {
get { return strength_; }
set {
@@ -121,11 +135,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as InterconnectLink);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(InterconnectLink other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -140,6 +156,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (DeviceId != 0) hash ^= DeviceId.GetHashCode();
@@ -152,12 +169,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (DeviceId != 0) {
output.WriteRawTag(8);
output.WriteInt32(DeviceId);
@@ -173,9 +195,33 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (DeviceId != 0) {
output.WriteRawTag(8);
output.WriteInt32(DeviceId);
}
if (Type.Length != 0) {
output.WriteRawTag(18);
output.WriteString(Type);
}
if (Strength != 0) {
output.WriteRawTag(24);
output.WriteInt32(Strength);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (DeviceId != 0) {
@@ -194,6 +240,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(InterconnectLink other) {
if (other == null) {
return;
@@ -211,7 +258,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -232,27 +283,63 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 8: {
DeviceId = input.ReadInt32();
break;
}
case 18: {
Type = input.ReadString();
break;
}
case 24: {
Strength = input.ReadInt32();
break;
}
}
}
}
#endif

}

public sealed partial class LocalLinks : pb::IMessage<LocalLinks> {
public sealed partial class LocalLinks : pb::IMessage<LocalLinks>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<LocalLinks> _parser = new pb::MessageParser<LocalLinks>(() => new LocalLinks());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<LocalLinks> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public LocalLinks() {
OnConstruction();
}
@@ -260,12 +347,14 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public LocalLinks(LocalLinks other) : this() {
link_ = other.link_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public LocalLinks Clone() {
return new LocalLinks(this);
}
@@ -276,16 +365,19 @@ namespace Tensorflow {
= pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.InterconnectLink> link_ = new pbc::RepeatedField<global::Tensorflow.InterconnectLink>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public pbc::RepeatedField<global::Tensorflow.InterconnectLink> Link {
get { return link_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as LocalLinks);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(LocalLinks other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -298,6 +390,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
hash ^= link_.GetHashCode();
@@ -308,19 +401,37 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
link_.WriteTo(output, _repeated_link_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
link_.WriteTo(ref output, _repeated_link_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
size += link_.CalculateSize(_repeated_link_codec);
@@ -331,6 +442,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(LocalLinks other) {
if (other == null) {
return;
@@ -340,7 +452,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -353,27 +469,55 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
link_.AddEntriesFrom(ref input, _repeated_link_codec);
break;
}
}
}
}
#endif

}

public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> {
public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DeviceLocality> _parser = new pb::MessageParser<DeviceLocality>(() => new DeviceLocality());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DeviceLocality> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceLocality() {
OnConstruction();
}
@@ -381,6 +525,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceLocality(DeviceLocality other) : this() {
busId_ = other.busId_;
numaNode_ = other.numaNode_;
@@ -389,6 +534,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceLocality Clone() {
return new DeviceLocality(this);
}
@@ -401,6 +547,7 @@ namespace Tensorflow {
/// no specific locality. Specific localities are indexed from 1.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int BusId {
get { return busId_; }
set {
@@ -415,6 +562,7 @@ namespace Tensorflow {
/// Optional NUMA locality of device.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int NumaNode {
get { return numaNode_; }
set {
@@ -429,6 +577,7 @@ namespace Tensorflow {
/// Optional local interconnect links to other devices.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.LocalLinks Links {
get { return links_; }
set {
@@ -437,11 +586,13 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DeviceLocality);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DeviceLocality other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -456,6 +607,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (BusId != 0) hash ^= BusId.GetHashCode();
@@ -468,12 +620,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (BusId != 0) {
output.WriteRawTag(8);
output.WriteInt32(BusId);
@@ -489,9 +646,33 @@ namespace Tensorflow {
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (BusId != 0) {
output.WriteRawTag(8);
output.WriteInt32(BusId);
}
if (NumaNode != 0) {
output.WriteRawTag(16);
output.WriteInt32(NumaNode);
}
if (links_ != null) {
output.WriteRawTag(26);
output.WriteMessage(Links);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (BusId != 0) {
@@ -510,6 +691,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DeviceLocality other) {
if (other == null) {
return;
@@ -530,7 +712,11 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -554,27 +740,66 @@ namespace Tensorflow {
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 8: {
BusId = input.ReadInt32();
break;
}
case 16: {
NumaNode = input.ReadInt32();
break;
}
case 26: {
if (links_ == null) {
Links = new global::Tensorflow.LocalLinks();
}
input.ReadMessage(Links);
break;
}
}
}
}
#endif

}

public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> {
public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<DeviceAttributes> _parser = new pb::MessageParser<DeviceAttributes>(() => new DeviceAttributes());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<DeviceAttributes> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceAttributes() {
OnConstruction();
}
@@ -582,6 +807,7 @@ namespace Tensorflow {
partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceAttributes(DeviceAttributes other) : this() {
name_ = other.name_;
deviceType_ = other.deviceType_;
@@ -589,10 +815,12 @@ namespace Tensorflow {
locality_ = other.locality_ != null ? other.locality_.Clone() : null;
incarnation_ = other.incarnation_;
physicalDeviceDesc_ = other.physicalDeviceDesc_;
xlaGlobalId_ = other.xlaGlobalId_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public DeviceAttributes Clone() {
return new DeviceAttributes(this);
}
@@ -604,6 +832,7 @@ namespace Tensorflow {
/// Fully specified name of the device within a cluster.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string Name {
get { return name_; }
set {
@@ -618,6 +847,7 @@ namespace Tensorflow {
/// String representation of device_type.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string DeviceType {
get { return deviceType_; }
set {
@@ -632,6 +862,7 @@ namespace Tensorflow {
/// Memory capacity of device in bytes.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long MemoryLimit {
get { return memoryLimit_; }
set {
@@ -647,6 +878,7 @@ namespace Tensorflow {
/// for supporting efficient data transfers.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Tensorflow.DeviceLocality Locality {
get { return locality_; }
set {
@@ -662,6 +894,7 @@ namespace Tensorflow {
/// initialized. "incarnation" should never be 0.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public ulong Incarnation {
get { return incarnation_; }
set {
@@ -676,6 +909,7 @@ namespace Tensorflow {
/// String representation of the physical device that this device maps to.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public string PhysicalDeviceDesc {
get { return physicalDeviceDesc_; }
set {
@@ -683,12 +917,31 @@ namespace Tensorflow {
}
}

/// <summary>Field number for the "xla_global_id" field.</summary>
public const int XlaGlobalIdFieldNumber = 8;
private long xlaGlobalId_;
/// <summary>
/// A physical device ID for use in XLA DeviceAssignments, unique across
/// clients in a multi-client setup. Set to -1 if unavailable, non-negative
/// otherwise.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public long XlaGlobalId {
get { return xlaGlobalId_; }
set {
xlaGlobalId_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as DeviceAttributes);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(DeviceAttributes other) {
if (ReferenceEquals(other, null)) {
return false;
@@ -702,10 +955,12 @@ namespace Tensorflow {
if (!object.Equals(Locality, other.Locality)) return false;
if (Incarnation != other.Incarnation) return false;
if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false;
if (XlaGlobalId != other.XlaGlobalId) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (Name.Length != 0) hash ^= Name.GetHashCode();
@@ -714,6 +969,7 @@ namespace Tensorflow {
if (locality_ != null) hash ^= Locality.GetHashCode();
if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode();
if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode();
if (XlaGlobalId != 0L) hash ^= XlaGlobalId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -721,12 +977,17 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
@@ -751,12 +1012,56 @@ namespace Tensorflow {
output.WriteRawTag(58);
output.WriteString(PhysicalDeviceDesc);
}
if (XlaGlobalId != 0L) {
output.WriteRawTag(64);
output.WriteInt64(XlaGlobalId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (Name.Length != 0) {
output.WriteRawTag(10);
output.WriteString(Name);
}
if (DeviceType.Length != 0) {
output.WriteRawTag(18);
output.WriteString(DeviceType);
}
if (MemoryLimit != 0L) {
output.WriteRawTag(32);
output.WriteInt64(MemoryLimit);
}
if (locality_ != null) {
output.WriteRawTag(42);
output.WriteMessage(Locality);
}
if (Incarnation != 0UL) {
output.WriteRawTag(49);
output.WriteFixed64(Incarnation);
}
if (PhysicalDeviceDesc.Length != 0) {
output.WriteRawTag(58);
output.WriteString(PhysicalDeviceDesc);
}
if (XlaGlobalId != 0L) {
output.WriteRawTag(64);
output.WriteInt64(XlaGlobalId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (Name.Length != 0) {
@@ -777,6 +1082,9 @@ namespace Tensorflow {
if (PhysicalDeviceDesc.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc);
}
if (XlaGlobalId != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaGlobalId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -784,6 +1092,7 @@ namespace Tensorflow {
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(DeviceAttributes other) {
if (other == null) {
return;
@@ -809,11 +1118,18 @@ namespace Tensorflow {
if (other.PhysicalDeviceDesc.Length != 0) {
PhysicalDeviceDesc = other.PhysicalDeviceDesc;
}
if (other.XlaGlobalId != 0L) {
XlaGlobalId = other.XlaGlobalId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
@@ -847,9 +1163,60 @@ namespace Tensorflow {
PhysicalDeviceDesc = input.ReadString();
break;
}
case 64: {
XlaGlobalId = input.ReadInt64();
break;
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
Name = input.ReadString();
break;
}
case 18: {
DeviceType = input.ReadString();
break;
}
case 32: {
MemoryLimit = input.ReadInt64();
break;
}
case 42: {
if (locality_ == null) {
Locality = new global::Tensorflow.DeviceLocality();
}
input.ReadMessage(Locality);
break;
}
case 49: {
Incarnation = input.ReadFixed64();
break;
}
case 58: {
PhysicalDeviceDesc = input.ReadString();
break;
}
case 64: {
XlaGlobalId = input.ReadInt64();
break;
}
}
}
}
#endif

}



+ 660
- 9
src/TensorFlowNET.Core/Protobuf/Event.cs
File diff suppressed because it is too large
View File


+ 340
- 0
src/TensorFlowNET.Core/Protobuf/Executable.cs View File

@@ -0,0 +1,340 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/compiler/xla/service/cpu/executable.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021, 8981
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Xla.Cpu {

/// <summary>Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/executable.proto</summary>
public static partial class ExecutableReflection {

#region Descriptor
/// <summary>File descriptor for tensorflow/compiler/xla/service/cpu/executable.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static ExecutableReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjR0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS9leGVjdXRh",
"YmxlLnByb3RvEgd4bGEuY3B1Gjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9z",
"ZXJ2aWNlL2NwdS94bGFfZnJhbWV3b3JrLnByb3RvGil0ZW5zb3JmbG93L2Nv",
"bXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90byLXAQocWGxhUnVudGltZUNw",
"dUV4ZWN1dGFibGVQcm90bxI+ChZ4bGFfcnVudGltZV9leGVjdXRhYmxlGAEg",
"ASgLMh4ueGxhLlhsYVJ1bnRpbWVFeGVjdXRhYmxlUHJvdG8SQAoVeGxhX2Zy",
"YW1ld29ya19tYXBwaW5nGAIgASgLMiEueGxhLmNwdS5YbGFGcmFtZXdvcmtN",
"YXBwaW5nUHJvdG8SNQoRYnVmZmVyX2Fzc2lnbm1lbnQYAyABKAsyGi54bGEu",
"QnVmZmVyQXNzaWdubWVudFByb3Rv"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Xla.Cpu.XlaFrameworkReflection.Descriptor, global::Xla.HloReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaRuntimeCpuExecutableProto), global::Xla.Cpu.XlaRuntimeCpuExecutableProto.Parser, new[]{ "XlaRuntimeExecutable", "XlaFrameworkMapping", "BufferAssignment" }, null, null, null, null)
}));
}
#endregion

}
#region Messages
public sealed partial class XlaRuntimeCpuExecutableProto : pb::IMessage<XlaRuntimeCpuExecutableProto>
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
, pb::IBufferMessage
#endif
{
private static readonly pb::MessageParser<XlaRuntimeCpuExecutableProto> _parser = new pb::MessageParser<XlaRuntimeCpuExecutableProto>(() => new XlaRuntimeCpuExecutableProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pb::MessageParser<XlaRuntimeCpuExecutableProto> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public static pbr::MessageDescriptor Descriptor {
get { return global::Xla.Cpu.ExecutableReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public XlaRuntimeCpuExecutableProto() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public XlaRuntimeCpuExecutableProto(XlaRuntimeCpuExecutableProto other) : this() {
xlaRuntimeExecutable_ = other.xlaRuntimeExecutable_ != null ? other.xlaRuntimeExecutable_.Clone() : null;
xlaFrameworkMapping_ = other.xlaFrameworkMapping_ != null ? other.xlaFrameworkMapping_.Clone() : null;
bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public XlaRuntimeCpuExecutableProto Clone() {
return new XlaRuntimeCpuExecutableProto(this);
}

/// <summary>Field number for the "xla_runtime_executable" field.</summary>
public const int XlaRuntimeExecutableFieldNumber = 1;
private global::Xla.XlaRuntimeExecutableProto xlaRuntimeExecutable_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Xla.XlaRuntimeExecutableProto XlaRuntimeExecutable {
get { return xlaRuntimeExecutable_; }
set {
xlaRuntimeExecutable_ = value;
}
}

/// <summary>Field number for the "xla_framework_mapping" field.</summary>
public const int XlaFrameworkMappingFieldNumber = 2;
private global::Xla.Cpu.XlaFrameworkMappingProto xlaFrameworkMapping_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Xla.Cpu.XlaFrameworkMappingProto XlaFrameworkMapping {
get { return xlaFrameworkMapping_; }
set {
xlaFrameworkMapping_ = value;
}
}

/// <summary>Field number for the "buffer_assignment" field.</summary>
public const int BufferAssignmentFieldNumber = 3;
private global::Xla.BufferAssignmentProto bufferAssignment_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public global::Xla.BufferAssignmentProto BufferAssignment {
get { return bufferAssignment_; }
set {
bufferAssignment_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override bool Equals(object other) {
return Equals(other as XlaRuntimeCpuExecutableProto);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public bool Equals(XlaRuntimeCpuExecutableProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (!object.Equals(XlaRuntimeExecutable, other.XlaRuntimeExecutable)) return false;
if (!object.Equals(XlaFrameworkMapping, other.XlaFrameworkMapping)) return false;
if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override int GetHashCode() {
int hash = 1;
if (xlaRuntimeExecutable_ != null) hash ^= XlaRuntimeExecutable.GetHashCode();
if (xlaFrameworkMapping_ != null) hash ^= XlaFrameworkMapping.GetHashCode();
if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void WriteTo(pb::CodedOutputStream output) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
output.WriteRawMessage(this);
#else
if (xlaRuntimeExecutable_ != null) {
output.WriteRawTag(10);
output.WriteMessage(XlaRuntimeExecutable);
}
if (xlaFrameworkMapping_ != null) {
output.WriteRawTag(18);
output.WriteMessage(XlaFrameworkMapping);
}
if (bufferAssignment_ != null) {
output.WriteRawTag(26);
output.WriteMessage(BufferAssignment);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) {
if (xlaRuntimeExecutable_ != null) {
output.WriteRawTag(10);
output.WriteMessage(XlaRuntimeExecutable);
}
if (xlaFrameworkMapping_ != null) {
output.WriteRawTag(18);
output.WriteMessage(XlaFrameworkMapping);
}
if (bufferAssignment_ != null) {
output.WriteRawTag(26);
output.WriteMessage(BufferAssignment);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(ref output);
}
}
#endif

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public int CalculateSize() {
int size = 0;
if (xlaRuntimeExecutable_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaRuntimeExecutable);
}
if (xlaFrameworkMapping_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaFrameworkMapping);
}
if (bufferAssignment_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(XlaRuntimeCpuExecutableProto other) {
if (other == null) {
return;
}
if (other.xlaRuntimeExecutable_ != null) {
if (xlaRuntimeExecutable_ == null) {
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto();
}
XlaRuntimeExecutable.MergeFrom(other.XlaRuntimeExecutable);
}
if (other.xlaFrameworkMapping_ != null) {
if (xlaFrameworkMapping_ == null) {
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto();
}
XlaFrameworkMapping.MergeFrom(other.XlaFrameworkMapping);
}
if (other.bufferAssignment_ != null) {
if (bufferAssignment_ == null) {
BufferAssignment = new global::Xla.BufferAssignmentProto();
}
BufferAssignment.MergeFrom(other.BufferAssignment);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
public void MergeFrom(pb::CodedInputStream input) {
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
input.ReadRawMessage(this);
#else
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
if (xlaRuntimeExecutable_ == null) {
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto();
}
input.ReadMessage(XlaRuntimeExecutable);
break;
}
case 18: {
if (xlaFrameworkMapping_ == null) {
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto();
}
input.ReadMessage(XlaFrameworkMapping);
break;
}
case 26: {
if (bufferAssignment_ == null) {
BufferAssignment = new global::Xla.BufferAssignmentProto();
}
input.ReadMessage(BufferAssignment);
break;
}
}
}
#endif
}

#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)]
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input);
break;
case 10: {
if (xlaRuntimeExecutable_ == null) {
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto();
}
input.ReadMessage(XlaRuntimeExecutable);
break;
}
case 18: {
if (xlaFrameworkMapping_ == null) {
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto();
}
input.ReadMessage(XlaFrameworkMapping);
break;
}
case 26: {
if (bufferAssignment_ == null) {
BufferAssignment = new global::Xla.BufferAssignmentProto();
}
input.ReadMessage(BufferAssignment);
break;
}
}
}
}
#endif

}

#endregion

}

#endregion Designer generated code

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save