@@ -1,7 +1,7 @@ | |||||
| | ||||
Microsoft Visual Studio Solution File, Format Version 12.00 | Microsoft Visual Studio Solution File, Format Version 12.00 | ||||
# Visual Studio Version 16 | |||||
VisualStudioVersion = 16.0.31624.102 | |||||
# Visual Studio Version 17 | |||||
VisualStudioVersion = 17.4.33213.308 | |||||
MinimumVisualStudioVersion = 10.0.40219.1 | MinimumVisualStudioVersion = 10.0.40219.1 | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | ||||
EndProject | EndProject | ||||
@@ -0,0 +1,17 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class c_api | |||||
{ | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||||
} | |||||
} |
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -45,6 +46,23 @@ namespace Tensorflow | |||||
{ | { | ||||
return as_text(bytes_or_text, encoding); | return as_text(bytes_or_text, encoding); | ||||
} | } | ||||
public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||||
{ | |||||
return bytes; | |||||
} | |||||
public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||||
{ | |||||
return ByteString.CopyFrom(bytes); | |||||
} | |||||
public ByteString as_bytes(string text, Encoding encoding = null) | |||||
{ | |||||
if(encoding is null) | |||||
{ | |||||
encoding = Encoding.UTF8; | |||||
} | |||||
return ByteString.CopyFrom(encoding.GetBytes(text)); | |||||
} | |||||
} | } | ||||
public bool executing_eagerly() | public bool executing_eagerly() | ||||
@@ -54,6 +54,6 @@ namespace Tensorflow | |||||
Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
string[] return_elements = null, | string[] return_elements = null, | ||||
string name = null, | string name = null, | ||||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||||
} | } | ||||
} | } |
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
@@ -79,5 +81,10 @@ namespace Tensorflow | |||||
num_split: num_split, | num_split: num_split, | ||||
axis: axis, | axis: axis, | ||||
name: name); | name: name); | ||||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||||
{ | |||||
return gen_ops.ensure_shape(x, shape, name); | |||||
} | |||||
} | } | ||||
} | } |
@@ -61,7 +61,7 @@ namespace Tensorflow | |||||
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Set `num_dims` to -1 to represent "unknown rank". | /// Set `num_dims` to -1 to represent "unknown rank". | ||||
@@ -22,6 +22,7 @@ using System.ComponentModel; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -107,6 +107,12 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public void Release() | |||||
{ | |||||
_handle.Dispose(); | |||||
_handle = null; | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
=> $"0x{_handle.DangerousGetHandle():x16}"; | => $"0x{_handle.DangerousGetHandle():x16}"; | ||||
@@ -25,5 +25,32 @@ namespace Tensorflow | |||||
public IntPtr data; | public IntPtr data; | ||||
public ulong length; | public ulong length; | ||||
public IntPtr data_deallocator; | public IntPtr data_deallocator; | ||||
public unsafe Span<T> AsSpan<T>() where T: unmanaged | |||||
{ | |||||
if(length > int.MaxValue) | |||||
{ | |||||
throw new ValueError($"The length {length} is too large to use in the span."); | |||||
} | |||||
return new Span<T>(data.ToPointer(), (int)length); | |||||
} | |||||
public unsafe byte[] ToByteArray() | |||||
{ | |||||
byte[] res = new byte[length]; | |||||
if(length > int.MaxValue) | |||||
{ | |||||
byte* root = (byte*)data; | |||||
for(ulong i = 0; i < length; i++) | |||||
{ | |||||
res[i] = *(root++); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); | |||||
} | |||||
return res; | |||||
} | |||||
} | } | ||||
} | } |
@@ -161,7 +161,7 @@ public static class CheckPointUtils | |||||
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | ||||
{ | { | ||||
return full_list.TakeWhile(x => | |||||
return full_list.Where(x => | |||||
{ | { | ||||
var saveables = x.gather_saveables_for_checkpoint(); | var saveables = x.gather_saveables_for_checkpoint(); | ||||
return saveables is not null && saveables.Count > 0; | return saveables is not null && saveables.Count > 0; | ||||
@@ -1,10 +1,12 @@ | |||||
using System; | |||||
using OneOf; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using Tensorflow.Common.Extensions; | |||||
using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
@@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||||
); | ); | ||||
public static class SaveUtil | public static class SaveUtil | ||||
{ | { | ||||
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | ||||
{ | { | ||||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | ||||
@@ -104,7 +106,10 @@ namespace Tensorflow.Checkpoint | |||||
{ | { | ||||
var td = trackable_data[i]; | var td = trackable_data[i]; | ||||
Debug.Assert(td.node_id == i); | Debug.Assert(td.node_id == i); | ||||
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||||
TrackableObjectGraph.Types.TrackableObject trackable_object = new(); | |||||
trackable_object.SlotVariables.AddRange(td.slot_variable_proto); | |||||
trackable_object.Children.AddRange(td.children_proto); | |||||
object_graph_proto.Nodes.Add(trackable_object); | |||||
} | } | ||||
return object_graph_proto; | return object_graph_proto; | ||||
} | } | ||||
@@ -117,16 +122,16 @@ namespace Tensorflow.Checkpoint | |||||
/// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
/// <param name="cache"></param> | /// <param name="cache"></param> | ||||
/// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | ||||
{ | { | ||||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||||
foreach(var td in tensor_trackables) | foreach(var td in tensor_trackables) | ||||
{ | { | ||||
// TODO: deal with cache. | // TODO: deal with cache. | ||||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | ||||
Trackable trackable = null; | Trackable trackable = null; | ||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | ||||
{ | { | ||||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | ||||
@@ -148,12 +153,12 @@ namespace Tensorflow.Checkpoint | |||||
return serialized_tensors; | return serialized_tensors; | ||||
} | } | ||||
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
{ | { | ||||
var trackable = trackable_data.object_to_save; | var trackable = trackable_data.object_to_save; | ||||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | ||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||||
if (call_with_mapped_captures) | if (call_with_mapped_captures) | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -163,8 +168,7 @@ namespace Tensorflow.Checkpoint | |||||
ret_tensor_dict = trackable.serialize_to_tensors(); | ret_tensor_dict = trackable.serialize_to_tensors(); | ||||
} | } | ||||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||||
foreach(var pair in ret_tensor_dict) | foreach(var pair in ret_tensor_dict) | ||||
{ | { | ||||
var local_name = TrackableUtils.escape_local_name(pair.Key); | var local_name = TrackableUtils.escape_local_name(pair.Key); | ||||
@@ -173,10 +177,12 @@ namespace Tensorflow.Checkpoint | |||||
tensor_dict[checkpoint_key] = maybe_tensor; | tensor_dict[checkpoint_key] = maybe_tensor; | ||||
if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||||
foreach(var key in maybe_tensor.Keys) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||||
{ | |||||
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||||
} | |||||
} | } | ||||
if(object_graph_proto is not null) | if(object_graph_proto is not null) | ||||
@@ -200,7 +206,7 @@ namespace Tensorflow.Checkpoint | |||||
/// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
/// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | ||||
{ | { | ||||
Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
@@ -8,6 +8,7 @@ using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using OneOf; | |||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
@@ -114,14 +115,10 @@ public static class SaveUtilV1 | |||||
{ | { | ||||
var trackable = trackable_objects[i]; | var trackable = trackable_objects[i]; | ||||
Debug.Assert(node_ids[trackable] == i); | Debug.Assert(node_ids[trackable] == i); | ||||
TrackableObjectGraph.Types.TrackableObject object_proto; | |||||
var object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||||
if (slot_variables.TryGetValue(trackable, out var slots)) | if (slot_variables.TryGetValue(trackable, out var slots)) | ||||
{ | { | ||||
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||||
} | |||||
else | |||||
{ | |||||
object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||||
object_proto.SlotVariables.AddRange(slots); | |||||
} | } | ||||
object_graph_proto.Nodes.Add(object_proto); | object_graph_proto.Nodes.Add(object_proto); | ||||
foreach (var child in graph_view.list_children(trackable)) | foreach (var child in graph_view.list_children(trackable)) | ||||
@@ -184,13 +181,13 @@ public static class SaveUtilV1 | |||||
// TODO: tensorflow python has a process with callable `saveable_factory`. | // TODO: tensorflow python has a process with callable `saveable_factory`. | ||||
List<MySaveableObject> saveables = new(); | List<MySaveableObject> saveables = new(); | ||||
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||||
if (maybe_saveable.TryPickT1(out var s, out var variable)) | |||||
{ | { | ||||
saveables.Add(s); | saveables.Add(s); | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); | |||||
} | } | ||||
foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
@@ -222,7 +219,7 @@ public static class SaveUtilV1 | |||||
public record class CheckpointFactoryData | public record class CheckpointFactoryData | ||||
( | ( | ||||
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||||
string name, | string name, | ||||
string checkpoint_key | string checkpoint_key | ||||
); | ); |
@@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using OneOf; | |||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
@@ -44,12 +45,12 @@ public class TrackableSaver | |||||
_graph_view = graph_view; | _graph_view = graph_view; | ||||
// TODO: cache when not executing eagerly. | // TODO: cache when not executing eagerly. | ||||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||||
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | ||||
} | } | ||||
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | gather_serialized_tensors(Tensor? object_graph_tensor = null) | ||||
{ | { | ||||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | ||||
@@ -70,9 +71,10 @@ public class TrackableSaver | |||||
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
if (!serialized_tensors.ContainsKey(Trackable.None)) | if (!serialized_tensors.ContainsKey(Trackable.None)) | ||||
{ | { | ||||
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||||
serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||||
} | } | ||||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||||
return (serialized_tensors, feed_additions, registered_savers, graph_proto); | return (serialized_tensors, feed_additions, registered_savers, graph_proto); | ||||
} | } | ||||
@@ -392,6 +394,7 @@ public class CheckpointRestoreCoordinator | |||||
/// </summary> | /// </summary> | ||||
public List<Trackable> AllTrackables => _all_trackables; | public List<Trackable> AllTrackables => _all_trackables; | ||||
public HashSet<int> MatchedProtoIds => _matched_proto_ids; | public HashSet<int> MatchedProtoIds => _matched_proto_ids; | ||||
// TODO(Rinne): change to weak ref. | |||||
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | ||||
public int RestoreUid => _restore_uid; | public int RestoreUid => _restore_uid; | ||||
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | ||||
@@ -406,7 +409,7 @@ public class CheckpointRestoreCoordinator | |||||
// skip the callback. | // skip the callback. | ||||
} | } | ||||
public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
{ | { | ||||
List<Operation> restore_ops = new(); | List<Operation> restore_ops = new(); | ||||
foreach(var position in positions) | foreach(var position in positions) | ||||
@@ -418,7 +421,7 @@ public class CheckpointRestoreCoordinator | |||||
Dictionary<string, BaseResourceVariable> variable_dict = new(); | Dictionary<string, BaseResourceVariable> variable_dict = new(); | ||||
foreach(var item in tensor_saveables) | foreach(var item in tensor_saveables) | ||||
{ | { | ||||
if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||||
if(item.Value.TryPickT0(out var variable, out var _)) | |||||
{ | { | ||||
variable_dict[item.Key] = variable; | variable_dict[item.Key] = variable; | ||||
} | } | ||||
@@ -15,106 +15,14 @@ using Tensorflow.Graphs; | |||||
using System.Xml.Linq; | using System.Xml.Linq; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using RestoreFunc = System.Func<object, object>; | using RestoreFunc = System.Func<object, object>; | ||||
using OneOf; | |||||
namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
{ | { | ||||
public class Maybe<TA, TB> | |||||
{ | |||||
private TA? _valueA = default(TA); | |||||
private TB? _valueB = default(TB); | |||||
private Type _type; | |||||
private bool _assignedTA; | |||||
public Maybe(TA value) | |||||
{ | |||||
_valueA = value; | |||||
_type= typeof(TA); | |||||
_assignedTA = true; | |||||
} | |||||
public Maybe(TB value) | |||||
{ | |||||
_valueB = value; | |||||
_type = typeof(TB); | |||||
_assignedTA = false; | |||||
} | |||||
public Type DataType => _type; | |||||
/// <summary> | |||||
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||||
/// It returns | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="res"></param> | |||||
/// <returns></returns> | |||||
public bool TryGet<T>(out T? res) | |||||
{ | |||||
if(_valueA is T && _valueB is not T) | |||||
{ | |||||
res = (T)(object)_valueA; | |||||
return _assignedTA; | |||||
} | |||||
else if(_valueA is not T && _valueB is T) | |||||
{ | |||||
res = (T)(object)_valueB; | |||||
return !_assignedTA; | |||||
} | |||||
res = default(T); | |||||
return false; | |||||
} | |||||
public bool IsTypeOrDeriveFrom<T>() | |||||
{ | |||||
if (_valueA is T && _valueB is not T) | |||||
{ | |||||
return _assignedTA; | |||||
} | |||||
else if (_valueA is not T && _valueB is T) | |||||
{ | |||||
return !_assignedTA; | |||||
} | |||||
else if (_valueA is T && _valueB is T) | |||||
{ | |||||
return true; | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
public T GetValue<T>() | |||||
{ | |||||
if (_valueA is T && _valueB is not T) | |||||
{ | |||||
return (T)(object)_valueA; | |||||
} | |||||
else if (_valueA is not T && _valueB is T) | |||||
{ | |||||
return (T)(object)_valueB; | |||||
} | |||||
else if (_valueA is T && _valueB is T) | |||||
{ | |||||
throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||||
} | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TA a) | |||||
{ | |||||
return new Maybe<TA, TB>(a); | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TB b) | |||||
{ | |||||
return new Maybe<TA, TB>(b); | |||||
} | |||||
} | |||||
internal class SingleDeviceSaver | internal class SingleDeviceSaver | ||||
{ | { | ||||
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
{ | { | ||||
_tensor_slice_dict = tensor_slice_dict; | _tensor_slice_dict = tensor_slice_dict; | ||||
} | } | ||||
@@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint | |||||
{ | { | ||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | _tensor_slice_dict = tensor_slice_dict.ToDictionary( | ||||
x => x.Key, x => x.Value.ToDictionary( | x => x.Key, x => x.Value.ToDictionary( | ||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value)) | |||||
as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||||
} | } | ||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | ||||
{ | { | ||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | _tensor_slice_dict = tensor_slice_dict.ToDictionary( | ||||
x => x.Key, x => x.Value.ToDictionary( | x => x.Key, x => x.Value.ToDictionary( | ||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value)) | |||||
as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||||
} | } | ||||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | ||||
{ | { | ||||
@@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint | |||||
{ | { | ||||
var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||||
{ | { | ||||
var tensor_value = spec.tensor; | var tensor_value = spec.tensor; | ||||
if (tensor_value is not null) | if (tensor_value is not null) | ||||
@@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
tensors.Add(tensor); | tensors.Add(tensor); | ||||
slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
@@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint | |||||
var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | // TODO: deal with other types. Currently only `SaveSpec` is allowed. | ||||
if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||||
{ | { | ||||
tensor_dtypes.Add(spec.dtype); | tensor_dtypes.Add(spec.dtype); | ||||
slice_specs.Add(spec.slice_spec); | slice_specs.Add(spec.slice_spec); | ||||
@@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
tensor_dtypes.Add(tensor.dtype); | tensor_dtypes.Add(tensor.dtype); | ||||
slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
@@ -256,12 +162,12 @@ namespace Tensorflow.Checkpoint | |||||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | ||||
/// <param name="registered_savers"></param> | /// <param name="registered_savers"></param> | ||||
/// <param name="call_with_mapped_capture"></param> | /// <param name="call_with_mapped_capture"></param> | ||||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | ||||
{ | { | ||||
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | ||||
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | ||||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||||
Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||||
foreach(var pair in serialized_tensors) | foreach(var pair in serialized_tensors) | ||||
{ | { | ||||
@@ -276,9 +182,9 @@ namespace Tensorflow.Checkpoint | |||||
{ | { | ||||
restore_fn = new RestoreFunc(x => | restore_fn = new RestoreFunc(x => | ||||
{ | { | ||||
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||||
if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) | |||||
{ | { | ||||
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||||
return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>); | |||||
} | } | ||||
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | ||||
}); | }); | ||||
@@ -287,16 +193,7 @@ namespace Tensorflow.Checkpoint | |||||
foreach(var item in tensor_dict) | foreach(var item in tensor_dict) | ||||
{ | { | ||||
var checkpoint_key = item.Key; | var checkpoint_key = item.Key; | ||||
IDictionary<string, Tensor> spec_to_tensor; | |||||
if(item.Value.TryGet<Tensor>(out var t)) | |||||
{ | |||||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||||
spec_to_tensor[""] = t; | |||||
} | |||||
else | |||||
{ | |||||
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||||
} | |||||
var spec_to_tensor = item.Value; | |||||
foreach(var spec in spec_to_tensor) | foreach(var spec in spec_to_tensor) | ||||
{ | { | ||||
@@ -311,12 +208,20 @@ namespace Tensorflow.Checkpoint | |||||
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | ||||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | ||||
// skip the process of device name because lack of API. | |||||
var host_device = tensor.Device; | |||||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||||
string host_device; | |||||
if (tensor.IsT0) | |||||
{ | |||||
host_device = tensor.AsT0.Device; | |||||
} | |||||
else | |||||
{ | |||||
host_device = tensor.AsT1.device; | |||||
} | |||||
host_device = saveable_object_util.set_cpu0(host_device); | |||||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | if (!internal_dict.ContainsKey(checkpoint_key)) | ||||
{ | { | ||||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||||
} | } | ||||
internal_dict[checkpoint_key][slice_spec] = tensor; | internal_dict[checkpoint_key][slice_spec] = tensor; | ||||
} | } | ||||
@@ -412,7 +317,7 @@ namespace Tensorflow.Checkpoint | |||||
IDictionary<string, Operation> restore_func() | IDictionary<string, Operation> restore_func() | ||||
{ | { | ||||
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | ||||
Dictionary<string, Operation> restore_ops = new(); | Dictionary<string, Operation> restore_ops = new(); | ||||
@@ -433,29 +338,29 @@ namespace Tensorflow.Checkpoint | |||||
var slice_spec = item.Key; | var slice_spec = item.Key; | ||||
var tensor = item.Value; | var tensor = item.Value; | ||||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | ||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | if (!string.IsNullOrEmpty(slice_spec)) | ||||
{ | { | ||||
if (!internal_dict.ContainsKey(checkpoint_key)) | if (!internal_dict.ContainsKey(checkpoint_key)) | ||||
{ | { | ||||
Dictionary<string, Tensor> dict = new(); | Dictionary<string, Tensor> dict = new(); | ||||
dict[slice_spec] = tensor; | dict[slice_spec] = tensor; | ||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||||
} | } | ||||
restore_fn_input_count[restore_fn]--; | restore_fn_input_count[restore_fn]--; | ||||
if (restore_fn_input_count[restore_fn] == 0) | if (restore_fn_input_count[restore_fn] == 0) | ||||
{ | { | ||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
foreach (var input in restore_fn_inputs[restore_fn]) | foreach (var input in restore_fn_inputs[restore_fn]) | ||||
{ | { | ||||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | ||||
@@ -538,7 +443,7 @@ namespace Tensorflow.Checkpoint | |||||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | ||||
{ | { | ||||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||||
foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
{ | { | ||||
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | ||||
@@ -1,7 +1,9 @@ | |||||
using System; | |||||
using OneOf; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Security; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
@@ -49,7 +51,7 @@ public class CheckpointPosition | |||||
{ | { | ||||
_checkpoint.AllTrackables.Add(trackable); | _checkpoint.AllTrackables.Add(trackable); | ||||
_checkpoint.MatchedProtoIds.Add(_proto_id); | _checkpoint.MatchedProtoIds.Add(_proto_id); | ||||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||||
{ | { | ||||
// skip the `logging.warning`. | // skip the `logging.warning`. | ||||
return false; | return false; | ||||
@@ -61,13 +63,13 @@ public class CheckpointPosition | |||||
} | } | ||||
} | } | ||||
public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
{ | { | ||||
// skip the registered_saver | // skip the registered_saver | ||||
if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | ||||
{ | { | ||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||||
new List<CheckpointPosition>(), null); | new List<CheckpointPosition>(), null); | ||||
} | } | ||||
@@ -75,7 +77,7 @@ public class CheckpointPosition | |||||
List<Operation> existing_restore_ops; | List<Operation> existing_restore_ops; | ||||
List<CheckpointPosition> positions = new(); | List<CheckpointPosition> positions = new(); | ||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | ||||
{ | { | ||||
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | ||||
@@ -109,8 +111,8 @@ public class CheckpointPosition | |||||
/// Creates a saveable using the _serialize_to_tensor method. | /// Creates a saveable using the _serialize_to_tensor method. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="saveable_factories"></param> | /// <param name="saveable_factories"></param> | ||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
{ | { | ||||
string suffix = SaveableCompat.get_saveable_name(this.Trackable); | string suffix = SaveableCompat.get_saveable_name(this.Trackable); | ||||
suffix = suffix ?? ""; | suffix = suffix ?? ""; | ||||
@@ -124,23 +126,23 @@ public class CheckpointPosition | |||||
var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | ||||
// skip the cache. | // skip the cache. | ||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
dict[saveable_name] = saveable; | dict[saveable_name] = saveable; | ||||
return (new List<Operation>(), dict); | return (new List<Operation>(), dict); | ||||
} | } | ||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
{ | { | ||||
// TODO(Rinne): implement it. | // TODO(Rinne): implement it. | ||||
if(ObjectProto.Attributes is null) | if(ObjectProto.Attributes is null) | ||||
{ | { | ||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>()); | |||||
} | } | ||||
List<Operation> existing_restore_ops = new(); | List<Operation> existing_restore_ops = new(); | ||||
HashSet<string> created_compat_names = new(); | HashSet<string> created_compat_names = new(); | ||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
foreach (var serialized_tensor in ObjectProto.Attributes) | foreach (var serialized_tensor in ObjectProto.Attributes) | ||||
{ | { | ||||
Operation existing_op; | Operation existing_op; | ||||
@@ -172,12 +174,12 @@ public class CheckpointPosition | |||||
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | ||||
continue; | continue; | ||||
} | } | ||||
named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||||
named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; | |||||
} | } | ||||
return (existing_restore_ops, named_saveables); | return (existing_restore_ops, named_saveables); | ||||
} | } | ||||
private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | ||||
{ | { | ||||
var expected_factory_name = serialized_tensor.Name; | var expected_factory_name = serialized_tensor.Name; | ||||
@@ -221,7 +223,7 @@ public class CheckpointPosition | |||||
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | ||||
visit_queue.Enqueue((this, this.Trackable)); | visit_queue.Enqueue((this, this.Trackable)); | ||||
List<Operation> restore_ops = new(); | List<Operation> restore_ops = new(); | ||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
List<CheckpointPosition> positions = new(); | List<CheckpointPosition> positions = new(); | ||||
CheckpointPosition current_position = null; | CheckpointPosition current_position = null; | ||||
@@ -306,7 +308,7 @@ public class CheckpointPosition | |||||
} | } | ||||
} | } | ||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
{ | { | ||||
var trackable = this.Trackable; | var trackable = this.Trackable; | ||||
trackable._maybe_initialize_trackable(); | trackable._maybe_initialize_trackable(); | ||||
@@ -318,7 +320,7 @@ public class CheckpointPosition | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||||
new List<CheckpointPosition>(), null); | new List<CheckpointPosition>(), null); | ||||
} | } | ||||
} | } | ||||
@@ -14,9 +14,11 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using System; | using System; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
{ | { | ||||
@@ -25,12 +27,93 @@ namespace Tensorflow.Contexts | |||||
/// </summary> | /// </summary> | ||||
public sealed partial class Context | public sealed partial class Context | ||||
{ | { | ||||
public ConfigProto Config { get; set; } = new ConfigProto | |||||
protected Device.PhysicalDevice[] _physical_devices; | |||||
protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index; | |||||
ConfigProto _config; | |||||
public ConfigProto Config | |||||
{ | { | ||||
GpuOptions = new GPUOptions | |||||
get | |||||
{ | { | ||||
_initialize_physical_devices(); | |||||
var config = new ConfigProto(); | |||||
if(_config is not null) | |||||
{ | |||||
config.MergeFrom(_config); | |||||
} | |||||
config.LogDevicePlacement = _log_device_placement; | |||||
config.DeviceCount["CPU"] = 0; | |||||
config.DeviceCount["GPU"] = 0; | |||||
foreach(var dev in _physical_devices) | |||||
{ | |||||
if (config.DeviceCount.ContainsKey(dev.DeviceType)) | |||||
{ | |||||
config.DeviceCount[dev.DeviceType] += 1; | |||||
} | |||||
else | |||||
{ | |||||
config.DeviceCount[dev.DeviceType] = 1; | |||||
} | |||||
} | |||||
var gpu_options = _compute_gpu_options(); | |||||
config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); | |||||
return config; | |||||
} | |||||
set | |||||
{ | |||||
_config = value; | |||||
} | |||||
} | |||||
protected void _initialize_physical_devices(bool reinitialize = false) | |||||
{ | |||||
if(!reinitialize && _physical_devices is not null) | |||||
{ | |||||
return; | |||||
} | |||||
var devs = list_physical_devices(); | |||||
_physical_devices = devs.Select(d => new Device.PhysicalDevice() | |||||
{ | |||||
DeviceName = d.DeviceName, | |||||
DeviceType = d.DeviceType | |||||
}).ToArray(); | |||||
_physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i)) | |||||
.ToDictionary(x => x.Key, x => x.Value); | |||||
_import_config(); | |||||
} | |||||
protected void _import_config() | |||||
{ | |||||
if(_config is null) | |||||
{ | |||||
return; | |||||
} | |||||
if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) | |||||
{ | |||||
num_cpus = 1; | |||||
} | |||||
if(num_cpus != 1) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | } | ||||
}; | |||||
var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); | |||||
if(gpus.Count() == 0) | |||||
{ | |||||
return; | |||||
} | |||||
if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) | |||||
{ | |||||
gpu_count = 0; | |||||
} | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
ConfigProto MergeConfig() | ConfigProto MergeConfig() | ||||
{ | { | ||||
@@ -111,6 +111,14 @@ namespace Tensorflow.Contexts | |||||
return results.ToArray(); | return results.ToArray(); | ||||
} | } | ||||
public bool is_custom_device(string device_name) | |||||
{ | |||||
return false; | |||||
// TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. | |||||
//ensure_initialized(); | |||||
//return c_api.TFE_IsCustomDevice(_handle, device_name); | |||||
} | |||||
public EagerDeviceContext device(string name) | public EagerDeviceContext device(string name) | ||||
{ | { | ||||
return new EagerDeviceContext(this, name); | return new EagerDeviceContext(this, name); | ||||
@@ -37,7 +37,26 @@ namespace Tensorflow.Contexts | |||||
public string ScopeName { get; set; } = ""; | public string ScopeName { get; set; } = ""; | ||||
bool initialized = false; | bool initialized = false; | ||||
ContextSwitchStack context_switches; | ContextSwitchStack context_switches; | ||||
public FunctionCallOptions FunctionCallOptions { get; } | |||||
protected FunctionCallOptions _function_call_options; | |||||
public FunctionCallOptions FunctionCallOptions | |||||
{ | |||||
get | |||||
{ | |||||
if(_function_call_options is null) | |||||
{ | |||||
var config = Config; | |||||
_function_call_options = new FunctionCallOptions() | |||||
{ | |||||
Config = config | |||||
}; | |||||
} | |||||
return _function_call_options; | |||||
} | |||||
set | |||||
{ | |||||
_function_call_options = value; | |||||
} | |||||
} | |||||
SafeContextHandle _handle; | SafeContextHandle _handle; | ||||
@@ -122,6 +141,11 @@ namespace Tensorflow.Contexts | |||||
name : | name : | ||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | ||||
public string anonymous_name() | |||||
{ | |||||
return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||||
} | |||||
public void graph_mode(bool isFunc = false) | public void graph_mode(bool isFunc = false) | ||||
=> context_switches.Push(false, isFunc); | => context_switches.Push(false, isFunc); | ||||
@@ -158,6 +182,37 @@ namespace Tensorflow.Contexts | |||||
return has_graph_arg; | return has_graph_arg; | ||||
} | } | ||||
public bool has_function(string name) | |||||
{ | |||||
ensure_initialized(); | |||||
return c_api.TFE_ContextHasFunction(_handle, name); | |||||
} | |||||
public void add_function(SafeFuncGraphHandle fn) | |||||
{ | |||||
ensure_initialized(); | |||||
Status status = new(); | |||||
c_api.TFE_ContextAddFunction(_handle, fn, status); | |||||
status.Check(true); | |||||
} | |||||
public void remove_function(string name) | |||||
{ | |||||
ensure_initialized(); | |||||
Status status = new(); | |||||
c_api.TFE_ContextRemoveFunction(_handle, name, status); | |||||
status.Check(true); | |||||
} | |||||
public void add_function_def(FunctionDef fdef) | |||||
{ | |||||
ensure_initialized(); | |||||
var fdef_string = fdef.ToByteArray(); | |||||
Status status = new Status(); | |||||
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); | |||||
status.Check(true); | |||||
} | |||||
public void restore_mode() | public void restore_mode() | ||||
{ | { | ||||
context_switches.Pop(); | context_switches.Pop(); | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Protobuf.Text; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
@@ -9,10 +10,11 @@ namespace Tensorflow.Contexts | |||||
public class FunctionCallOptions | public class FunctionCallOptions | ||||
{ | { | ||||
public ConfigProto Config { get; set; } | public ConfigProto Config { get; set; } | ||||
public string ExecutorType { get; set; } | |||||
public string config_proto_serialized() | |||||
public ByteString config_proto_serialized() | |||||
{ | { | ||||
return Config.ToByteString().ToStringUtf8(); | |||||
return Config.ToByteString(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
private bool ShouldRecord(Tensor[] inputs) | |||||
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||||
{ | { | ||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
var tape_set = tf.GetTapeSet(); | |||||
var input_ids = MakeTensorIDList(tensors); | |||||
var input_dtypes = MakeTensorDtypeList(tensors); | |||||
bool some_tape_watching = false; | |||||
if (tape_set is not null && tape_set.Count > 0) | |||||
{ | { | ||||
if (tape.ShouldRecord(inputs)) | |||||
foreach (var tape in tape_set) | |||||
{ | { | ||||
should_record = true; | |||||
break; | |||||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
if (tape.Persistent || some_tape_watching) | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||||
} | |||||
some_tape_watching = true; | |||||
} | |||||
} | } | ||||
} | } | ||||
return should_record; | |||||
// skip the forward_accumulators. | |||||
if (some_tape_watching) | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||||
} | |||||
else | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||||
Tensor[] results, | Tensor[] results, | ||||
BackwardFunction backwardFunction = null) | BackwardFunction backwardFunction = null) | ||||
{ | { | ||||
bool should_record = ShouldRecord(inputs); | |||||
var input_ids = MakeTensorIDList(inputs); | |||||
var input_dtypes = MakeTensorDtypeList(inputs); | |||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
{ | |||||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
should_record = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!should_record) | if (!should_record) | ||||
{ | { | ||||
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||||
op_inputs = inputs;*/ | op_inputs = inputs;*/ | ||||
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | ||||
TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||||
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||||
{ | |||||
return tensors.Select(x => x.dtype).ToArray(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Functions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
@@ -358,7 +358,7 @@ namespace Tensorflow.Eager | |||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
if (value is ConcreteFunction func) | if (value is ConcreteFunction func) | ||||
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||||
else | else | ||||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | ||||
break; | break; | ||||
@@ -1,6 +1,8 @@ | |||||
using System; | |||||
using OneOf.Types; | |||||
using System; | |||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||||
/// </summary> | /// </summary> | ||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="tape"></param> | |||||
/// <param name="target"></param> | |||||
/// <param name="sources"></param> | |||||
/// <param name="output_gradients"></param> | |||||
/// <param name="unconnected_gradients">determines the value returned if the target and | |||||
/// sources are unconnected.When 'none' the value returned is None wheras when | |||||
/// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||||
/// <returns></returns> | |||||
/// <exception cref="RuntimeError"></exception> | |||||
public Tensor[] TFE_TapeGradient(ITape tape, | public Tensor[] TFE_TapeGradient(ITape tape, | ||||
Tensor[] target, | Tensor[] target, | ||||
Tensor[] sources, | Tensor[] sources, | ||||
Tensor[] output_gradients) | |||||
List<Tensor> output_gradients, | |||||
Tensor[] sources_raw, | |||||
string unconnected_gradients) | |||||
{ | { | ||||
var target_vec = target; | |||||
var sources_vec = sources; | |||||
var sources_set = sources_vec; | |||||
if (!tape.Persistent) | |||||
{ | |||||
var tape_set = tf.GetTapeSet(); | |||||
if (tape_set.Contains(tape)) | |||||
{ | |||||
throw new RuntimeError("gradient() cannot be invoked within the " + | |||||
"GradientTape context (i.e., while operations are being " + | |||||
"recorded). Either move the call to gradient() to be " + | |||||
"outside the 'with tf.GradientTape' block, or " + | |||||
"use a persistent tape: " + | |||||
"'with tf.GradientTape(persistent=true)'"); | |||||
} | |||||
} | |||||
var target_vec = MakeTensorIDList(target); | |||||
var sources_vec = MakeTensorIDList(sources); | |||||
HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||||
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||||
int len = target.Length; | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
var target_id = target_vec[i]; | |||||
if (sources_set.Contains(target_id)) | |||||
{ | |||||
var tensor = target[i]; | |||||
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||||
} | |||||
} | |||||
List<Tensor> outgrad_vec = new(); | |||||
if(output_gradients is not null) | |||||
{ | |||||
outgrad_vec = output_gradients.ToList(); | |||||
} | |||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||||
var seq_array = target; | |||||
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||||
for (int i = 0; i < target.Length; ++i) | |||||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||||
Tensor[] sources_obj = null; | |||||
if (unconnected_gradients_zero) | |||||
{ | { | ||||
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||||
sources_obj = MakeTensorList(sources_raw); | |||||
} | } | ||||
if (output_gradients != null) | |||||
if (result.Length > 0) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
for(int i = 0; i < result.Length; i++) | |||||
{ | |||||
if (result[i] is null && unconnected_gradients_zero) | |||||
{ | |||||
var dtype = sources_obj[i].dtype; | |||||
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||||
} | |||||
} | |||||
} | } | ||||
else | |||||
return result; | |||||
} | |||||
Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||||
{ | |||||
return tensors.ToArray(); | |||||
} | |||||
long[] MakeTensorIDList(Tensor[] tensors) | |||||
{ | |||||
int len = tensors.Length; | |||||
long[] ids = new long[len]; | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
var tensor = tensors[i]; | |||||
ids[i] = tensor.Id; | |||||
} | |||||
return ids; | |||||
} | |||||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||||
{ | |||||
int len = tensors.Length; | |||||
TF_DataType[] dtypes = new TF_DataType[len]; | |||||
for (int i = 0; i < len; i++) | |||||
{ | { | ||||
output_gradients = new Tensor[0]; | |||||
var tensor = tensors[i]; | |||||
dtypes[i] = tensor.dtype; | |||||
} | } | ||||
return dtypes; | |||||
} | |||||
var outgrad_vec = MakeTensorList(output_gradients); | |||||
TapeTensor TapeTensorFromTensor(Tensor tensor) | |||||
{ | |||||
long id = tensor.Id; | |||||
var dtype = tensor.dtype; | |||||
if (tensor is EagerTensor) | |||||
{ | |||||
var handle = tensor.EagerTensorHandle; | |||||
if (DTypeNeedsHandleData(dtype)) | |||||
{ | |||||
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||||
} | |||||
Status status = new(); | |||||
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||||
long[] dims = new long[num_dims]; | |||||
for(int i = 0; i < num_dims; i++) | |||||
{ | |||||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||||
} | |||||
Shape tensor_shape = new(dims); | |||||
if(status.Code != TF_Code.TF_OK) | |||||
{ | |||||
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||||
} | |||||
else | |||||
{ | |||||
return new TapeTensor(id, dtype, tensor_shape); | |||||
} | |||||
} | |||||
var shape_tuple = tensor.shape.dims; | |||||
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||||
{ | |||||
return new TapeTensor(id, dtype, tensor); | |||||
} | |||||
long[] l = new long[shape_tuple.Length]; | |||||
for(int i = 0; i < shape_tuple.Length; i++) | |||||
{ | |||||
if (shape_tuple[i] < 0) | |||||
{ | |||||
l[i] = 0; | |||||
} | |||||
else | |||||
{ | |||||
l[i] = shape_tuple[i]; | |||||
} | |||||
} | |||||
return new TapeTensor(id, dtype, new Shape(l)); | |||||
} | |||||
return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||||
bool DTypeNeedsHandleData(TF_DataType dtype) | |||||
{ | |||||
return dtype == dtypes.variant || dtype == dtypes.resource; | |||||
} | } | ||||
Tensor[] MakeTensorList(Tensor[] tensors) | |||||
bool ListContainNone(long[] list) | |||||
{ | { | ||||
return tensors; | |||||
int len = list.Length; | |||||
if(len == 0) | |||||
{ | |||||
return true; | |||||
} | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
if (list[i] == -1) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
void TapeSetRecordBackprop(string op_type, | void TapeSetRecordBackprop(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | |||||
TapeTensor[] output_info, | |||||
long[] input_ids, | |||||
TF_DataType[] input_detyps, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
if (!CouldBackprop()) | if (!CouldBackprop()) | ||||
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||||
foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
{ | { | ||||
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||||
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||||
public bool TapeSetRecordOperation(string op_type, | public bool TapeSetRecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
Tensor[] output_tensors, | Tensor[] output_tensors, | ||||
long[] input_ids, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||||
var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | ||||
backward_function)) | backward_function)) | ||||
return false; | return false; | ||||
TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||||
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||||
backward_function); | backward_function); | ||||
return true; | return true; | ||||
} | } | ||||
public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||||
Tensor[] input_tensors, BackwardFunction backward_function) | |||||
{ | |||||
var input_ids = MakeTensorIDList(input_tensors); | |||||
var input_dtypes = MakeTensorDtypeList(input_tensors); | |||||
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||||
backward_function); | |||||
} | |||||
} | } | ||||
} | } |
@@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||||
Tensor[] TFE_TapeGradient(ITape tape, | Tensor[] TFE_TapeGradient(ITape tape, | ||||
Tensor[] target, | Tensor[] target, | ||||
Tensor[] sources, | Tensor[] sources, | ||||
Tensor[] output_gradients); | |||||
List<Tensor> output_gradients, | |||||
Tensor[] sources_raw, | |||||
string unconnected_gradients); | |||||
void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||||
Tensor[] input_tensors, BackwardFunction backward_function); | |||||
int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||||
bool RecordGradient(string op_name, | bool RecordGradient(string op_name, | ||||
Tensor[] inputs, | Tensor[] inputs, | ||||
@@ -0,0 +1,53 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Operations; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
internal static class backprop_util | |||||
{ | |||||
// TODO: add quantized_dtypes (after being supported). | |||||
private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[] | |||||
{ | |||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, | |||||
dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 | |||||
}); | |||||
public static bool IsTrainable(Tensor tensor) | |||||
{ | |||||
var dtype = _DTypeFromTensor(tensor); | |||||
return _trainable_dtypes.Contains(dtype); | |||||
} | |||||
public static bool IsTrainable(TF_DataType dtype) | |||||
{ | |||||
return _trainable_dtypes.Contains(dtype); | |||||
} | |||||
private static TF_DataType _DTypeFromTensor(Tensor tensor) | |||||
{ | |||||
var dtype = tensor.dtype; | |||||
if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) | |||||
{ | |||||
CppShapeInferenceResult.Types.HandleData handle_data; | |||||
if (tensor is EagerTensor) | |||||
{ | |||||
handle_data = tensor.HandleData; | |||||
} | |||||
else | |||||
{ | |||||
handle_data = handle_data_util.get_resource_handle_data(tensor); | |||||
} | |||||
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && | |||||
handle_data.ShapeAndType.Count > 0) | |||||
{ | |||||
var first_type = handle_data.ShapeAndType[0].Dtype; | |||||
if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) | |||||
{ | |||||
return first_type.as_tf_dtype(); | |||||
} | |||||
} | |||||
} | |||||
return dtype; | |||||
} | |||||
} | |||||
} |
@@ -30,6 +30,9 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | ||||
@@ -277,7 +280,7 @@ namespace Tensorflow | |||||
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||||
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -480,5 +483,8 @@ namespace Tensorflow | |||||
IntPtr[] target, int target_size, | IntPtr[] target, int target_size, | ||||
IntPtr[] sources, int source_size, | IntPtr[] sources, int source_size, | ||||
IntPtr[] outputs, int output_size); | IntPtr[] outputs, int output_size); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); | |||||
} | } | ||||
} | } |
@@ -18,6 +18,10 @@ namespace Tensorflow.Eager | |||||
var types = v.Select(t => t.dtype.as_datatype_enum()); | var types = v.Select(t => t.dtype.as_datatype_enum()); | ||||
return (types.ToArray(), v.ToArray()); | return (types.ToArray(), v.ToArray()); | ||||
} | } | ||||
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||||
{ | |||||
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||||
} | |||||
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | ||||
{ | { | ||||
string device_name = ctx.DeviceName; | string device_name = ctx.DeviceName; | ||||
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public class TangentInfo | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
public object Indices { get; set; } | |||||
public object Tangents { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.CompilerServices; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | |||||
public static class DictionaryExtension | |||||
{ | |||||
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second) | |||||
{ | |||||
first = pair.Key; | |||||
second = pair.Value; | |||||
} | |||||
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other) | |||||
{ | |||||
foreach(var (key, value) in other) | |||||
{ | |||||
dic[key] = value; | |||||
} | |||||
} | |||||
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue) | |||||
{ | |||||
if (dic.ContainsKey(key)) | |||||
{ | |||||
return dic[key]; | |||||
} | |||||
return defaultValue; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.CompilerServices; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public class NamedTuple | |||||
{ | |||||
public string Name { get; set; } | |||||
public Dictionary<string, object> ValueDict { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using OneOf; | |||||
using System; | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | |||||
public static class OneofExtension | |||||
{ | |||||
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src) | |||||
{ | |||||
return src.Value is T; | |||||
} | |||||
} | |||||
} |
@@ -6,8 +6,11 @@ | |||||
public class DenseSpec : TypeSpec | public class DenseSpec : TypeSpec | ||||
{ | { | ||||
protected Shape _shape; | protected Shape _shape; | ||||
public Shape shape => _shape; | |||||
public Shape shape | |||||
{ | |||||
get { return _shape; } | |||||
set { _shape = value; } | |||||
} | |||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
@@ -1,6 +0,0 @@ | |||||
namespace Tensorflow.Framework.Models | |||||
{ | |||||
class ScopedTFFunction | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
internal class ScopedTFFunction | |||||
{ | |||||
SafeFuncGraphHandle _handle; | |||||
string _name; | |||||
public ScopedTFFunction(SafeFuncGraphHandle func, string name) | |||||
{ | |||||
_handle = func; | |||||
_name = name; | |||||
} | |||||
public SafeFuncGraphHandle Get() | |||||
{ | |||||
return _handle; | |||||
} | |||||
} | |||||
} |
@@ -111,7 +111,17 @@ namespace Tensorflow | |||||
public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | ||||
public static Buffer tf_buffer(byte[] data) => new Buffer(data); | |||||
public static Buffer tf_buffer(byte[] data = null) | |||||
{ | |||||
if(data is not null) | |||||
{ | |||||
return new Buffer(data); ; | |||||
} | |||||
else | |||||
{ | |||||
return new Buffer(); | |||||
} | |||||
} | |||||
public static IEnumerable<Operation> new_tf_operations(Graph graph) | public static IEnumerable<Operation> new_tf_operations(Graph graph) | ||||
{ | { | ||||
@@ -0,0 +1,297 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Security.Cryptography; | |||||
using System.Text; | |||||
using Tensorflow.Graphs; | |||||
using Tensorflow.Common.Extensions; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.CppShapeInferenceResult.Types; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
public class function_def_lib | |||||
{ | |||||
// TODO(Rinne): process signatures and structured outputs. | |||||
public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||||
object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||||
{ | |||||
var func_graph = new FuncGraph(fdef.Signature.Name); | |||||
if(input_shapes is null) | |||||
{ | |||||
if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||||
{ | |||||
var raw_input_shapes = input_shapes_attr.List.Shape; | |||||
input_shapes = new List<TensorShapeProto>(); | |||||
foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||||
{ | |||||
if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||||
{ | |||||
input_shapes.Add(null); | |||||
} | |||||
else | |||||
{ | |||||
input_shapes.Add(input_shape); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||||
func_graph.as_default(); | |||||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
// TODO(Rinne): func_graph.ControlOutputs | |||||
_set_handle_data(func_graph, fdef); | |||||
foreach(var node in graph_def.Node) | |||||
{ | |||||
if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||||
{ | |||||
var op = func_graph.get_operation_by_name(node.Name); | |||||
foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||||
{ | |||||
op.outputs[output_index].shape = new Shape(shape); | |||||
} | |||||
} | |||||
} | |||||
Dictionary<long, string> output_names = new(); | |||||
foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||||
{ | |||||
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||||
} | |||||
func_graph._output_names = output_names; | |||||
func_graph.Exit(); | |||||
return func_graph; | |||||
} | |||||
public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||||
{ | |||||
var graph_def = new GraphDef() | |||||
{ | |||||
Versions = new VersionDef() | |||||
{ | |||||
Producer = versions.GRAPH_DEF_VERSION, | |||||
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||||
} | |||||
}; | |||||
var default_graph = ops.get_default_graph(); | |||||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||||
{ | |||||
throw new ValueError($"Length of `input_shapes` must match the number " + | |||||
$"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||||
$"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||||
} | |||||
foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||||
{ | |||||
NodeDef node_def = new(); | |||||
node_def.Name = arg_def.Name; | |||||
node_def.Op = "Placeholder"; | |||||
node_def.Attr["dtype"] = new AttrValue() | |||||
{ | |||||
Type = arg_def.Type | |||||
}; | |||||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||||
{ | |||||
var input_shape = input_shapes[i]; | |||||
// skip the condition that input_shape is not `TensorShapeProto`. | |||||
AttrValue shape = new AttrValue() | |||||
{ | |||||
Shape = new TensorShapeProto() | |||||
}; | |||||
shape.Shape = new TensorShapeProto(input_shape); | |||||
node_def.Attr["shape"] = shape; | |||||
} | |||||
if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||||
{ | |||||
fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||||
} | |||||
var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||||
foreach(var k in arg_attrs.Keys) | |||||
{ | |||||
if(k == "_output_shapes") | |||||
{ | |||||
if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||||
} | |||||
else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
{ | |||||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||||
} | |||||
} | |||||
else if (k.StartsWith("_")) | |||||
{ | |||||
if (!node_def.Attr.ContainsKey(k)) | |||||
{ | |||||
node_def.Attr[k] = new AttrValue(); | |||||
} | |||||
node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||||
} | |||||
} | |||||
graph_def.Node.Add(node_def); | |||||
} | |||||
graph_def.Node.AddRange(fdef.NodeDef); | |||||
Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||||
foreach(var arg_def in fdef.Signature.InputArg) | |||||
{ | |||||
nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||||
string control_name = "^" + arg_def.Name; | |||||
nested_to_flat_tensor_name[control_name] = control_name; | |||||
} | |||||
foreach(var node_def in fdef.NodeDef) | |||||
{ | |||||
var graph = default_graph; | |||||
while (true) | |||||
{ | |||||
if(graph is null) | |||||
{ | |||||
break; | |||||
} | |||||
var f = graph.Functions.GetOrDefault(node_def.Op, null); | |||||
if(f is not null && graph.OuterGraph is null) | |||||
{ | |||||
break; | |||||
} | |||||
graph = graph.OuterGraph; | |||||
} | |||||
var op_def = default_graph.GetOpDef(node_def.Op); | |||||
foreach(var attr in op_def.Attr) | |||||
{ | |||||
if(attr.Type == "func") | |||||
{ | |||||
var fname = node_def.Attr[attr.Name].Func.Name; | |||||
if (!is_function(fname)) | |||||
{ | |||||
throw new ValueError($"Function {fname} was not found. Please make sure " + | |||||
$"the FunctionDef `fdef` is correct."); | |||||
} | |||||
} | |||||
else if(attr.Type == "list(func)") | |||||
{ | |||||
foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||||
{ | |||||
var fname = fn.Name; | |||||
if (!is_function(fname)) | |||||
{ | |||||
throw new ValueError($"Function {fname} was not found. Please make " + | |||||
$"sure the FunctionDef `fdef` is correct."); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
int flattened_index = 0; | |||||
foreach(var arg_def in op_def.OutputArg) | |||||
{ | |||||
var num_args = _get_num_args(arg_def, node_def); | |||||
for(int i = 0; i < num_args; i++) | |||||
{ | |||||
var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||||
var flat_name = $"{node_def.Name}:{flattened_index}"; | |||||
nested_to_flat_tensor_name[nested_name] = flat_name; | |||||
flattened_index++; | |||||
} | |||||
} | |||||
string control_name = "^" + node_def.Name; | |||||
nested_to_flat_tensor_name[control_name] = control_name; | |||||
} | |||||
foreach(var node_def in graph_def.Node) | |||||
{ | |||||
for(int i = 0; i < node_def.Input.Count; i++) | |||||
{ | |||||
node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||||
} | |||||
} | |||||
return (graph_def, nested_to_flat_tensor_name); | |||||
} | |||||
private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||||
{ | |||||
foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||||
{ | |||||
if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||||
{ | |||||
tensor.shape = Shape.Scalar; | |||||
var shape_and_type = arg_def.HandleData[0]; | |||||
var handle_data = new HandleData(); | |||||
handle_data.IsSet = true; | |||||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||||
{ | |||||
Shape = shape_and_type.Shape, | |||||
Dtype = shape_and_type.Dtype | |||||
}); | |||||
resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||||
} | |||||
} | |||||
} | |||||
private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||||
{ | |||||
if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||||
{ | |||||
return node_def.Attr[arg_def.NumberAttr].I; | |||||
} | |||||
else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||||
{ | |||||
return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||||
} | |||||
else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||||
{ | |||||
return 1; | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||||
$"FunctionDef `fdef` is correct."); | |||||
} | |||||
} | |||||
public static bool is_function(string fname) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
return tf.Context.has_function(fname); | |||||
} | |||||
else | |||||
{ | |||||
var graph = ops.get_default_graph(); | |||||
while(graph is not null) | |||||
{ | |||||
if (graph.IsFunction(fname)) | |||||
{ | |||||
return true; | |||||
} | |||||
if(graph.OuterGraph is not null) | |||||
{ | |||||
graph = graph.OuterGraph; | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
} | |||||
} |
@@ -17,6 +17,7 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
@@ -25,9 +26,14 @@ namespace Tensorflow | |||||
{ | { | ||||
public class importer | public class importer | ||||
{ | { | ||||
public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||||
{ | |||||
return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||||
} | |||||
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | ||||
Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
string[] return_elements = null, | string[] return_elements = null, | ||||
bool validate_colocation_constraints = true, | |||||
string name = null, | string name = null, | ||||
OpList producer_op_list = null) | OpList producer_op_list = null) | ||||
{ | { | ||||
@@ -60,7 +66,7 @@ namespace Tensorflow | |||||
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | ||||
var status = new Status(); | var status = new Status(); | ||||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||||
// need to create a class ImportGraphDefWithResults with IDisposal | // need to create a class ImportGraphDefWithResults with IDisposal | ||||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | ||||
status.Check(true); | status.Check(true); | ||||
@@ -107,21 +113,36 @@ namespace Tensorflow | |||||
foreach (var new_op in graph._add_new_tf_operations()) | foreach (var new_op in graph._add_new_tf_operations()) | ||||
{ | { | ||||
var original_device = new_op.Device; | var original_device = new_op.Device; | ||||
new_op._set_device(original_device); | |||||
} | } | ||||
} | } | ||||
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | ||||
string prefix, | string prefix, | ||||
Dictionary<string, Tensor> input_map, | Dictionary<string, Tensor> input_map, | ||||
string[] return_elements) | |||||
string[] return_elements, | |||||
bool validate_colocation_constraints) | |||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | ||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||||
foreach (var input in input_map) | foreach (var input in input_map) | ||||
{ | { | ||||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||||
var input_src = tf.compat.as_str(input.Key); | |||||
var input_dst = input.Value; | |||||
if (input_src.StartsWith("^")) | |||||
{ | |||||
var src_name = tf.compat.as_str(input_src.Substring(1)); | |||||
var dst_op = input_dst._as_tf_output().oper; | |||||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||||
} | |||||
else | |||||
{ | |||||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||||
src_name = tf.compat.as_str(src_name); | |||||
var dst_output = input_dst._as_tf_output(); | |||||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||||
} | |||||
} | } | ||||
if (return_elements == null) | if (return_elements == null) | ||||
@@ -132,15 +153,16 @@ namespace Tensorflow | |||||
if (name.Contains(":")) | if (name.Contains(":")) | ||||
{ | { | ||||
var (op_name, index) = _ParseTensorName(name); | var (op_name, index) = _ParseTensorName(name); | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||||
op_name = tf.compat.as_str(op_name); | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||||
} | } | ||||
} | } | ||||
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||||
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||||
} | } | ||||
private static (string, int) _ParseTensorName(string tensor_name) | private static (string, int) _ParseTensorName(string tensor_name) | ||||
@@ -173,6 +195,14 @@ namespace Tensorflow | |||||
return graph_def; | return graph_def; | ||||
} | } | ||||
private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||||
{ | |||||
var old_graph_def = graph_def; | |||||
graph_def = new GraphDef(old_graph_def); | |||||
return graph_def; | |||||
} | |||||
private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | ||||
{ | { | ||||
foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
@@ -240,6 +270,35 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||||
{ | |||||
var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||||
foreach (var node in graph_def.Node) | |||||
{ | |||||
// Remove any default attr values that aren't in op_def. | |||||
if (producer_op_dict.ContainsKey(node.Op)) | |||||
{ | |||||
var op_def = op_def_registry.GetOpDef(node.Op); | |||||
if(op_def is null) | |||||
{ | |||||
continue; | |||||
} | |||||
var producer_op_def = producer_op_dict[node.Op]; | |||||
foreach (var key in node.Attr.Keys) | |||||
{ | |||||
if (_FindAttrInOpDef(key, op_def) is null) | |||||
{ | |||||
var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||||
if (attr_def != null && attr_def.DefaultValue != null && | |||||
node.Attr[key] == attr_def.DefaultValue) | |||||
node.Attr[key].ClearValue(); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | ||||
{ | { | ||||
return op_def.Attr.FirstOrDefault(x => x.Name == name); | return op_def.Attr.FirstOrDefault(x => x.Name == name); | ||||
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
public class versions | |||||
{ | |||||
public static int GRAPH_DEF_VERSION = 1286; | |||||
public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||||
} | |||||
} |
@@ -1,9 +1,13 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
@@ -13,29 +17,46 @@ namespace Tensorflow.Functions | |||||
/// </summary> | /// </summary> | ||||
public class ConcreteFunction: Trackable | public class ConcreteFunction: Trackable | ||||
{ | { | ||||
protected IEnumerable<Tensor> _captured_inputs; | |||||
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||||
protected Dictionary<string, AttrValue> _attrs; | |||||
protected FunctionSpec _function_spec; | |||||
protected FunctionSpec _pre_initialized_function_spec = null; | |||||
protected EagerDefinedFunction _inference_function; | |||||
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||||
internal FuncGraph func_graph; | internal FuncGraph func_graph; | ||||
internal ForwardBackwardCall forward_backward; | internal ForwardBackwardCall forward_backward; | ||||
public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
public string Name => func_graph?.FuncName; | |||||
public string Name => _delayed_rewrite_functions.Forward().Name; | |||||
public Tensor[] Outputs; | |||||
public Tensor[] Outputs => func_graph.Outputs; | |||||
public Type ReturnType; | public Type ReturnType; | ||||
public TensorSpec[] OutputStructure; | public TensorSpec[] OutputStructure; | ||||
public IEnumerable<string> ArgKeywords { get; set; } | public IEnumerable<string> ArgKeywords { get; set; } | ||||
public long NumPositionArgs { get; set; } | public long NumPositionArgs { get; set; } | ||||
public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; | |||||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||||
public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
{ | { | ||||
func_graph = new FuncGraph(name); | func_graph = new FuncGraph(name); | ||||
_captured_inputs = func_graph.external_captures; | |||||
_attrs= new Dictionary<string, AttrValue>(); | |||||
_set_infer_function(); | |||||
} | } | ||||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||||
public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null) | |||||
{ | { | ||||
func_graph = graph; | func_graph = graph; | ||||
_captured_inputs = func_graph.external_captures; | |||||
ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||||
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||||
_attrs = attrs; | |||||
_set_infer_function(); | |||||
} | } | ||||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
@@ -53,6 +74,9 @@ namespace Tensorflow.Functions | |||||
new[] { output }, | new[] { output }, | ||||
null); | null); | ||||
func_graph.Exit(); | func_graph.Exit(); | ||||
_captured_inputs = func_graph.external_captures; | |||||
_attrs = new Dictionary<string, AttrValue>(); | |||||
_set_infer_function(); | |||||
} | } | ||||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
@@ -73,6 +97,9 @@ namespace Tensorflow.Functions | |||||
new[] { output.variant_tensor }, | new[] { output.variant_tensor }, | ||||
null); | null); | ||||
func_graph.Exit(); | func_graph.Exit(); | ||||
_captured_inputs = func_graph.external_captures; | |||||
_attrs = new Dictionary<string, AttrValue>(); | |||||
_set_infer_function(); | |||||
} | } | ||||
/*public ConcreteFunction(Func<Tensors, Tensors> func, | /*public ConcreteFunction(Func<Tensors, Tensors> func, | ||||
@@ -130,39 +157,56 @@ namespace Tensorflow.Functions | |||||
{ | { | ||||
var executing_eagerly = tf.Context.executing_eagerly(); | var executing_eagerly = tf.Context.executing_eagerly(); | ||||
var default_graph = ops.get_default_graph(); | var default_graph = ops.get_default_graph(); | ||||
// TODO(Rinne): deal with `default_graph.building_function` | |||||
var tempvv = func_graph.Variables; | |||||
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||||
{ | |||||
foreach(var v in this.func_graph.Variables) | |||||
{ | |||||
resource_variable_ops.variable_accessed(v); | |||||
} | |||||
} | |||||
var tensor_inputs = new Tensors(); | var tensor_inputs = new Tensors(); | ||||
foreach (var (i, arg) in enumerate(args)) | foreach (var (i, arg) in enumerate(args)) | ||||
{ | { | ||||
tensor_inputs.Add(arg); | tensor_inputs.Add(arg); | ||||
// If we're graph building, shape inference is on. | // If we're graph building, shape inference is on. | ||||
if (!executing_eagerly) | |||||
{ | |||||
} | |||||
} | } | ||||
tensor_inputs.AddRange(captured_inputs); | |||||
if (!executing_eagerly) | |||||
{ | |||||
// TODO(Rinne): add the check | |||||
} | |||||
tensor_inputs.AddRange(captured_inputs); | |||||
args = tensor_inputs.ToArray(); | args = tensor_inputs.ToArray(); | ||||
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||||
var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); | |||||
// No tape is watching; skip to running the function. | // No tape is watching; skip to running the function. | ||||
if (possible_gradient_type == 0 && executing_eagerly) | |||||
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) | |||||
{ | { | ||||
var attrs = new object[] | |||||
{ | |||||
"executor_type", "", | |||||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
}; | |||||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||||
return _build_call_outputs(_inference_function.Call(args)); | |||||
} | } | ||||
if (forward_backward == null) | |||||
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||||
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | var (forward_function, args_with_tangents) = forward_backward.Forward(); | ||||
Tensors flat_outputs = null; | Tensors flat_outputs = null; | ||||
if (executing_eagerly) | if (executing_eagerly) | ||||
{ | |||||
flat_outputs = forward_function.Call(args_with_tangents); | flat_outputs = forward_function.Call(args_with_tangents); | ||||
} | |||||
else | |||||
{ | |||||
tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){ | |||||
{ "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } | |||||
}), _ => | |||||
{ | |||||
flat_outputs = forward_function.Call(args_with_tangents); | |||||
}); | |||||
} | |||||
forward_backward.Record(flat_outputs); | forward_backward.Record(flat_outputs); | ||||
return flat_outputs; | |||||
return _build_call_outputs(flat_outputs); | |||||
} | } | ||||
public void AddTograph(Graph? g = null) | public void AddTograph(Graph? g = null) | ||||
@@ -171,13 +215,99 @@ namespace Tensorflow.Functions | |||||
{ | { | ||||
g = ops.get_default_graph(); | g = ops.get_default_graph(); | ||||
} | } | ||||
// TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||||
_delayed_rewrite_functions.Forward().AddToGraph(g); | |||||
} | |||||
public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||||
{ | |||||
_captured_inputs = captures; | |||||
} | } | ||||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
{ | { | ||||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
TangentInfo input_tangents; | |||||
if (executing_eagerly) | |||||
{ | |||||
// TODO(Rinne): check if it needs to be implemented. | |||||
input_tangents = new TangentInfo(); | |||||
} | |||||
else | |||||
{ | |||||
input_tangents = new TangentInfo(); | |||||
} | |||||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||||
{ | |||||
if(input_tangents.Indices is not null || executing_eagerly) | |||||
{ | |||||
string cache_key = "first_order"; | |||||
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||||
{ | |||||
functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
_tape_functions_cache[cache_key] = functions; | |||||
} | |||||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
} | |||||
else | |||||
{ | |||||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); | |||||
} | |||||
} | |||||
else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||||
} | |||||
internal void set_variables(IEnumerable<IVariableV1> variables) | |||||
{ | |||||
func_graph.Variables = variables; | |||||
} | |||||
internal void _set_infer_function() | |||||
{ | |||||
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||||
_inference_function = _delayed_rewrite_functions.Forward(); | |||||
} | |||||
internal void _set_function_spec(FunctionSpec spec) | |||||
{ | |||||
_function_spec = null; | |||||
_pre_initialized_function_spec = spec; | |||||
_initialize_function_spec(); | |||||
} | |||||
internal void _initialize_function_spec() | |||||
{ | |||||
if(_pre_initialized_function_spec is null) | |||||
{ | |||||
return; | |||||
} | |||||
Debug.Assert(_function_spec is null, "already initialized"); | |||||
var spec = _pre_initialized_function_spec; | |||||
//var args = spec.Fullargspec.DictValue.Fields["args"]; | |||||
// TODO(Rinne): self.structured_input_signature | |||||
_function_spec = new FunctionSpec() | |||||
{ | |||||
Fullargspec = spec.Fullargspec, | |||||
IsMethod = spec.IsMethod, | |||||
InputSignature = spec.InputSignature | |||||
}; | |||||
} | |||||
internal Func<Operation, object[], Tensor[]> _get_gradient_function() | |||||
{ | |||||
return _delayed_rewrite_functions._rewrite_forward_and_call_backward; | |||||
} | |||||
private Tensors _build_call_outputs(Tensors result) | |||||
{ | |||||
// TODO(Rinne): deal with `func_graph.structured_outputs` | |||||
return result; | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
@@ -1,50 +1,232 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Operations; | |||||
using Tensorflow.Util; | |||||
using Tensorflow.Common.Extensions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Framework; | |||||
using System.Buffers; | |||||
using Tensorflow.Gradients; | |||||
namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
{ | { | ||||
public class EagerDefinedFunction | |||||
public class EagerDefinedFunction: IDisposable | |||||
{ | { | ||||
public int _num_outputs; | public int _num_outputs; | ||||
public string Name => _func_graph.FuncName; | |||||
FuncGraph _graph; | |||||
FunctionDef _definition; | |||||
OpDef _signature; | |||||
string _name; | |||||
internal ScopedTFFunction _c_func; | |||||
internal Tensor[] _func_graph_outputs; | |||||
internal string _grad_func_name; | |||||
internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func; | |||||
internal EagerDefinedFunction _grad_func; | |||||
internal bool _registered_on_context = false; | |||||
public string Name => _name; | |||||
public DataType[] OutputTypes { get; protected set; } | |||||
public Shape[] OutputShapes { get; protected set; } | |||||
public FunctionDef Definition | |||||
{ | |||||
get | |||||
{ | |||||
if(_definition is null) | |||||
{ | |||||
_definition = _get_definition(); | |||||
} | |||||
return _definition; | |||||
} | |||||
} | |||||
FuncGraph _func_graph; | |||||
public EagerDefinedFunction(string name, FuncGraph graph, | |||||
public OpDef Signature | |||||
{ | |||||
get | |||||
{ | |||||
if( _signature is null) | |||||
{ | |||||
_signature = Definition.Signature; | |||||
} | |||||
return _signature; | |||||
} | |||||
} | |||||
public unsafe EagerDefinedFunction(string name, FuncGraph graph, | |||||
Tensors inputs, Tensors outputs, | Tensors inputs, Tensors outputs, | ||||
Dictionary<string, string> attrs) | |||||
Dictionary<string, AttrValue> attrs) | |||||
{ | { | ||||
_num_outputs = outputs.Length; | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | ||||
.Select(x => x as Operation).ToArray(); | .Select(x => x as Operation).ToArray(); | ||||
var output_names = new string[0]; | |||||
var graph_output_names = graph._output_names; | |||||
string[] output_names; | |||||
if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) | |||||
{ | |||||
output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); | |||||
if(output_names.Distinct().Count() != output_names.Length) | |||||
{ | |||||
output_names = new string[0]; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
output_names = new string[0]; | |||||
} | |||||
_func_graph = new FuncGraph(graph, name, attrs); | |||||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | |||||
Status status = new Status(); | |||||
var fn = c_api.TF_GraphToFunction(graph.c_graph, | |||||
name, | |||||
false, | |||||
operations.Length, | |||||
operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), | |||||
inputs.Length, | |||||
inputs.Select(t => t._as_tf_output()).ToArray(), | |||||
outputs.Length, | |||||
outputs.Select(t => t._as_tf_output()).ToArray(), | |||||
output_names.Length != outputs.Length ? null : output_names, | |||||
IntPtr.Zero, // warning: the control output hasbben totally ignored. | |||||
null, | |||||
status); | |||||
status.Check(true); | |||||
_c_func = new ScopedTFFunction(fn, name); | |||||
foreach(var (attr_name, attr_value) in attrs) | |||||
{ | |||||
var serialized = attr_value.ToByteArray(); | |||||
c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); | |||||
status.Check(true); | |||||
} | |||||
var signature = _get_definition().Signature; | |||||
_name = signature.Name; | |||||
tf_with(ops.init_scope(), s => | |||||
{ | |||||
tf.Context.add_function(fn); | |||||
_registered_on_context = true; | |||||
}); | |||||
_num_outputs = signature.OutputArg.Count; | |||||
OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); | |||||
OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||||
_func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||||
csharp_grad_func = null; | |||||
_graph = graph; | |||||
} | } | ||||
public Tensors Call(Tensors args) | |||||
public unsafe Tensors Call(Tensors args) | |||||
{ | { | ||||
// TODO(Rinne): Add arg `CancellationManager`. | |||||
// TODO(Rinne): Check the arg length. | |||||
var function_call_options = tf.Context.FunctionCallOptions; | |||||
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||||
//if (function_call_options.config_proto_serialized().Length == 0) | |||||
//{ | |||||
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||||
//} | |||||
//else | |||||
//{ | |||||
// config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||||
//} | |||||
string executor_type = function_call_options.ExecutorType ?? ""; | |||||
var executing_eagerly = tf.Context.executing_eagerly(); | |||||
var attrs = new object[] | var attrs = new object[] | ||||
{ | { | ||||
"executor_type", "", | |||||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
"executor_type", executor_type, | |||||
"config_proto", config | |||||
}; | }; | ||||
var results = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | |||||
_func_graph.FuncName, | |||||
args, | |||||
attrs, | |||||
_num_outputs); | |||||
Tensor[] outputs; | |||||
if (executing_eagerly) | |||||
{ | |||||
outputs = execute.executes( | |||||
Signature.Name, | |||||
_num_outputs, | |||||
args, | |||||
attrs, | |||||
tf.Context); | |||||
} | |||||
else | |||||
{ | |||||
if(tf.GetTapeSet().Count == 0) | |||||
{ | |||||
outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||||
executing_eagerly, config, ""); | |||||
} | |||||
else | |||||
{ | |||||
var tape = tf.GetTapeSet().Peek(); | |||||
tape.StopRecord(); | |||||
outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||||
executing_eagerly, config, ""); | |||||
tape.StartRecord(); | |||||
} | |||||
} | |||||
foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||||
{ | |||||
handle_data_util.copy_handle_data(func_graph_output, outputs[i]); | |||||
} | |||||
if (executing_eagerly) | |||||
{ | |||||
return outputs; | |||||
} | |||||
else | |||||
{ | |||||
foreach(var (i, shape) in enumerate(OutputShapes)) | |||||
{ | |||||
outputs[i].shape = shape; | |||||
} | |||||
return outputs; | |||||
} | |||||
} | |||||
public void AddToGraph(Graph g = null) | |||||
{ | |||||
if(g is null && tf.Context.executing_eagerly()) | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (!ctx.has_function(this.Name)) | |||||
{ | |||||
ctx.add_function_def(Definition); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
if (!g.IsFunction(Name)) | |||||
{ | |||||
g.AddFunction(this); | |||||
} | |||||
foreach(var f in _graph.Functions.Values) | |||||
{ | |||||
if (!g.IsFunction(f.Name)) | |||||
{ | |||||
g.AddFunction(f); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return results; | |||||
private FunctionDef _get_definition() | |||||
{ | |||||
var buffer = c_api_util.tf_buffer(); | |||||
Status status = new(); | |||||
c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); | |||||
status.Check(true); | |||||
var proto_data = c_api.TF_GetBuffer(buffer); | |||||
return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>()); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
tf.Context.remove_function(Name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||||
} | } | ||||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | { | ||||
var outputs = _func_graph.Outputs; | |||||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
= BuildFunctionsForOutputs(outputs, inference_args); | |||||
return _forward; | |||||
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||||
return BuildFunctionsForOutputs(outputs, inference_args); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,23 +1,84 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class Function: Trackable | |||||
public class Function: Trackable, IGenericFunction | |||||
{ | { | ||||
#pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
#pragma warning restore CS0169 // The field 'Function._handle' is never used | #pragma warning restore CS0169 // The field 'Function._handle' is never used | ||||
protected Func<Tensor[], Tensor[]> _csharp_function; | |||||
protected ConcreteFunction _concrete_variable_creation_fn; | |||||
protected bool _autograph; | |||||
protected TracingCompiler _variable_creation_fn; | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public Function() | |||||
public Function(Func<Tensor[], Tensor[]> csharp_function, | |||||
string name, bool auto_graph = true) | |||||
{ | |||||
_csharp_function = csharp_function; | |||||
Name = name; | |||||
_autograph = auto_graph; | |||||
} | |||||
public virtual Tensors Apply(Tensors inputs) | |||||
{ | { | ||||
if (_run_functions_eagerly()) | |||||
{ | |||||
return _csharp_function(inputs); | |||||
} | |||||
var result = _call(inputs); | |||||
return result; | |||||
} | |||||
public ConcreteFunction get_concrete_function(params Tensor[] args) | |||||
{ | |||||
return _get_concrete_function_garbage_collected(args); | |||||
} | } | ||||
public Function(string name) | |||||
protected virtual Tensors _call(Tensors inputs) | |||||
{ | { | ||||
Name = name; | |||||
if(_variable_creation_fn is not null) | |||||
{ | |||||
return _variable_creation_fn.Apply(inputs); | |||||
} | |||||
_initialize(inputs); | |||||
return _concrete_variable_creation_fn.CallFlat(inputs, | |||||
_concrete_variable_creation_fn.CapturedInputs); | |||||
} | |||||
protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn) | |||||
{ | |||||
var name = nameof(fn); | |||||
return new TracingCompiler(fn, name, autograph: _autograph); | |||||
} | |||||
protected virtual bool _run_functions_eagerly() | |||||
{ | |||||
return false; | |||||
} | |||||
protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) | |||||
{ | |||||
if(_variable_creation_fn is null) | |||||
{ | |||||
_initialize(args); | |||||
// TODO(Rinne): _initialize_uninitialized_variables | |||||
} | |||||
var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||||
return concrete; | |||||
} | |||||
private void _initialize(Tensor[] args) | |||||
{ | |||||
_variable_creation_fn = _compiler(_csharp_function); | |||||
_variable_creation_fn._name = this.Name; | |||||
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
public interface IGenericFunction | |||||
{ | |||||
Tensors Apply(Tensors args); | |||||
ConcreteFunction get_concrete_function(params Tensor[] args); | |||||
} | |||||
} |
@@ -3,8 +3,10 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Operations; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.tensorflow; | using static Tensorflow.tensorflow; | ||||
@@ -15,17 +17,21 @@ namespace Tensorflow.Functions | |||||
/// </summary> | /// </summary> | ||||
public abstract class TapeGradientFunctions | public abstract class TapeGradientFunctions | ||||
{ | { | ||||
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
string _FORWARD_PREFIX = "__forward_"; | |||||
string _BACKWARD_PREFIX = "__backward_"; | |||||
string _INFERENCE_PREFIX = "__inference_"; | |||||
protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
protected string _FORWARD_PREFIX = "__forward_"; | |||||
protected string _BACKWARD_PREFIX = "__backward_"; | |||||
protected string _INFERENCE_PREFIX = "__inference_"; | |||||
protected FuncGraph _func_graph; | protected FuncGraph _func_graph; | ||||
protected EagerDefinedFunction _forward; | protected EagerDefinedFunction _forward; | ||||
protected FuncGraph _forward_graph; | protected FuncGraph _forward_graph; | ||||
protected List<int> _forwardprop_input_indices; | |||||
protected List<int> _forwardprop_output_indices; | protected List<int> _forwardprop_output_indices; | ||||
protected int _num_forwardprop_outputs; | protected int _num_forwardprop_outputs; | ||||
protected int _num_inference_outputs; | |||||
protected int _num_outputs; | |||||
protected int _num_trainable_inference_outputs; | |||||
protected ConcreteFunction _backward; | protected ConcreteFunction _backward; | ||||
BackwardFunction _backward_function_wrapper; | BackwardFunction _backward_function_wrapper; | ||||
@@ -33,11 +39,25 @@ namespace Tensorflow.Functions | |||||
bool need_gradients_for_jvps) | bool need_gradients_for_jvps) | ||||
{ | { | ||||
_func_graph = func_graph; | _func_graph = func_graph; | ||||
_forward_graph = null; | |||||
_forward = null; | |||||
_backward = null; | |||||
_num_outputs = func_graph.Outputs.Length; | |||||
_forwardprop_output_indices = null; | |||||
_num_forwardprop_outputs = 0; | |||||
_num_inference_outputs = func_graph.Outputs.Length; | |||||
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||||
} | } | ||||
public EagerDefinedFunction Forward(Tensors inference_args) | |||||
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||||
{ | { | ||||
return ForwardAndBackwardFunctions(inference_args); | |||||
// TODO(Rinne): add input_tangents arg. | |||||
if(_forward is null) | |||||
{ | |||||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
= ForwardAndBackwardFunctions(inference_args); | |||||
} | |||||
return _forward; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -45,11 +65,16 @@ namespace Tensorflow.Functions | |||||
/// </summary> | /// </summary> | ||||
/// <param name="flat_outputs"></param> | /// <param name="flat_outputs"></param> | ||||
/// <param name="inference_args"></param> | /// <param name="inference_args"></param> | ||||
public void Record(Tensors flat_outputs, Tensors inference_args) | |||||
public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||||
{ | { | ||||
// TODO(Rinne): add arg `input_tagents`. | |||||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | ||||
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | |||||
getBackwardFunction: backward_function); | |||||
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -61,66 +86,95 @@ namespace Tensorflow.Functions | |||||
/// <returns></returns> | /// <returns></returns> | ||||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | ||||
{ | { | ||||
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||||
.ToDictionary(x => x.Item1, x => x.Item2); | |||||
var captured_inputs = backward.CapturedInputs; | |||||
var remapped_captures = captured_inputs.Select(c => | |||||
{ | |||||
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||||
{ | |||||
return value; | |||||
} | |||||
else | |||||
{ | |||||
return c; | |||||
} | |||||
}).ToArray(); | |||||
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||||
{ | |||||
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||||
throw new RuntimeError($"Failed to map all backward graph captures to " + | |||||
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||||
} | |||||
Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | ||||
var recorded_outputs = new Tensors(); | var recorded_outputs = new Tensors(); | ||||
var trainable_recorded_outputs = 0; | |||||
foreach (var (output_index, output) in enumerate(outputs)) | |||||
int trainable_recorded_outputs = 0; | |||||
var skip_positions = new HashSet<int>(); | |||||
var relevant_outputs = outputs; | |||||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||||
{ | { | ||||
if (trainable_recorded_outputs < backward_function_inputs) | if (trainable_recorded_outputs < backward_function_inputs) | ||||
recorded_outputs.Add(output); | recorded_outputs.Add(output); | ||||
if (gradients_util.IsTrainable(output)) | |||||
trainable_recorded_outputs += 1; | |||||
if (backprop_util.IsTrainable(output)) | |||||
trainable_recorded_outputs++; | |||||
else | |||||
skip_positions.Add(output_index); | |||||
if (output.dtype == dtypes.variant) | |||||
variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||||
} | } | ||||
if(_backward_function_wrapper == null) | |||||
_backward_function_wrapper = (args, unneeded_gradients) => | |||||
{ | { | ||||
var capture_mapping = new Dictionary<long, Tensor>(); | |||||
foreach (var (i, output) in enumerate(outputs)) | |||||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||||
var remapped_captures = new Tensors(); | |||||
foreach (var capture in backward.CapturedInputs) | |||||
{ | |||||
if (capture_mapping.ContainsKey(capture.Id)) | |||||
remapped_captures.Add(capture_mapping[capture.Id]); | |||||
} | |||||
var skip_positions = new List<int>(); | |||||
foreach (var (output_index, output) in enumerate(outputs)) | |||||
if(backward.Outputs is null || backward.Outputs.Length == 0) | |||||
{ | { | ||||
if (!gradients_util.IsTrainable(output)) | |||||
skip_positions.Add(output_index); | |||||
return backward.FlatStructuredOutputs; | |||||
} | } | ||||
_backward_function_wrapper = (args, unneeded_gradients) => | |||||
var processed_args = new Tensors(); | |||||
int input_index = 0; | |||||
foreach (var (output_index, arg) in enumerate(args)) | |||||
{ | { | ||||
var processed_args = new Tensors(); | |||||
var input_index = 0; | |||||
foreach (var (output_index, arg) in enumerate(args)) | |||||
if (skip_positions.Contains(output_index)) | |||||
continue; | |||||
if (arg is null) | |||||
{ | |||||
var input_placeholder = backward.Inputs[input_index]; | |||||
Tensor variant_arg; | |||||
if (input_placeholder.dtype == dtypes.variant) | |||||
{ | |||||
variant_arg = variant_zeros_like[output_index]; | |||||
} | |||||
else | |||||
{ | |||||
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||||
variant_arg = array_ops.zeros(shape, type); | |||||
} | |||||
processed_args.Add(variant_arg); | |||||
} | |||||
else | |||||
{ | { | ||||
if (skip_positions.Contains(output_index)) | |||||
continue; | |||||
if (arg == null) | |||||
throw new NotImplementedException(""); | |||||
processed_args.Add(arg); | processed_args.Add(arg); | ||||
input_index += 1; | |||||
if (input_index >= backward_function_inputs) | |||||
break; | |||||
} | } | ||||
input_index++; | |||||
if (input_index >= backward_function_inputs) | |||||
break; | |||||
} | |||||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
{ | |||||
var index = Convert.ToInt32(unneeded_gradient_index); | |||||
if (gradients.Length <= index) | |||||
gradients.Insert(index, null); | |||||
} | |||||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
{ | |||||
var index = Convert.ToInt32(unneeded_gradient_index); | |||||
if (gradients.Length <= index) | |||||
gradients.Insert(index, null); | |||||
} | |||||
return gradients; | |||||
}; | |||||
} | |||||
return gradients; | |||||
}; | |||||
return (_backward_function_wrapper, recorded_outputs); | return (_backward_function_wrapper, recorded_outputs); | ||||
} | } | ||||
@@ -132,51 +186,66 @@ namespace Tensorflow.Functions | |||||
var trainable_indices = new List<int>(); | var trainable_indices = new List<int>(); | ||||
foreach(var (index, output) in enumerate(outputs)) | foreach(var (index, output) in enumerate(outputs)) | ||||
{ | { | ||||
if (gradients_util.IsTrainable(output)) | |||||
if (backprop_util.IsTrainable(output)) | |||||
{ | { | ||||
trainable_outputs.Add(output); | trainable_outputs.Add(output); | ||||
trainable_indices.Add(index); | trainable_indices.Add(index); | ||||
} | } | ||||
} | } | ||||
var gradients_wrt_outputs = new List<Tensor>(); | |||||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||||
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||||
backwards_graph.as_default(); | backwards_graph.as_default(); | ||||
var gradients_wrt_outputs = new List<Tensor>(); | |||||
foreach (var output in trainable_outputs) | foreach (var output in trainable_outputs) | ||||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||||
{ | |||||
var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); | |||||
var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); | |||||
gradients_wrt_outputs.Add(gradient_placeholder); | |||||
handle_data_util.copy_handle_data(output, gradient_placeholder); | |||||
} | |||||
// TODO(Rinne): with ops.device(None) | |||||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | ||||
_func_graph.Inputs, | |||||
grad_ys: gradients_wrt_outputs.ToArray(), | |||||
src_graph: _func_graph); | |||||
_func_graph.Inputs, | |||||
grad_ys: gradients_wrt_outputs.ToArray(), | |||||
src_graph: _func_graph); | |||||
var captures_from_forward = backwards_graph.external_captures | var captures_from_forward = backwards_graph.external_captures | ||||
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | ||||
.ToArray(); | .ToArray(); | ||||
HashSet<Tensor> existing_outputs = new(_func_graph.Outputs); | |||||
foreach(var capture in captures_from_forward) | foreach(var capture in captures_from_forward) | ||||
{ | { | ||||
if (!_func_graph.Outputs.Contains(capture)) | |||||
if (!existing_outputs.Contains(capture)) | |||||
{ | |||||
existing_outputs.Add(capture); | |||||
_func_graph.Outputs.Add(capture); | _func_graph.Outputs.Add(capture); | ||||
} | |||||
} | } | ||||
backwards_graph.Exit(); | backwards_graph.Exit(); | ||||
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||||
var backward_function_attr = new Dictionary<string, string>(); | |||||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||||
gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||||
backwards_graph.Inputs = gradients_wrt_outputs; | |||||
backwards_graph.Outputs = gradients_wrt_inputs; | |||||
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||||
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||||
var (wrapped_forward_function, wrapped_backward_function) = | |||||
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||||
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||||
//var backward_function_attr = new Dictionary<string, string>(); | |||||
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||||
//var backward_function = new ConcreteFunction(backwards_graph, | |||||
// monomorphic_function_utils._parse_func_attrs(backward_function_attr)); | |||||
var forward_function_attr = new Dictionary<string, string>(); | |||||
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||||
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||||
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||||
//var forward_function_attr = new Dictionary<string, string>(); | |||||
//forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||||
//var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||||
// _func_graph.Inputs, _func_graph.Outputs, | |||||
// monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||||
return (forward_function, _func_graph, backward_function, null, 0); | |||||
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||||
} | } | ||||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -0,0 +1,84 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Security.Cryptography.X509Certificates; | |||||
using System.Text; | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
public class TracingCompiler | |||||
{ | |||||
Func<Tensor[], Tensor[]> _csharp_function; | |||||
//FunctionSpec _function_spec; | |||||
internal string _name; | |||||
bool _autograph; | |||||
Dictionary<string, ConcreteFunction> _function_cache; | |||||
Dictionary<string, AttrValue> _function_attributes; | |||||
int _tracing_count; | |||||
public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null, | |||||
Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null, | |||||
bool reduce_retracing = false, bool capture_by_value = false) | |||||
{ | |||||
_csharp_function = csharp_function; | |||||
bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); | |||||
_name = name; | |||||
_autograph = autograph; | |||||
_function_attributes = attributes ?? new Dictionary<string, AttrValue>(); | |||||
_function_cache = new Dictionary<string, ConcreteFunction>(); | |||||
_tracing_count = 0; | |||||
} | |||||
public Tensor[] Apply(Tensor[] inputs) | |||||
{ | |||||
// TODO(Rinne): add lock here. | |||||
var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); | |||||
return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); | |||||
} | |||||
internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) | |||||
{ | |||||
var (concrete_function, _) = _maybe_define_concrete_function(args); | |||||
return concrete_function; | |||||
} | |||||
private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) | |||||
{ | |||||
return _maybe_define_function(args); | |||||
} | |||||
private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | |||||
{ | |||||
var lookup_func_key = make_cache_key(args); | |||||
if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | |||||
{ | |||||
return (concrete_function, args); | |||||
} | |||||
concrete_function = _create_concrete_function(args); | |||||
_function_cache[lookup_func_key] = concrete_function; | |||||
return (concrete_function, args); | |||||
} | |||||
private ConcreteFunction _create_concrete_function(Tensor[] args) | |||||
{ | |||||
_tracing_count++; | |||||
int arglen = args.Length; | |||||
var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( | |||||
_name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), | |||||
args, new Dictionary<string, object>(), autograph: _autograph | |||||
), _function_attributes); | |||||
return concrete_function; | |||||
} | |||||
private static string make_cache_key(Tensor[] inputs) | |||||
{ | |||||
//string res = ""; | |||||
//foreach (var input in inputs) | |||||
//{ | |||||
// res += $"{input.name}_{input.Id}"; | |||||
//} | |||||
return inputs.Length.ToString(); | |||||
} | |||||
} | |||||
} |
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using Tensorflow.Functions; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -54,6 +55,9 @@ namespace Tensorflow | |||||
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,50 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Framework; | |||||
using Tensorflow.Framework.Models; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
internal static class composite_tensor_utils | |||||
{ | |||||
public static List<object> flatten_with_variables(object inputs) | |||||
{ | |||||
List<object> flat_inputs = new(); | |||||
foreach(var value in nest.flatten(inputs)) | |||||
{ | |||||
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||||
{ | |||||
throw new NotImplementedException("The composite tensor has not been fully supported."); | |||||
} | |||||
else | |||||
{ | |||||
flat_inputs.Add(value); | |||||
} | |||||
} | |||||
return flat_inputs; | |||||
} | |||||
public static List<object> flatten_with_variables_or_variable_specs(object arg) | |||||
{ | |||||
List<object> flat_inputs = new(); | |||||
foreach(var value in nest.flatten(arg)) | |||||
{ | |||||
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||||
{ | |||||
throw new NotImplementedException("The composite tensor has not been fully supported."); | |||||
} | |||||
// TODO(Rinne): deal with `VariableSpec`. | |||||
else if(value is TypeSpec type_spec && value is not TensorSpec) | |||||
{ | |||||
throw new NotImplementedException("The TypeSpec has not been fully supported."); | |||||
} | |||||
else | |||||
{ | |||||
flat_inputs.Add(value); | |||||
} | |||||
} | |||||
return flat_inputs; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,94 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Variables; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
public static class function_saved_model_utils | |||||
{ | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="concrete_function"></param> | |||||
/// <param name="inputs">a list tensors or other objects (such as variables) which | |||||
/// contain tensors that were originally captured by the function</param> | |||||
public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||||
{ | |||||
var bound_inputs = inputs?.Select(obj => | |||||
{ | |||||
if(obj is Tensor tensor) | |||||
{ | |||||
return get_tensor_from_node(tensor); | |||||
} | |||||
else if(obj is IVariableV1 variable) | |||||
{ | |||||
return get_tensor_from_node(variable); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
}); | |||||
var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); | |||||
List<Tensor> captured_inputs_list = new(); | |||||
concrete_function.set_variables(bound_variables); | |||||
if (bound_inputs is not null) | |||||
{ | |||||
foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||||
{ | |||||
if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
captured_inputs_list.Add(bound_input); | |||||
concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||||
if(internal_capture.dtype == dtypes.resource) | |||||
{ | |||||
if (resource_variable_ops.is_resource_variable(bound_input)) | |||||
{ | |||||
handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); | |||||
} | |||||
else | |||||
{ | |||||
handle_data_util.copy_handle_data(bound_input, internal_capture); | |||||
} | |||||
} | |||||
concrete_function.func_graph.capture(bound_input); | |||||
} | |||||
} | |||||
} | |||||
if(captured_inputs_list.Any(inp => inp is null)) | |||||
{ | |||||
// TODO(Rinne): add warnings. | |||||
} | |||||
concrete_function.SetExternalCaptures(captured_inputs_list); | |||||
} | |||||
public static Tensor get_tensor_from_node(Tensor node) | |||||
{ | |||||
return node; | |||||
} | |||||
public static Tensor get_tensor_from_node(IVariableV1 node) | |||||
{ | |||||
if (resource_variable_ops.is_resource_variable(node)) | |||||
{ | |||||
return node.Handle; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,282 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Framework.Models; | |||||
using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | |||||
using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Framework; | |||||
using static Tensorflow.Binding; | |||||
using System.Diagnostics; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
internal static class monomorphic_function_utils | |||||
{ | |||||
internal static string _FORWARD_PREFIX = "__forward_"; | |||||
internal static string _BACKWARD_PREFIX = "__backward_"; | |||||
internal static string _INFERENCE_PREFIX = "__inference_"; | |||||
internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; | |||||
internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
public static string _inference_name(string name) | |||||
{ | |||||
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||||
} | |||||
public static string _forward_name(string name) | |||||
{ | |||||
return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; | |||||
} | |||||
public static string _backward_name(string name) | |||||
{ | |||||
return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; | |||||
} | |||||
public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs, | |||||
FuncGraph forward_graph, FuncGraph backwards_graph) | |||||
{ | |||||
string forward_function_name = _forward_name(forward_graph.Name); | |||||
Dictionary<string, AttrValue> common_attributes; | |||||
if(attrs is null) | |||||
{ | |||||
common_attributes = new Dictionary<string, AttrValue>(); | |||||
} | |||||
else | |||||
{ | |||||
common_attributes = new Dictionary<string, AttrValue>(attrs); | |||||
} | |||||
if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) | |||||
{ | |||||
common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); | |||||
} | |||||
var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||||
{ | |||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } | |||||
}); | |||||
backward_function_attr.Update(common_attributes); | |||||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||||
var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||||
{ | |||||
{BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } | |||||
}); | |||||
forward_function_attr.Update(common_attributes); | |||||
var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, | |||||
forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); | |||||
return (forward_function, backward_function); | |||||
} | |||||
public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes) | |||||
{ | |||||
Dictionary<string, AttrValue> attrs = new(); | |||||
foreach(var item in attributes) | |||||
{ | |||||
var key = item.Key; | |||||
var value = item.Value; | |||||
if (value is AttrValue attr_value) | |||||
{ | |||||
attrs[key] = attr_value; | |||||
} | |||||
else if (value is bool b) | |||||
{ | |||||
attrs[key] = new AttrValue() { B = b }; | |||||
} | |||||
else if (value is int i) | |||||
{ | |||||
attrs[key] = new AttrValue() { I = i }; | |||||
} | |||||
else if (value is float f) | |||||
{ | |||||
attrs[key] = new AttrValue() { F = f }; | |||||
} | |||||
else if(value is string s) | |||||
{ | |||||
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; | |||||
} | |||||
else if (value is byte[] bytes) | |||||
{ | |||||
attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + | |||||
$"AttrValue. Got {value.GetType()}."); | |||||
} | |||||
} | |||||
return attrs; | |||||
} | |||||
public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes) | |||||
{ | |||||
Dictionary<string, AttrValue> attrs = new(); | |||||
foreach (var item in attributes) | |||||
{ | |||||
var key = item.Key; | |||||
var value = item.Value; | |||||
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; | |||||
} | |||||
return attrs; | |||||
} | |||||
} | |||||
public class DelayedRewriteGradientFunctions : TapeGradientFunctions | |||||
{ | |||||
EagerDefinedFunction _inference_function; | |||||
Dictionary<string, AttrValue> _attrs; | |||||
int _num_inference_outputs; | |||||
Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new(); | |||||
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs) | |||||
: base(func_graph, false) | |||||
{ | |||||
_func_graph = func_graph; | |||||
_inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), | |||||
_func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||||
_attrs = attrs; | |||||
_num_inference_outputs = _func_graph.Outputs.Length; | |||||
} | |||||
public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||||
{ | |||||
if (input_tangents is not null) | |||||
{ | |||||
throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||||
$"a class that does not support forwardprop."); | |||||
} | |||||
return _inference_function; | |||||
} | |||||
public override void Record(Tensors flat_outputs, Tensors inference_args) | |||||
{ | |||||
var (backward_function, to_record) = _backward(flat_outputs); | |||||
foreach(var tape in tf.GetTapeSet()) | |||||
{ | |||||
tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||||
inference_args, backward_function); | |||||
} | |||||
} | |||||
public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) | |||||
{ | |||||
if(num_doutputs == -2) | |||||
{ | |||||
num_doutputs = _num_inference_outputs; | |||||
} | |||||
if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) | |||||
{ | |||||
return target; | |||||
} | |||||
var (forward, backward) = _construct_forward_backward(num_doutputs); | |||||
_cached_function_pairs[num_doutputs] = (forward, backward); | |||||
return (forward, backward); | |||||
} | |||||
private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||||
{ | |||||
Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) | |||||
{ | |||||
var call_op = outputs[0].op; | |||||
return _rewrite_forward_and_call_backward(call_op, args); | |||||
} | |||||
return (backward_function, outputs); | |||||
} | |||||
internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) | |||||
{ | |||||
var (forward_function, backward_function) = forward_backward(doutputs.Length); | |||||
if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) | |||||
{ | |||||
return backward_function.FlatStructuredOutputs; | |||||
} | |||||
forward_function.AddToGraph(op.graph); | |||||
op._set_func_attr("f", forward_function.Name); | |||||
op._set_type_list_attr("Tout", forward_function.OutputTypes); | |||||
op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). | |||||
Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() | |||||
); | |||||
for(int i = 0; i < op.outputs.Length; i++) | |||||
{ | |||||
var func_graph_output = forward_function._func_graph_outputs[i]; | |||||
handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); | |||||
} | |||||
var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). | |||||
ToDictionary(x => x.Item1, x => x.Item2); | |||||
var remapped_captures = backward_function.CapturedInputs.Select( | |||||
x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) | |||||
); | |||||
List<Tensor> cleaned_doutputs = new(); | |||||
foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) | |||||
{ | |||||
if (backprop_util.IsTrainable(placeholder)) | |||||
{ | |||||
if(doutput is IndexedSlices) | |||||
{ | |||||
cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); | |||||
} | |||||
else if(doutput is null) | |||||
{ | |||||
cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); | |||||
} | |||||
else if(doutput is Tensor tensor) | |||||
{ | |||||
cleaned_doutputs.Add(tensor); | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); | |||||
} | |||||
} | |||||
} | |||||
return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); | |||||
} | |||||
private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) | |||||
{ | |||||
var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); | |||||
List<TensorSpec> signature = new(); | |||||
foreach(var t in trainable_outputs) | |||||
{ | |||||
var (shape, dtype) = default_gradient.shape_and_dtype(t); | |||||
signature.Add(new TensorSpec(shape, dtype)); | |||||
} | |||||
Tensor[] _backprop_function(Tensor[] grad_ys) | |||||
{ | |||||
return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, | |||||
grad_ys, src_graph: _func_graph); | |||||
} | |||||
_func_graph.as_default(); | |||||
FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||||
FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => | |||||
{ | |||||
Debug.Assert(y is Tensor); | |||||
return (Tensor)y; | |||||
}).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph); | |||||
var backwards_graph_captures = backwards_graph.external_captures; | |||||
var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); | |||||
HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs); | |||||
foreach(var capture in captures_from_forward) | |||||
{ | |||||
if (!existing_outputs.Contains(capture)) | |||||
{ | |||||
existing_outputs.Add(capture); | |||||
_func_graph.Outputs.Add(capture); | |||||
} | |||||
} | |||||
var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( | |||||
_attrs, _func_graph, backwards_graph); | |||||
_func_graph.Exit(); | |||||
return (forward_function, backward_function); | |||||
} | |||||
} | |||||
} |
@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||||
/// Map from tensor to how many references still exist for this tensor in | /// Map from tensor to how many references still exist for this tensor in | ||||
/// the tape. | /// the tape. | ||||
/// </summary> | /// </summary> | ||||
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||||
public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Maps from op ID to how many output tensors of this op still need to have | /// Maps from op ID to how many output tensors of this op still need to have | ||||
/// their gradients computed. | /// their gradients computed. | ||||
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||||
public BackpropInitialState() | public BackpropInitialState() | ||||
{ | { | ||||
op_tape = new OpTape(); | op_tape = new OpTape(); | ||||
tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||||
tensor_usage_counts = new UnorderedMap<long, long>(); | |||||
op_missing_tensor = new UnorderedMap<long, long>(); | op_missing_tensor = new UnorderedMap<long, long>(); | ||||
} | } | ||||
} | } | ||||
@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||||
/// <param name="target"></param> | /// <param name="target"></param> | ||||
/// <param name="source"></param> | /// <param name="source"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor gradient(Tensor target, Tensor source) | |||||
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
if(_tape is null) | |||||
{ | |||||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||||
"compute one set of gradients (or jacobians)."); | |||||
} | |||||
ITape tape = stop_recording(); | ITape tape = stop_recording(); | ||||
var results = tf.Runner.TFE_TapeGradient(tape, | var results = tf.Runner.TFE_TapeGradient(tape, | ||||
new[] { target }, | new[] { target }, | ||||
new[] { source }, | new[] { source }, | ||||
null); | |||||
output_gradients, | |||||
new[] { source }, | |||||
unconnected_gradients); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public Tensor gradient(Tensor target, ResourceVariable source) | |||||
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
var results = gradient(target, new List<IVariableV1> { source }); | |||||
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||||
return (results[0], results[1]); | return (results[0], results[1]); | ||||
} | } | ||||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
if (_tape is null) | |||||
{ | |||||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||||
"compute one set of gradients (or jacobians)."); | |||||
} | |||||
var tape = stop_recording(); | var tape = stop_recording(); | ||||
var results = tf.Runner.TFE_TapeGradient(tape, | var results = tf.Runner.TFE_TapeGradient(tape, | ||||
new[] { target }, | new[] { target }, | ||||
sources.Select(x => x.Handle).ToArray(), | sources.Select(x => x.Handle).ToArray(), | ||||
null); | |||||
output_gradients, | |||||
sources.Select(x => x.Handle).ToArray(), | |||||
unconnected_gradients); | |||||
if (!tape.Persistent) | if (!tape.Persistent) | ||||
{ | { | ||||
@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||||
public interface ITape | public interface ITape | ||||
{ | { | ||||
void SetTapeId(int id); | void SetTapeId(int id); | ||||
bool ShouldRecord(Tensor[] tensors); | |||||
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||||
void StartRecord(); | void StartRecord(); | ||||
void StopRecord(); | void StopRecord(); | ||||
bool Persistent { get; } | bool Persistent { get; } | ||||
void RecordOperation(string op_type, | void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function); | BackwardFunction backward_function); | ||||
void VariableAccessed(ResourceVariable variable); | |||||
void RecordOperation(string op_type, | |||||
Tensor[] outputs, | |||||
Tensor[] inputs, | |||||
BackwardFunction backward_function); | |||||
void VariableAccessed(IVariableV1 variable); | |||||
void Watch(Tensor x); | void Watch(Tensor x); | ||||
ResourceVariable[] WatchedVariables(); | |||||
IVariableV1[] WatchedVariables(); | |||||
Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients); | |||||
Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
bool build_default_zeros_grads); | |||||
} | } | ||||
} | } |
@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public string op_type { get; set; } | public string op_type { get; set; } | ||||
public TapeTensor[] output_tensor_info { get; set; } | public TapeTensor[] output_tensor_info { get; set; } | ||||
public Tensor[] input_tensor_id { get; set; } | |||||
public long[] input_tensor_id { get; set; } | |||||
public BackwardFunction backward_function { get; set; } | public BackwardFunction backward_function { get; set; } | ||||
public override string ToString() | public override string ToString() | ||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||||
} | } | ||||
} | } |
@@ -2,235 +2,246 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
// int kMinAggregateCount = 4; | |||||
// int kMinAggregateBytes = 128 * 1024 * 1024; | |||||
static readonly int kMinAggregateCount = 4; | |||||
static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||||
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||||
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients) | |||||
static Tape() | |||||
{ | { | ||||
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||||
// var gradients_size = new UnorderedMap<Tensor, long>(); | |||||
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||||
var state = PrepareBackprop( | |||||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||||
output_gradients, | |||||
tensor_tape_, | |||||
state.op_tape); | |||||
_functionsAcceptingNoneForIndicesMap = new(); | |||||
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||||
} | |||||
while (!op_stack.empty()) | |||||
public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
bool build_default_zeros_grads) | |||||
{ | |||||
UnorderedSet<long> sources_set = new(source_tensor_ids); | |||||
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||||
UnorderedMap<long, long> gradients_size = new(); | |||||
while(op_stack.Count > 0) | |||||
{ | { | ||||
var op = op_stack.Dequeue(); | |||||
if (!state.op_tape.find(op, out var trace)) | |||||
long op = op_stack.Dequeue(); | |||||
if(!state.op_tape.TryGetValue(op, out var op_it)) | |||||
{ | |||||
continue; | continue; | ||||
// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||||
} | |||||
var trace = op_it; | |||||
state.op_tape.erase(op); | state.op_tape.erase(op); | ||||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||||
var unneeded_gradients = new List<long>(); | |||||
for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||||
List<Tensor> out_gradients = new(); | |||||
List<long> unneeded_gradients = new(); | |||||
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||||
{ | { | ||||
var in_tensor_id = trace.input_tensor_id[i]; | |||||
if (!tensor_tape_.find(in_tensor_id) && | |||||
!sources_set.find(in_tensor_id)) | |||||
long in_tensor_id = trace.input_tensor_id[i]; | |||||
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||||
{ | |||||
unneeded_gradients.Add(i); | unneeded_gradients.Add(i); | ||||
} | |||||
} | } | ||||
bool any_gradient_nonzero = false; | bool any_gradient_nonzero = false; | ||||
var zero_indices = new List<int>(); | |||||
for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||||
List<int> zero_indices = new(); | |||||
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||||
{ | { | ||||
var id = trace.output_tensor_info[i].GetTensor(); | |||||
if (!gradients.find(id, out var grad_it)) | |||||
long id = trace.output_tensor_info[i].GetID(); | |||||
if(!gradients.TryGetValue(id, out var grad_it)) | |||||
{ | { | ||||
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||||
func_name_it.find(i)) | |||||
out_gradients.Add(null); | |||||
if (build_default_zeros_grads) | |||||
{ | { | ||||
out_gradients.Add(null); | |||||
} | |||||
else | |||||
{ | |||||
out_gradients.Add(null); | |||||
zero_indices.Add(i); | |||||
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||||
!func_name_it.find(i)) | |||||
{ | |||||
zero_indices.Add(i); | |||||
} | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
any_gradient_nonzero = true; | any_gradient_nonzero = true; | ||||
var new_gradients = grad_it.Count == 1 ? | |||||
grad_it[0] : | |||||
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||||
Tensor new_gradients; | |||||
if (grad_it.Count == 1) | |||||
{ | |||||
new_gradients = grad_it[0]; | |||||
} | |||||
else | |||||
{ | |||||
new_gradients = AggregateGradients(grad_it); | |||||
} | |||||
if (!sources_set.find(id)) | if (!sources_set.find(id)) | ||||
{ | |||||
gradients.Remove(id); | gradients.Remove(id); | ||||
} | |||||
else | else | ||||
{ | { | ||||
// grad_it.Clear(); | |||||
// grad_it.Add(new_gradients); | |||||
// vspace.MarkAsResult(new_gradients); | |||||
grad_it.Clear(); | |||||
grad_it.Add(new_gradients); | |||||
// MarkAsResult | |||||
} | } | ||||
out_gradients.Add(new_gradients); | out_gradients.Add(new_gradients); | ||||
} | } | ||||
} | } | ||||
Tensor[] in_gradients; | |||||
Tensor[] in_gradients = new Tensor[0]; | |||||
if (any_gradient_nonzero) | if (any_gradient_nonzero) | ||||
{ | { | ||||
// foreach (var i in zero_indices) | |||||
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||||
in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||||
if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||||
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||||
if (!_persistent) | |||||
foreach(var i in zero_indices) | |||||
{ | { | ||||
// trace.backward_function_deleter(trace.backward_function); | |||||
trace.backward_function = null; | |||||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||||
} | } | ||||
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||||
out_gradients.Clear(); | |||||
} | } | ||||
bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||||
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||||
for(int i = 0, end = in_gradients.Length; i < end; i++) | |||||
{ | { | ||||
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||||
var id = trace.input_tensor_id[k]; | |||||
if (in_gradients[i] != null) | |||||
long id = trace.input_tensor_id[i]; | |||||
if (in_gradients[i] is not null) | |||||
{ | { | ||||
var unaggregated_grads = gradients[id]; | |||||
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||||
unaggregated_grads.Add(in_gradients[i]); | unaggregated_grads.Add(in_gradients[i]); | ||||
/*if (unaggregated_grads.Count > kMinAggregateCount) | |||||
if(unaggregated_grads.Count > kMinAggregateCount) | |||||
{ | { | ||||
if (!gradients_size.find(id, out var size)) | |||||
if(!gradients_size.TryGetValue(id, out var size)) | |||||
{ | { | ||||
size = (long)unaggregated_grads[0].size; | |||||
size = NumElements(unaggregated_grads[0]); | |||||
gradients_size.emplace(id, size); | gradients_size.emplace(id, size); | ||||
} | } | ||||
if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||||
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
Tensor grad = AggregateGradients(unaggregated_grads); | |||||
unaggregated_grads.Clear(); | |||||
unaggregated_grads.Add(grad); | |||||
} | } | ||||
}*/ | |||||
} | |||||
} | } | ||||
if (!state.tensor_usage_counts.find(id)) | |||||
if(!state.tensor_usage_counts.find(id)) | |||||
{ | |||||
continue; | continue; | ||||
} | |||||
state.tensor_usage_counts[id]--; | state.tensor_usage_counts[id]--; | ||||
if (state.tensor_usage_counts[id] > 0) | |||||
if(state.tensor_usage_counts[id] > 0) | |||||
{ | |||||
continue; | continue; | ||||
if (!tensor_tape_.find(id, out var tape_it)) | |||||
} | |||||
if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||||
{ | { | ||||
if (gradients.find(id, out var grad_it)) | |||||
if (gradients.find(id)) | |||||
{ | { | ||||
// foreach (var g in grad_it) | |||||
// DeleteGradient(g); | |||||
gradients.erase(id); | gradients.erase(id); | ||||
} | } | ||||
continue; | continue; | ||||
} | } | ||||
var op_id = tape_it; | |||||
if (op_id == -1) | |||||
long op_id = tape_it; | |||||
if(op_id == -1) | |||||
{ | |||||
continue; | continue; | ||||
if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||||
} | |||||
if(state.op_missing_tensor.find(op_id)) | |||||
{ | { | ||||
state.op_missing_tensor[op_id]--; | state.op_missing_tensor[op_id]--; | ||||
if (state.op_missing_tensor[op_id] == 0) | |||||
if(state.op_missing_tensor[op_id] == 0) | |||||
{ | |||||
op_stack.Enqueue(op_id); | op_stack.Enqueue(op_id); | ||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (state.op_tape.Count > 0) | |||||
if(state.op_tape.Count > 0) | |||||
{ | |||||
throw new RuntimeError("Invalid tape state."); | throw new RuntimeError("Invalid tape state."); | ||||
var result = new Tensor[source_tensor_ids.Length]; | |||||
var j = 0; | |||||
foreach (var id in source_tensor_ids) | |||||
} | |||||
Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||||
for(int i = 0; i < source_tensor_ids.Length; i++) | |||||
{ | { | ||||
if (gradients.find(id, out var grad_it)) | |||||
long tensor_id = source_tensor_ids[i]; | |||||
if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||||
{ | { | ||||
if (grad_it.Count > 1) | |||||
result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||||
else | |||||
result[j] = grad_it[0]; | |||||
result[i] = null; | |||||
} | |||||
else | |||||
{ | |||||
if(grad_it.Count > 1) | |||||
{ | |||||
Tensor grad = AggregateGradients(grad_it); | |||||
grad_it.Clear(); | |||||
grad_it.Add(grad); | |||||
} | |||||
result[i] = grad_it[0]; | |||||
} | } | ||||
j++; | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | ||||
{ | { | ||||
var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||||
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||||
return m; | |||||
return _functionsAcceptingNoneForIndicesMap; | |||||
} | } | ||||
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients, | |||||
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape op_tape) | OpTape op_tape) | ||||
{ | { | ||||
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||||
for (int i = 0; i < target_tensor_ids.Length; ++i) | |||||
var result = new UnorderedMap<long, List<Tensor>>(); | |||||
for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||||
{ | { | ||||
var id = target_tensor_ids[i]; | |||||
if (output_gradients.Length == 0 || output_gradients[i] == null) | |||||
long id = target_tensor_ids[i]; | |||||
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||||
{ | { | ||||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||||
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||||
{ | { | ||||
if (!op_tape.find(tensor_tape[id], out var op_it)) | |||||
if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||||
{ | |||||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | throw new RuntimeError("Internal state of the gradient tape is invalid: " + | ||||
"failed to find operation producing a tensor"); | |||||
"failed to find operation producing a tensor."); | |||||
} | |||||
bool found = false; | bool found = false; | ||||
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||||
for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||||
{ | { | ||||
if (op_it.output_tensor_info[j].GetTensor() == id) | |||||
if (op_it.output_tensor_info[j].GetID() == id) | |||||
{ | { | ||||
found = true; | found = true; | ||||
var ones = op_it.output_tensor_info[j].OnesLike(); | |||||
result[id].Add(ones); | |||||
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
if (!found) | if (!found) | ||||
{ | { | ||||
throw new ValueError("Internal state of the gradient tape is invalid: " + | |||||
"none of operations outputs match expected tensor"); | |||||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||||
"none of operations outputs match expected tensor."); | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
if (sources_that_are_targets.find(id, out var source_tensor)) | |||||
result[id].Add(source_tensor.OnesLike()); | |||||
if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||||
{ | |||||
Tensor ones_like = BuildOnesLike(source_tensor); | |||||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||||
} | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
result[id].Add(output_gradients[i]); | |||||
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||||
} | } | ||||
} | } | ||||
@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
Tensor BuildOnesLike(TapeTensor t) | |||||
{ | |||||
return t.OnesLike(); | |||||
} | |||||
Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||||
{ | |||||
if(gradient_tensors.Count == 0) | |||||
{ | |||||
return gradient_tensors[0]; | |||||
} | |||||
return tf.add_n(gradient_tensors.ToArray()); | |||||
} | |||||
void DeleteGradient(Tensor gradient) | |||||
{ | |||||
// Do not do anything here. Because GC will collect it when it has no reference. | |||||
} | |||||
long NumElements(Tensor tensor) => 1; | |||||
} | } | ||||
} | } |
@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
public BackpropInitialState PrepareBackprop(Tensor[] target, | |||||
public BackpropInitialState PrepareBackprop(long[] target, | |||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape op_tape, | OpTape op_tape, | ||||
UnorderedSet<Tensor> sources_set, | |||||
UnorderedSet<long> sources_set, | |||||
bool persistent_tape) | bool persistent_tape) | ||||
{ | { | ||||
Stack<long> tensor_stack = new Stack<long>(); | |||||
foreach(var t in target) | |||||
{ | |||||
tensor_stack.Push(t); | |||||
} | |||||
BackpropInitialState result = new BackpropInitialState(); | BackpropInitialState result = new BackpropInitialState(); | ||||
var tensor_stack = new Queue<Tensor>(target); | |||||
while (tensor_stack.Count > 0) | |||||
while(tensor_stack.Count > 0) | |||||
{ | { | ||||
var tensor_id = tensor_stack.Dequeue(); | |||||
if (!tensor_tape.find(tensor_id, out var op_id)) | |||||
long tensor_id = tensor_stack.Pop(); | |||||
if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||||
{ | |||||
continue; | continue; | ||||
if (op_id == -1 || | |||||
!op_tape.find(op_id, out var op_it) || | |||||
result.op_tape.find(op_id, out var result_op_it)) | |||||
} | |||||
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||||
|| result.op_tape.find(op_id)) | |||||
{ | |||||
continue; | continue; | ||||
} | |||||
result.op_tape.emplace(op_id, op_it); | result.op_tape.emplace(op_id, op_it); | ||||
foreach (var it in op_it.input_tensor_id) | |||||
foreach(var it in op_it.input_tensor_id) | |||||
{ | { | ||||
if (result.tensor_usage_counts.find(it)) | |||||
if(result.tensor_usage_counts.find(it)) | |||||
{ | |||||
result.tensor_usage_counts[it]++; | result.tensor_usage_counts[it]++; | ||||
} | |||||
else | else | ||||
{ | { | ||||
result.tensor_usage_counts[it] = 1; | result.tensor_usage_counts[it] = 1; | ||||
if (tensor_tape.find(it)) | if (tensor_tape.find(it)) | ||||
tensor_stack.Enqueue(it); | |||||
{ | |||||
tensor_stack.Push(it); | |||||
} | |||||
} | } | ||||
} | } | ||||
if (!persistent_tape) | if (!persistent_tape) | ||||
op_tape.Remove(op_id); | |||||
{ | |||||
op_tape.erase(op_id); | |||||
} | |||||
} | } | ||||
foreach (var pair in result.tensor_usage_counts) | |||||
foreach(var pair in result.tensor_usage_counts) | |||||
{ | { | ||||
if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||||
result.op_missing_tensor[it] += 1; | |||||
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||||
{ | |||||
result.op_missing_tensor[it]++; | |||||
} | |||||
} | } | ||||
if (!persistent_tape) | if (!persistent_tape) | ||||
{ | { | ||||
// Call destructors for all unneeded gradient functions and | |||||
// clear the op_tape. We can clear the tape because ownership of | |||||
// backward functions that will be used for gradient computation | |||||
// has been transferred to `result`. | |||||
/*for (const auto&op_pair : *op_tape) { | |||||
op_pair.second.backward_function_deleter( | |||||
op_pair.second.backward_function); | |||||
}*/ | |||||
op_tape.Clear(); | op_tape.Clear(); | ||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
} | } | ||||
@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
long next_op_id_ = 0; | long next_op_id_ = 0; | ||||
UnorderedMap<Tensor, long> tensor_usage_; | |||||
UnorderedMap<long, long> tensor_usage_; | |||||
public void RecordOperation(string op_type, | public void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
if (!ShouldRecord(input_tensors)) | |||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||||
return; | return; | ||||
var op_id = next_op_id_++; | |||||
foreach (var i in input_tensors) | |||||
foreach (var i in input_tensor_id) | |||||
{ | |||||
tensor_usage_[i]++; | tensor_usage_[i]++; | ||||
} | |||||
long op_id = next_op_id_++; | |||||
foreach (var o in output_tensors) | foreach (var o in output_tensors) | ||||
{ | { | ||||
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | ||||
tensor_tape_[o.GetTensor()] = op_id; | |||||
tensor_usage_[o.GetTensor()] = 1; | |||||
tensor_tape_[o.GetID()] = op_id; | |||||
tensor_usage_[o.GetID()] = 1; | |||||
} | } | ||||
op_tape_[op_id] = new OpTapeEntry | op_tape_[op_id] = new OpTapeEntry | ||||
{ | { | ||||
op_type = op_type, | op_type = op_type, | ||||
output_tensor_info = output_tensors, | |||||
input_tensor_id = input_tensors, | |||||
output_tensor_info = output_tensors.ToArray(), | |||||
input_tensor_id = input_tensor_id.ToArray(), | |||||
backward_function = backward_function | backward_function = backward_function | ||||
}; | }; | ||||
} | } | ||||
public void RecordOperation(string op_type, | |||||
Tensor[] outputs, | |||||
Tensor[] inputs, | |||||
BackwardFunction backward_function) | |||||
{ | |||||
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||||
_created_eagerly = tf.Context.executing_eagerly(); | _created_eagerly = tf.Context.executing_eagerly(); | ||||
tensor_tape_ = new TensorTape(); | tensor_tape_ = new TensorTape(); | ||||
op_tape_ = new OpTape(); | op_tape_ = new OpTape(); | ||||
tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||||
tensor_usage_ = new UnorderedMap<long, long>(); | |||||
if(_created_eagerly) | if(_created_eagerly) | ||||
tf.Context.start_step(); | tf.Context.start_step(); | ||||
// nesting_id = ++tape_nesting_id_counter; | // nesting_id = ++tape_nesting_id_counter; | ||||
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||||
public void Watch(Tensor x) | public void Watch(Tensor x) | ||||
{ | { | ||||
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | ||||
tensor_tape_.emplace(x, -1); | |||||
tensor_tape_.emplace(x.Id, -1); | |||||
} | } | ||||
public bool ShouldRecord(Tensor[] tensors) | |||||
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||||
{ | { | ||||
var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||||
for (int i = 0; i < tensors.Length; ++i) | |||||
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||||
for (int i = 0; i < tensor_ids.Length; ++i) | |||||
{ | { | ||||
if (tensor_tape_.find(tensors[i])) | |||||
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||||
{ | { | ||||
if (IsDtypeTrainable(dtypes[i])) | |||||
return true; | |||||
return true; | |||||
} | } | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
public void VariableAccessed(ResourceVariable variable) | |||||
public void VariableAccessed(IVariableV1 variable) | |||||
{ | { | ||||
Watch(variable.Handle); | Watch(variable.Handle); | ||||
} | } | ||||
public ResourceVariable[] WatchedVariables() | |||||
public IVariableV1[] WatchedVariables() | |||||
{ | { | ||||
return null; | return null; | ||||
} | } | ||||
@@ -1,27 +1,63 @@ | |||||
using static Tensorflow.Binding; | |||||
using OneOf; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public class TapeTensor | public class TapeTensor | ||||
{ | { | ||||
Tensor tensor; | |||||
long id => tensor.Id; | |||||
TF_DataType dtype => tensor.dtype; | |||||
Shape shape => tensor.shape; | |||||
internal Tensor tensor; | |||||
internal long id; | |||||
internal TF_DataType dtype; | |||||
internal OneOf<Shape, Tensor> shape; | |||||
public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||||
{ | |||||
this.id = id; | |||||
this.dtype = dtype; | |||||
this.shape = shape; | |||||
} | |||||
public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||||
{ | |||||
this.id = id; | |||||
this.dtype = dtype; | |||||
this.shape = shape; | |||||
} | |||||
public TapeTensor(Tensor tensor) | public TapeTensor(Tensor tensor) | ||||
{ | { | ||||
this.id = tensor.Id; | |||||
this.dtype = tensor.dtype; | |||||
this.shape = tensor.shape; | |||||
this.tensor = tensor; | this.tensor = tensor; | ||||
} | } | ||||
public long GetID() => tensor.Id; | |||||
public Tensor GetTensor() => tensor; | |||||
public long GetID() => id; | |||||
public Tensor ZerosLike() | public Tensor ZerosLike() | ||||
=> tf.zeros(shape: shape, dtype: dtype); | |||||
{ | |||||
if(dtype == dtypes.resource) | |||||
{ | |||||
return null; | |||||
} | |||||
if(shape.Index == 1) | |||||
{ | |||||
return tf.zeros_like(shape.AsT1); | |||||
} | |||||
return tf.zeros(shape.AsT0, dtype); | |||||
} | |||||
public Tensor OnesLike() | public Tensor OnesLike() | ||||
=> tf.ones(shape: shape, dtype: dtype); | |||||
{ | |||||
if (shape.Index == 1) | |||||
{ | |||||
return tf.ones_like(shape.AsT1); | |||||
} | |||||
return tf.ones(shape.AsT0, dtype); | |||||
} | |||||
//public Tensor OnesLike() | |||||
// => tf.ones(shape: shape, dtype: dtype); | |||||
public override string ToString() | public override string ToString() | ||||
=> $"{id}, {shape}, {dtype.as_numpy_name()}"; | => $"{id}, {shape}, {dtype.as_numpy_name()}"; | ||||
@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||||
/// produced this tensor. A value of -1 means that the tensor was directly | /// produced this tensor. A value of -1 means that the tensor was directly | ||||
/// watched and not the result of any operation in the tape. | /// watched and not the result of any operation in the tape. | ||||
/// </summary> | /// </summary> | ||||
public class TensorTape : UnorderedMap<Tensor, long> | |||||
public class TensorTape : UnorderedMap<long, long> | |||||
{ | { | ||||
} | } | ||||
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Gradients | |||||
{ | |||||
public class custom_gradient | |||||
{ | |||||
public static string generate_name() | |||||
{ | |||||
return $"CustomGradient-{ops.uid()}"; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,52 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Gradients | |||||
{ | |||||
internal static class default_gradient | |||||
{ | |||||
public static (Shape, TF_DataType) shape_and_dtype(Tensor t) | |||||
{ | |||||
if(t.dtype == dtypes.resource) | |||||
{ | |||||
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||||
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||||
{ | |||||
throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||||
$"of a variable without handle data:\n{t}"); | |||||
} | |||||
return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); | |||||
} | |||||
return (t.shape, t.dtype); | |||||
} | |||||
public static Tensor zeros_like(Tensor t) | |||||
{ | |||||
if(t.dtype == dtypes.resource) | |||||
{ | |||||
var (shape, dtype) = shape_and_dtype(t); | |||||
return array_ops.zeros(shape, dtype); | |||||
} | |||||
else | |||||
{ | |||||
return array_ops.zeros_like(t); | |||||
} | |||||
} | |||||
public static TF_DataType get_zeros_dtype(Tensor t) | |||||
{ | |||||
if(t.dtype == dtypes.resource) | |||||
{ | |||||
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||||
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||||
{ | |||||
throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||||
$"of a variable without handle data:\n{t}"); | |||||
} | |||||
return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); | |||||
} | |||||
return t.dtype; | |||||
} | |||||
} | |||||
} |
@@ -14,10 +14,15 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Operations; | |||||
using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -25,6 +30,11 @@ namespace Tensorflow | |||||
{ | { | ||||
public class gradients_util | public class gradients_util | ||||
{ | { | ||||
// Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are | |||||
// unfortunately too slow to use here. | |||||
public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; | |||||
public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; | |||||
public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; | |||||
public static Tensor[] _GradientsHelper(Tensor[] ys, | public static Tensor[] _GradientsHelper(Tensor[] ys, | ||||
Tensor[] xs, | Tensor[] xs, | ||||
Tensor[] grad_ys = null, | Tensor[] grad_ys = null, | ||||
@@ -143,7 +153,7 @@ namespace Tensorflow | |||||
Tensor[] in_grads = null; | Tensor[] in_grads = null; | ||||
Func<Operation, Tensor[], Tensor[]> grad_fn = null; | Func<Operation, Tensor[], Tensor[]> grad_fn = null; | ||||
var is_partitioned_call = _IsPartitionedCall(op); | var is_partitioned_call = _IsPartitionedCall(op); | ||||
var is_func_call = false; | |||||
var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; | |||||
var has_out_grads = out_grads.Exists(x => x != null); | var has_out_grads = out_grads.Exists(x => x != null); | ||||
if (has_out_grads && !stop_ops.Contains(op)) | if (has_out_grads && !stop_ops.Contains(op)) | ||||
{ | { | ||||
@@ -157,14 +167,41 @@ namespace Tensorflow | |||||
{ | { | ||||
if (is_func_call) | if (is_func_call) | ||||
{ | { | ||||
EagerDefinedFunction func_call = null; | |||||
if (is_partitioned_call) | if (is_partitioned_call) | ||||
{ | { | ||||
var func_attr = op.get_attr("f"); | |||||
Debug.Assert(func_attr is NameAttrList); | |||||
var func_name = ((NameAttrList)func_attr).Name; | |||||
func_call = src_graph._get_function(func_name); | |||||
if(func_call is null && src_graph.OuterGraph is not null) | |||||
{ | |||||
var graph = src_graph.OuterGraph; | |||||
while(graph is not null) | |||||
{ | |||||
func_call = graph._get_function(func_name); | |||||
if(func_call is not null) | |||||
{ | |||||
break; | |||||
} | |||||
if(graph.OuterGraph is not null) | |||||
{ | |||||
graph = graph.OuterGraph; | |||||
} | |||||
else | |||||
{ | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
func_call = src_graph._get_function(op.type); | |||||
} | } | ||||
// skip the following codes: | |||||
// `func_call = getattr(op, "__defun", func_call)` | |||||
grad_fn = func_call.csharp_grad_func; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -208,6 +245,8 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||||
null, (x, y) => _SymGrad(x, y)); | |||||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | ||||
} | } | ||||
_VerifyGeneratedGradients(in_grads, op); | _VerifyGeneratedGradients(in_grads, op); | ||||
@@ -663,6 +702,11 @@ namespace Tensorflow | |||||
dtypes.resource, dtypes.variant}.Contains(dtype); | dtypes.resource, dtypes.variant}.Contains(dtype); | ||||
} | } | ||||
public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||||
{ | |||||
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Return true if op has real gradient. | /// Return true if op has real gradient. | ||||
/// </summary> | /// </summary> | ||||
@@ -683,7 +727,7 @@ namespace Tensorflow | |||||
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | ||||
{ | { | ||||
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||||
// scope = scope.TrimEnd('/').Replace('/', '_'); | |||||
return grad_fn(op, out_grads); | return grad_fn(op, out_grads); | ||||
} | } | ||||
@@ -696,5 +740,28 @@ namespace Tensorflow | |||||
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | ||||
$"inputs {op.inputs._inputs.Count()}"); | $"inputs {op.inputs._inputs.Count()}"); | ||||
} | } | ||||
private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) | |||||
{ | |||||
var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); | |||||
var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); | |||||
NameAttrList f = new(); | |||||
if (_IsPartitionedCall(op)) | |||||
{ | |||||
var func_attr = op.get_attr("f"); | |||||
Debug.Assert(func_attr is NameAttrList); | |||||
f.Name = ((NameAttrList)func_attr).Name; | |||||
} | |||||
else | |||||
{ | |||||
f.Name = op.type; | |||||
} | |||||
foreach(var k in op.node_def.Attr.Keys) | |||||
{ | |||||
f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); | |||||
} | |||||
var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); | |||||
return in_grads; | |||||
} | |||||
} | } | ||||
} | } |
@@ -98,12 +98,23 @@ namespace Tensorflow | |||||
{ | { | ||||
if (op.inputs == null) return null; | if (op.inputs == null) return null; | ||||
RegisterFromAssembly(); | |||||
var gradient_function = op._gradient_function; | |||||
if(gradient_function is null) | |||||
{ | |||||
RegisterFromAssembly(); | |||||
if (!gradientFunctions.ContainsKey(op.type)) | |||||
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||||
if (!gradientFunctions.ContainsKey(op.type)) | |||||
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||||
return gradientFunctions[op.type]; | |||||
} | |||||
return gradientFunctions[op.type]; | |||||
Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) | |||||
{ | |||||
return gradient_function(operation, args); | |||||
} | |||||
// TODO(Rinne): check if this needs to be registered. | |||||
return wrapped_gradient_function; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using MethodBoundaryAspect.Fody.Attributes; | using MethodBoundaryAspect.Fody.Attributes; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
@@ -22,7 +23,7 @@ namespace Tensorflow.Graphs | |||||
public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
{ | { | ||||
// TODO: func_name can be cache in FullName + Args | // TODO: func_name can be cache in FullName + Args | ||||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||||
if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
{ | { | ||||
@@ -91,6 +92,7 @@ namespace Tensorflow.Graphs | |||||
// cache function. | // cache function. | ||||
function.ReturnType = args.ReturnValue.GetType(); | function.ReturnType = args.ReturnValue.GetType(); | ||||
function._set_infer_function(); | |||||
functions[func_name] = function; | functions[func_name] = function; | ||||
// run function | // run function | ||||
@@ -1,6 +1,15 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | |||||
using System.Buffers; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Framework.Models; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Graphs; | namespace Tensorflow.Graphs; | ||||
@@ -10,12 +19,66 @@ namespace Tensorflow.Graphs; | |||||
/// </summary> | /// </summary> | ||||
public class FuncGraph : Graph, IDisposable | public class FuncGraph : Graph, IDisposable | ||||
{ | { | ||||
SafeFuncGraphHandle _func_graph_handle; | |||||
internal SafeFuncGraphHandle _func_graph_handle; | |||||
internal HashSet<Tensor> _resource_tensor_inputs; | |||||
internal HashSet<WeakReference<IVariableV1>> _watched_variables; | |||||
internal IEnumerable<WeakReference<IVariableV1>> _weak_variables; | |||||
internal object[] _structured_outputs; | |||||
internal Dictionary<long, string> _output_names; | |||||
public string FuncName => _graph_key; | public string FuncName => _graph_key; | ||||
public Tensors Inputs { get; set; } = new Tensors(); | public Tensors Inputs { get; set; } = new Tensors(); | ||||
public Tensors Outputs { get; set; } = new Tensors(); | public Tensors Outputs { get; set; } = new Tensors(); | ||||
public Dictionary<string, string> Attrs { get; set; } | |||||
public Tensors FlatStructuredOutputs | |||||
{ | |||||
get | |||||
{ | |||||
List<Tensor> res = new(); | |||||
foreach(var obj in _structured_outputs) | |||||
{ | |||||
if(obj is Tensor tensor) | |||||
{ | |||||
res.Add(tensor); | |||||
} | |||||
else if(obj is IEnumerable<Tensor> tensors) | |||||
{ | |||||
res.AddRange(tensors); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("The structured outputs member should be tensor or tensors."); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
} | |||||
public string Name { get; set; } | |||||
public IEnumerable<IVariableV1> Variables | |||||
{ | |||||
get | |||||
{ | |||||
return _weak_variables.Select(v => | |||||
{ | |||||
if (v.TryGetTarget(out var target)) | |||||
{ | |||||
return target; | |||||
} | |||||
else | |||||
{ | |||||
throw new AssertionError("Called a function referencing variables which have been deleted. " + | |||||
"This likely means that function-local variables were created and " + | |||||
"not referenced elsewhere in the program. This is generally a " + | |||||
"mistake; consider storing variables in an object attribute on first call."); | |||||
} | |||||
}); | |||||
} | |||||
internal set | |||||
{ | |||||
_weak_variables = value.Select(x => new WeakReference<IVariableV1>(x)); | |||||
} | |||||
} | |||||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||||
public Dictionary<string, AttrValue> Attrs { get; set; } | |||||
Dictionary<long, (Tensor, Tensor)> _captures | Dictionary<long, (Tensor, Tensor)> _captures | ||||
= new Dictionary<long, (Tensor, Tensor)>(); | = new Dictionary<long, (Tensor, Tensor)>(); | ||||
@@ -39,31 +102,42 @@ public class FuncGraph : Graph, IDisposable | |||||
outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
while (outer_graph.building_function) | while (outer_graph.building_function) | ||||
outer_graph = outer_graph.OuterGraph; | outer_graph = outer_graph.OuterGraph; | ||||
_graph_key = name; | |||||
_graph_key = Name = name; | |||||
building_function = true; | building_function = true; | ||||
_weak_variables = new List<WeakReference<IVariableV1>>(); | |||||
_resource_tensor_inputs = new HashSet<Tensor>(); | |||||
_watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||||
} | } | ||||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base() | |||||
{ | { | ||||
outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
while (outer_graph.building_function) | while (outer_graph.building_function) | ||||
outer_graph = outer_graph.OuterGraph; | outer_graph = outer_graph.OuterGraph; | ||||
_graph_key = name; | |||||
_graph_key = Name = name; | |||||
building_function = true; | building_function = true; | ||||
Attrs = attrs; | Attrs = attrs; | ||||
// Will to test if FuncGraph has memory leak | // Will to test if FuncGraph has memory leak | ||||
// c_api.TF_DeleteGraph(_handle); | // c_api.TF_DeleteGraph(_handle); | ||||
_handle = handle; | _handle = handle; | ||||
_weak_variables = new List<WeakReference<IVariableV1>>(); | |||||
_resource_tensor_inputs = new HashSet<Tensor>(); | |||||
_watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||||
} | } | ||||
public void ToGraph(Operation[] opers, | |||||
public void replace_capture(Tensor tensor, Tensor placeholder) | |||||
{ | |||||
_captures[tensor.Id] = (tensor, placeholder); | |||||
} | |||||
public unsafe void ToGraph(Operation[] opers, | |||||
Tensor[] inputs, Tensor[] outputs, | Tensor[] inputs, Tensor[] outputs, | ||||
string[] output_names) | string[] output_names) | ||||
{ | { | ||||
var status = new Status(); | var status = new Status(); | ||||
if (output_names != null && output_names.Length == 0) | |||||
if (output_names is null) | |||||
{ | { | ||||
output_names = null; | |||||
output_names = new string[0]; | |||||
}; | }; | ||||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | _func_graph_handle = c_api.TF_GraphToFunction(_handle, | ||||
@@ -75,7 +149,7 @@ public class FuncGraph : Graph, IDisposable | |||||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | ||||
outputs.Length, | outputs.Length, | ||||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | ||||
output_names, | |||||
output_names.Length != outputs.Length ? null : output_names, | |||||
IntPtr.Zero, | IntPtr.Zero, | ||||
null, | null, | ||||
status); | status); | ||||
@@ -141,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||||
return tensor; | return tensor; | ||||
} | } | ||||
public void watch_variable(IVariableV1 v) | |||||
{ | |||||
if (_resource_tensor_inputs.Contains(v.Handle)) | |||||
{ | |||||
return; | |||||
} | |||||
_watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||||
//this = this.outer_graph; | |||||
} | |||||
Tensor capture_eager_tensor(Tensor tensor, string name) | Tensor capture_eager_tensor(Tensor tensor, string name) | ||||
{ | { | ||||
Tensor graph_const = null; | Tensor graph_const = null; | ||||
@@ -205,6 +289,19 @@ public class FuncGraph : Graph, IDisposable | |||||
Inputs.Add(placeholder); | Inputs.Add(placeholder); | ||||
} | } | ||||
Tensor pop_capture(Tensor tensor) | |||||
{ | |||||
if(_captures.TryGetValue(tensor.Id, out var capture)) | |||||
{ | |||||
_captures.Remove(tensor.Id); | |||||
return capture.Item2; | |||||
} | |||||
else | |||||
{ | |||||
return null; | |||||
} | |||||
} | |||||
Tensor _create_substitute_placeholder(Tensor value, | Tensor _create_substitute_placeholder(Tensor value, | ||||
string name = null, | string name = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
@@ -228,10 +325,7 @@ public class FuncGraph : Graph, IDisposable | |||||
foreach (var (_name, attr_value) in enumerate(Attrs)) | foreach (var (_name, attr_value) in enumerate(Attrs)) | ||||
{ | { | ||||
var serialized = new AttrValue | |||||
{ | |||||
S = ByteString.CopyFromUtf8(attr_value) | |||||
}.ToByteArray(); | |||||
var serialized = attr_value.ToByteArray(); | |||||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | ||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
@@ -254,4 +348,261 @@ public class FuncGraph : Graph, IDisposable | |||||
{ | { | ||||
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | ||||
} | } | ||||
public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func, | |||||
object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null, | |||||
FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, | |||||
bool add_control_dependencies = true, string[] arg_names = null, | |||||
Tensor op_return_value = null, bool capture_by_value = false, | |||||
bool acd_record_initial_resource_uses = false) | |||||
{ | |||||
if(func_graph is null) | |||||
{ | |||||
func_graph = new FuncGraph(name); | |||||
} | |||||
// TODO(Rinne): deal with control dependencies. | |||||
func_graph.as_default(); | |||||
var current_scope = variable_scope.get_variable_scope(); | |||||
var default_use_resource = current_scope.use_resource; | |||||
current_scope.use_resource = true; | |||||
if(signature is not null) | |||||
{ | |||||
args = signature; | |||||
kwargs = new Dictionary<string, object>(); | |||||
} | |||||
var func_args = _get_defun_inputs_from_args(args, arg_names); | |||||
var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); | |||||
if(func_kwargs is not null && func_kwargs.Count > 0) | |||||
{ | |||||
throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); | |||||
} | |||||
foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs })) | |||||
{ | |||||
if(arg is Tensor tensor && tensor.dtype == dtypes.resource) | |||||
{ | |||||
func_graph._resource_tensor_inputs.Add(tensor); | |||||
} | |||||
else if (arg is ResourceVariable variable) | |||||
{ | |||||
func_graph._resource_tensor_inputs.Add(variable.Handle); | |||||
} | |||||
} | |||||
// skip the assignment of `func_graph.structured_input_signature`. | |||||
var flat_func_args = nest.flatten(func_args as object); | |||||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||||
Tensor convert(object x) | |||||
{ | |||||
if (x is null) return null; | |||||
Tensor res = null; | |||||
if(op_return_value is not null && x is Operation) | |||||
{ | |||||
tf_with(ops.control_dependencies(new object[] { x }), _ => | |||||
{ | |||||
res = array_ops.identity(op_return_value); | |||||
}); | |||||
} | |||||
else if(x is not TensorArray) | |||||
{ | |||||
Debug.Assert(x is Tensor); | |||||
res = ops.convert_to_tensor_or_composite(x as Tensor); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException($"The `TensorArray` is not supported here currently."); | |||||
} | |||||
if (add_control_dependencies) | |||||
{ | |||||
// TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. | |||||
} | |||||
return res; | |||||
} | |||||
if (autograph) | |||||
{ | |||||
throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); | |||||
} | |||||
var func_outputs = func(func_args); | |||||
func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); | |||||
func_outputs = func_outputs.Select(x => convert(x)).ToArray(); | |||||
// TODO(Rinne): `check_func_mutation`. | |||||
current_scope.use_resource = default_use_resource; | |||||
var graph_variables = func_graph._watched_variables.ToList(); | |||||
HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>(); | |||||
List<Tensor> inputs = new(); | |||||
foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) | |||||
{ | |||||
if(arg is BaseResourceVariable variable) | |||||
{ | |||||
var resource_placeholder = func_graph.pop_capture(variable.Handle); | |||||
if(resource_placeholder is null) | |||||
{ | |||||
continue; | |||||
} | |||||
Debug.Assert(variable is IVariableV1); | |||||
arg_variables.Add(variable as IVariableV1); | |||||
inputs.Add(resource_placeholder); | |||||
} | |||||
else if(arg is Tensor tensor) | |||||
{ | |||||
inputs.Add(tensor); | |||||
} | |||||
} | |||||
var variables = graph_variables.Select(v => | |||||
{ | |||||
if (v.TryGetTarget(out var target)) | |||||
{ | |||||
return target; | |||||
} | |||||
else | |||||
{ | |||||
return null; | |||||
} | |||||
}).Where(v => v is not null && !arg_variables.Contains(v)); | |||||
func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); | |||||
func_graph._structured_outputs = func_outputs; | |||||
func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) | |||||
.Select(x => func_graph.capture(x))); | |||||
func_graph.Variables = variables; | |||||
func_graph.Exit(); | |||||
if (add_control_dependencies) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
return func_graph; | |||||
} | |||||
private static object[] _get_defun_inputs_from_args(object[] args, string[] names) | |||||
{ | |||||
return _get_defun_inputs(args, names, args) as object[]; | |||||
} | |||||
private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
Debug.Assert(kwargs is null || kwargs.Count == 0); | |||||
return kwargs; | |||||
//string[] names; | |||||
//object[] args; | |||||
//if(kwargs is not null && kwargs.Count > 0) | |||||
//{ | |||||
// var sorted_kwargs = kwargs.OrderBy(x => x.Key); | |||||
// names = sorted_kwargs.Select(x => x.Key).ToArray(); | |||||
// args = sorted_kwargs.Select(x => x.Value).ToArray(); | |||||
//} | |||||
//else | |||||
//{ | |||||
// names = new string[0]; | |||||
// args = new object[0]; | |||||
//} | |||||
//return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>; | |||||
} | |||||
private static object _get_defun_inputs(object[] args, string[] names, object structured_args) | |||||
{ | |||||
List<object> function_inputs = new(); | |||||
if(names is null) | |||||
{ | |||||
names = new string[args.Length]; | |||||
} | |||||
foreach(var (arg_value, name) in zip(args, names)) | |||||
{ | |||||
foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) | |||||
{ | |||||
function_inputs.Add(_get_defun_input(val, name)); | |||||
} | |||||
} | |||||
return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true); | |||||
} | |||||
private static object _get_defun_input(object arg, string name) | |||||
{ | |||||
var func_graph = ops.get_default_graph() as FuncGraph; | |||||
Debug.Assert(func_graph is not null); | |||||
if (arg is Tensor tensor) | |||||
{ | |||||
Tensor placeholder; | |||||
try | |||||
{ | |||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||||
} | |||||
catch (ValueError) | |||||
{ | |||||
// TODO(Rinne): Add warning here. | |||||
placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||||
} | |||||
handle_data_util.copy_handle_data(tensor, placeholder); | |||||
if (name is not null) | |||||
{ | |||||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||||
{ | |||||
S = tf.compat.as_bytes(name) | |||||
}); | |||||
} | |||||
return placeholder; | |||||
} | |||||
else if (arg is TensorSpec spec) | |||||
{ | |||||
string requested_name; | |||||
if (!string.IsNullOrEmpty(spec.name)) | |||||
{ | |||||
requested_name = spec.name; | |||||
} | |||||
else | |||||
{ | |||||
requested_name = name; | |||||
} | |||||
Tensor placeholder; | |||||
try | |||||
{ | |||||
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||||
} | |||||
catch (ValueError) | |||||
{ | |||||
// TODO(Rinne): Add warning here. | |||||
placeholder = tf.placeholder(spec.dtype, spec.shape); | |||||
} | |||||
if (name is not null) | |||||
{ | |||||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||||
{ | |||||
S = tf.compat.as_bytes(requested_name) | |||||
}); | |||||
} | |||||
return placeholder; | |||||
} | |||||
else if (arg is BaseResourceVariable variable) | |||||
{ | |||||
var placeholder = func_graph.capture(variable.Handle, name); | |||||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||||
{ | |||||
S = tf.compat.as_bytes(name) | |||||
}); | |||||
return arg; | |||||
} | |||||
// TODO(Rinne): deal with `VariableSpec`. | |||||
else | |||||
{ | |||||
return arg; | |||||
} | |||||
} | |||||
} | } |
@@ -1,4 +1,6 @@ | |||||
namespace Tensorflow | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow | |||||
{ | { | ||||
public partial class Graph | public partial class Graph | ||||
{ | { | ||||
@@ -6,5 +8,10 @@ | |||||
{ | { | ||||
} | } | ||||
internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map) | |||||
{ | |||||
return new GraphOverrideGradientContext(this, gradient_function_map); | |||||
} | |||||
} | } | ||||
} | } |
@@ -118,7 +118,7 @@ namespace Tensorflow | |||||
/// <param name="compute_device">(Optional.) If True, device functions will be executed | /// <param name="compute_device">(Optional.) If True, device functions will be executed | ||||
/// to compute the device property of the Operation.</param> | /// to compute the device property of the Operation.</param> | ||||
/// <returns>An `Operation` object.</returns> | /// <returns>An `Operation` object.</returns> | ||||
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | |||||
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) | |||||
{ | { | ||||
var ret = new Operation(c_op, this); | var ret = new Operation(c_op, this); | ||||
_add_op(ret); | _add_op(ret); | ||||
@@ -19,6 +19,9 @@ using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Collections.Specialized; | using System.Collections.Specialized; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -86,6 +89,13 @@ namespace Tensorflow | |||||
private int _next_id_counter; | private int _next_id_counter; | ||||
private List<Operation> _unfetchable_ops = new List<Operation>(); | private List<Operation> _unfetchable_ops = new List<Operation>(); | ||||
private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | ||||
private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||||
internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new(); | |||||
private VersionDef _graph_def_versions = new VersionDef() | |||||
{ | |||||
Producer = versions.GRAPH_DEF_VERSION, | |||||
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||||
}; | |||||
public string _name_stack = ""; | public string _name_stack = ""; | ||||
protected string _graph_key; | protected string _graph_key; | ||||
@@ -121,6 +131,8 @@ namespace Tensorflow | |||||
protected Graph outer_graph; | protected Graph outer_graph; | ||||
public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||||
public SafeGraphHandle c_graph => _handle; | |||||
public Graph() | public Graph() | ||||
{ | { | ||||
@@ -147,6 +159,44 @@ namespace Tensorflow | |||||
return ops.set_default_graph(this); | return ops.set_default_graph(this); | ||||
} | } | ||||
public bool IsFunction(string name) | |||||
{ | |||||
return _functions.ContainsKey(tf.compat.as_str(name)); | |||||
} | |||||
internal void AddFunction(EagerDefinedFunction function) | |||||
{ | |||||
_check_not_finalized(); | |||||
var name = function.Name; | |||||
if(function._grad_func_name is not null && function.csharp_grad_func is not null) | |||||
{ | |||||
throw new ValueError($"Gradient defined twice for function {name}"); | |||||
} | |||||
var c_graph = this.c_graph; | |||||
var func = function._c_func.Get(); | |||||
Status status = new(); | |||||
if (function._grad_func is not null) | |||||
{ | |||||
var gradient = function._grad_func._c_func.Get(); | |||||
c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); | |||||
status.Check(true); | |||||
} | |||||
else | |||||
{ | |||||
c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); | |||||
status.Check(true); | |||||
} | |||||
_functions[tf.compat.as_str(name)] = function; | |||||
if(_graph_def_versions.MinConsumer < 12) | |||||
{ | |||||
_graph_def_versions.MinConsumer = 12; | |||||
} | |||||
} | |||||
private Tensor _as_graph_element(object obj) | private Tensor _as_graph_element(object obj) | ||||
{ | { | ||||
if (obj is RefVariable var) | if (obj is RefVariable var) | ||||
@@ -308,6 +358,9 @@ namespace Tensorflow | |||||
private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
{ | { | ||||
// high priority | |||||
// TODO(Rinne): complete the implementation. | |||||
op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); | |||||
_record_op_seen_by_control_dependencies(op); | _record_op_seen_by_control_dependencies(op); | ||||
} | } | ||||
@@ -524,6 +577,11 @@ namespace Tensorflow | |||||
ops.pop_graph(); | ops.pop_graph(); | ||||
} | } | ||||
internal EagerDefinedFunction _get_function(string name) | |||||
{ | |||||
return _functions.GetOrDefault(name, null); | |||||
} | |||||
string debugString = string.Empty; | string debugString = string.Empty; | ||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
@@ -0,0 +1,37 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Text; | |||||
namespace Tensorflow.Graphs | |||||
{ | |||||
internal class GraphOverrideGradientContext: ITensorFlowObject | |||||
{ | |||||
Graph _graph; | |||||
Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map; | |||||
public GraphOverrideGradientContext(Graph graph, | |||||
Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map) | |||||
{ | |||||
_graph = graph; | |||||
_new_gradient_function_map = new_gradient_function_map; | |||||
} | |||||
[DebuggerStepThrough] | |||||
public void __enter__() | |||||
{ | |||||
Debug.Assert(_graph._gradient_function_map.Count == 0); | |||||
_graph._gradient_function_map = _new_gradient_function_map; | |||||
} | |||||
[DebuggerStepThrough] | |||||
public void __exit__() | |||||
{ | |||||
_graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>(); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||||
_handle = c_api.TF_NewImportGraphDefOptions(); | _handle = c_api.TF_NewImportGraphDefOptions(); | ||||
} | } | ||||
public SafeImportGraphDefOptionsHandle Options => _handle; | |||||
public void AddReturnOutput(string name, int index) | public void AddReturnOutput(string name, int index) | ||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | ||||
@@ -185,6 +185,9 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||||
/// <summary> | /// <summary> | ||||
/// Add an output in `graph_def` to be returned via the `return_outputs` output | /// Add an output in `graph_def` to be returned via the `return_outputs` output | ||||
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | ||||
@@ -246,7 +249,7 @@ namespace Tensorflow | |||||
/// <param name="ops">TF_ImportGraphDefOptions*</param> | /// <param name="ops">TF_ImportGraphDefOptions*</param> | ||||
/// <param name="uniquify_prefix">unsigned char</param> | /// <param name="uniquify_prefix">unsigned char</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||||
/// <summary> | /// <summary> | ||||
/// Fetches the return operations requested via | /// Fetches the return operations requested via | ||||
@@ -308,7 +311,7 @@ namespace Tensorflow | |||||
/// <param name="types">const TF_DataType*</param> | /// <param name="types">const TF_DataType*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||||
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | ||||
SafeStatusHandle status); | SafeStatusHandle status); | ||||
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
/// This class has nothing but the attributes different from `LayerArgs`. | /// This class has nothing but the attributes different from `LayerArgs`. | ||||
/// It's used to serialize the model to `tf` format. | /// It's used to serialize the model to `tf` format. | ||||
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | ||||
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. | |||||
/// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. | |||||
/// </summary> | /// </summary> | ||||
public class AutoSerializeLayerArgs: LayerArgs | public class AutoSerializeLayerArgs: LayerArgs | ||||
{ | { | ||||
@@ -7,6 +7,11 @@ using System.Text; | |||||
namespace Tensorflow.Keras.Common | namespace Tensorflow.Keras.Common | ||||
{ | { | ||||
class ShapeInfoFromPython | |||||
{ | |||||
public string class_name { get; set; } | |||||
public long?[] items { get; set; } | |||||
} | |||||
public class CustomizedShapeJsonConverter: JsonConverter | public class CustomizedShapeJsonConverter: JsonConverter | ||||
{ | { | ||||
public override bool CanConvert(Type objectType) | public override bool CanConvert(Type objectType) | ||||
@@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common | |||||
dims[i] = shape.dims[i]; | dims[i] = shape.dims[i]; | ||||
} | } | ||||
} | } | ||||
var token = JToken.FromObject(dims); | |||||
var token = JToken.FromObject(new ShapeInfoFromPython() | |||||
{ | |||||
class_name = "__tuple__", | |||||
items = dims | |||||
}); | |||||
token.WriteTo(writer); | token.WriteTo(writer); | ||||
} | } | ||||
} | } | ||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
{ | { | ||||
long?[] dims; | |||||
try | |||||
{ | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
catch (JsonSerializationException ex) | |||||
{ | |||||
if (reader.Value.Equals("class_name")) | |||||
{ | |||||
reader.Read(); | |||||
reader.Read(); | |||||
reader.Read(); | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
else | |||||
{ | |||||
throw ex; | |||||
} | |||||
} | |||||
if (dims is null) | |||||
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | |||||
if (shape_info_from_python is null) | |||||
{ | { | ||||
return null; | return null; | ||||
} | } | ||||
long ?[]dims = shape_info_from_python.items; | |||||
long[] convertedDims = new long[dims.Length]; | long[] convertedDims = new long[dims.Length]; | ||||
for(int i = 0; i < dims.Length; i++) | for(int i = 0; i < dims.Length; i++) | ||||
{ | { | ||||
@@ -4,10 +4,10 @@ public interface IOptimizer | |||||
{ | { | ||||
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | ||||
Tensor[] clip_gradients(Tensor[] grads); | Tensor[] clip_gradients(Tensor[] grads); | ||||
void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||||
void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true); | bool experimental_aggregate_gradients = true); | ||||
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true); | bool experimental_aggregate_gradients = true); | ||||
} | } |
@@ -20,6 +20,9 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Google.Protobuf; | |||||
using Google.Protobuf.WellKnownTypes; | |||||
using System.Diagnostics; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -47,6 +50,8 @@ namespace Tensorflow | |||||
private readonly Graph _graph; | private readonly Graph _graph; | ||||
internal Func<Operation, object[], Tensor[]> _gradient_function; | |||||
public string type => OpType; | public string type => OpType; | ||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
@@ -61,7 +66,7 @@ namespace Tensorflow | |||||
public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
// OperationDescription _opDesc; | |||||
//private OperationDescription _op_desc; | |||||
public NodeDef node_def => GetNodeDef(); | public NodeDef node_def => GetNodeDef(); | ||||
@@ -216,21 +221,19 @@ namespace Tensorflow | |||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | ||||
string oneof_value = x.ValueCase.ToString(); | |||||
if (string.IsNullOrEmpty(oneof_value)) | |||||
return null; | |||||
var oneof_value = x.ValueCase; | |||||
if (oneof_value == AttrValue.ValueOneofCase.None) | |||||
return new object[0]; | |||||
switch (oneof_value.ToLower()) | |||||
if(oneof_value == AttrValue.ValueOneofCase.List) | |||||
{ | { | ||||
case "list": | |||||
throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||||
case "type": | |||||
return x.Type; | |||||
case "s": | |||||
return x.S.ToStringUtf8(); | |||||
default: | |||||
return x.GetType().GetProperty(oneof_value).GetValue(x); | |||||
throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||||
} | } | ||||
if(oneof_value == AttrValue.ValueOneofCase.Type) | |||||
{ | |||||
return dtypes.as_tf_dtype(x.Type); | |||||
} | |||||
return ProtoUtils.GetSingleAttrValue(x, oneof_value); | |||||
} | } | ||||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | ||||
@@ -238,6 +241,19 @@ namespace Tensorflow | |||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | ||||
} | } | ||||
[Obsolete("The implementation is not complete.")] | |||||
internal void _set_device_from_string(string device_str) | |||||
{ | |||||
// TODO(Rinne): complete it with new C API `SetRequestedDevice`. | |||||
//c_api.TF_SetDevice(_handle, device_str); | |||||
} | |||||
[Obsolete("The implementation is not complete.")] | |||||
internal void _set_device(string device) | |||||
{ | |||||
_set_device_from_string(device); | |||||
} | |||||
private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
{ | { | ||||
var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
@@ -296,5 +312,60 @@ namespace Tensorflow | |||||
} | } | ||||
public NDArray numpy() => throw new NotImplementedException(""); | public NDArray numpy() => throw new NotImplementedException(""); | ||||
internal void _add_outputs(TF_DataType[] types, Shape[] shapes) | |||||
{ | |||||
Debug.Assert(types.Length == shapes.Length); | |||||
int orig_num_outputs = this.outputs.Length; | |||||
var new_outputs = new List<Tensor>(_outputs); | |||||
// Since the `_outputs` is defined as `Array`, when we add new output, we | |||||
// have to create a new array, which brings some performance concerns. | |||||
// In the future maybe the type of `outputs` should be reconsidered. | |||||
for(int i = 0; i < types.Length; i++) | |||||
{ | |||||
var t = new Tensor(this, orig_num_outputs + i, types[i]); | |||||
t.shape = shapes[i]; | |||||
new_outputs.Add(t); | |||||
} | |||||
_outputs = new_outputs.ToArray(); | |||||
} | |||||
internal void _set_func_attr(string attr_name, string func_name) | |||||
{ | |||||
var func = new NameAttrList() { Name = func_name }; | |||||
_set_attr(attr_name, new AttrValue() { Func = func }); | |||||
} | |||||
internal void _set_type_list_attr(string attr_name, DataType[] types) | |||||
{ | |||||
if(types is null || types.Length == 0) | |||||
{ | |||||
return; | |||||
} | |||||
var type_list = new AttrValue.Types.ListValue(); | |||||
type_list.Type.AddRange(types); | |||||
_set_attr(attr_name, new AttrValue() { List = type_list }); | |||||
} | |||||
internal void _set_attr(string attr_name, AttrValue attr_value) | |||||
{ | |||||
var buffer = new Buffer(attr_value.ToByteArray()); | |||||
try | |||||
{ | |||||
_set_attr_with_buf(attr_name, buffer); | |||||
} | |||||
finally | |||||
{ | |||||
buffer.Release(); | |||||
} | |||||
} | |||||
internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) | |||||
{ | |||||
Status status = new(); | |||||
c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status); | |||||
status.Check(true); | |||||
} | |||||
} | } | ||||
} | } |
@@ -14,10 +14,14 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using Google.Protobuf.WellKnownTypes; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -25,6 +29,74 @@ namespace Tensorflow | |||||
{ | { | ||||
public class functional_ops | public class functional_ops | ||||
{ | { | ||||
public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, | |||||
bool executing_eagerly, string config, string executor_type) | |||||
{ | |||||
if (tout is null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
if (config is null) | |||||
{ | |||||
config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||||
} | |||||
if (executor_type is null) | |||||
{ | |||||
executor_type = ""; | |||||
} | |||||
if (executing_eagerly) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); | |||||
AttrValue tin_attr = new() | |||||
{ | |||||
List = new AttrValue.Types.ListValue() | |||||
}; | |||||
tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); | |||||
AttrValue tout_attr = new() | |||||
{ | |||||
List = new AttrValue.Types.ListValue() | |||||
}; | |||||
tout_attr.List.Type.AddRange(tout); | |||||
AttrValue func_attr = new() | |||||
{ | |||||
Func = new NameAttrList() | |||||
}; | |||||
func_attr.Func.Name = f.Name; | |||||
AttrValue executor_type_attr = new AttrValue() | |||||
{ | |||||
S = tf.compat.as_bytes(executor_type) | |||||
}; | |||||
AttrValue config_proto = new AttrValue() | |||||
{ | |||||
S = ByteString.CopyFromUtf8(executor_type) | |||||
}; | |||||
var graph = ops.get_default_graph(); | |||||
f.AddToGraph(graph); | |||||
// TODO(Rinne): complete it with `f.stateful` | |||||
var op_name = "PartitionedCall"; | |||||
string xla_compile_attr = "_XlaMustCompile"; | |||||
Dictionary<string, AttrValue> op_attrs = new(); | |||||
op_attrs["Tin"] = tin_attr; | |||||
op_attrs["Tout"] = tout_attr; | |||||
op_attrs["f"] = func_attr; | |||||
op_attrs["config_proto"] = config_proto; | |||||
op_attrs["executor_type"] = executor_type_attr; | |||||
// TODO(Rinne): deal with `f.definition`. | |||||
var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), | |||||
name: op_name, attrs: op_attrs); | |||||
var outputs = op.outputs; | |||||
// TODO(Rinne): deal with `f.graph`. | |||||
return outputs; | |||||
} | |||||
public static Tensor scan( | public static Tensor scan( | ||||
Func<Tensor, Tensor, Tensor> fn, | Func<Tensor, Tensor, Tensor> fn, | ||||
Tensor elems, | Tensor elems, | ||||
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -210,7 +211,51 @@ namespace Tensorflow | |||||
/// <param name="name">A name for the operation (optional).</param> | /// <param name="name">A name for the operation (optional).</param> | ||||
/// <returns>A `Tensor`. Has the same type as `value`.</returns> | /// <returns>A `Tensor`. Has the same type as `value`.</returns> | ||||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | public static Tensor fill<T>(Tensor dims, T value, string name = null) | ||||
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||||
return _result[0]; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
try | |||||
{ | |||||
return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||||
attrs["dims"] = dims; | |||||
attrs["value"] = value; | |||||
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return result.output; | |||||
} | |||||
public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||||
{ | |||||
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||||
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return _result[0]; | |||||
} | |||||
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||||
/// <summary> | /// <summary> | ||||
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | ||||
@@ -0,0 +1,128 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.Xml.Linq; | |||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Functions; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class gen_functional_ops | |||||
{ | |||||
public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||||
string config = "", string config_proto = "", string executor_type = "", string name = null) | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||||
args, tout, f, config, config_proto, executor_type)); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
if (config is null) | |||||
{ | |||||
config = ""; | |||||
} | |||||
if (config_proto is null) | |||||
{ | |||||
config_proto = ""; | |||||
} | |||||
if (executor_type is null) | |||||
{ | |||||
executor_type = ""; | |||||
} | |||||
Dictionary<string, object> kwargs = new(); | |||||
kwargs["args"] = args; | |||||
kwargs["Tout"] = tout; | |||||
kwargs["f"] = f; | |||||
kwargs["config"] = config; | |||||
kwargs["config_proto"] = config_proto; | |||||
kwargs["executor_type"] = executor_type; | |||||
var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||||
name, kwargs); | |||||
var result = output.outputs; | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return result; | |||||
} | |||||
public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||||
string config, string config_proto, string executor_type, string name, Context ctx) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | |||||
if(config is null) | |||||
{ | |||||
config = ""; | |||||
} | |||||
if(config_proto is null) | |||||
{ | |||||
config_proto = ""; | |||||
} | |||||
if(executor_type is null) | |||||
{ | |||||
executor_type = ""; | |||||
} | |||||
object[] attrs = new object[] | |||||
{ | |||||
}; | |||||
} | |||||
public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||||
"SymbolicGradient", name, input, Tout, f)); | |||||
return _result; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
try | |||||
{ | |||||
return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); | |||||
var result = op.outputs; | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return result; | |||||
} | |||||
public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) | |||||
{ | |||||
object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; | |||||
var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return result; | |||||
} | |||||
} | |||||
} |
@@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") | public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") | ||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); | |||||
return _result[0]; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
try | |||||
{ | |||||
return ensure_shape_eager_fallback(input, shape, name, ctx); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["input"] = input; | dict["input"] = input; | ||||
dict["shape"] = shape; | dict["shape"] = shape; | ||||
var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | ||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return op.output; | return op.output; | ||||
} | } | ||||
public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) | |||||
{ | |||||
object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; | |||||
var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, | |||||
attrs, ctx, name); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return _result[0]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Creates or finds a child frame, and makes <c>data</c> available to the child frame. | /// Creates or finds a child frame, and makes <c>data</c> available to the child frame. | ||||
/// </summary> | /// </summary> | ||||
@@ -0,0 +1,60 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
using static Tensorflow.CppShapeInferenceResult.Types; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public static class handle_data_util | |||||
{ | |||||
public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||||
{ | |||||
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||||
{ | |||||
HandleData handle_data; | |||||
if(source_t is EagerTensor) | |||||
{ | |||||
handle_data = source_t.HandleData; | |||||
} | |||||
else | |||||
{ | |||||
handle_data = ops.get_resource_handle_data(source_t); | |||||
} | |||||
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null | |||||
&& handle_data.ShapeAndType.Count > 0) | |||||
{ | |||||
set_handle_data(target_t, handle_data); | |||||
} | |||||
} | |||||
} | |||||
public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||||
{ | |||||
HandleData handle_data = new(); | |||||
handle_data.IsSet = true; | |||||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||||
{ | |||||
Shape = shape.as_proto(), | |||||
Dtype = dtype.as_datatype_enum() | |||||
}); | |||||
return handle_data; | |||||
} | |||||
public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||||
{ | |||||
if(target_t is EagerTensor) | |||||
{ | |||||
target_t.HandleData = handle_data; | |||||
return; | |||||
} | |||||
Status status = new(); | |||||
var proto = handle_data.ToByteArray(); | |||||
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||||
status.Check(true); | |||||
} | |||||
public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||||
} | |||||
} |
@@ -21,6 +21,11 @@ using Tensorflow.Train; | |||||
using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Operations; | |||||
using System.Buffers; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -31,18 +36,14 @@ namespace Tensorflow | |||||
{ | { | ||||
public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | ||||
{ | { | ||||
// TODO(Rinne): deal with `_handle_graph`. | |||||
var value_tensor = ops.convert_to_tensor(value); | var value_tensor = ops.convert_to_tensor(value); | ||||
return gen_resource_variable_ops.assign_variable_op(handle, | return gen_resource_variable_ops.assign_variable_op(handle, | ||||
value_tensor, | value_tensor, | ||||
name: name); | name: name); | ||||
} | } | ||||
public static bool is_resource_variable(IVariableV1 var) | |||||
{ | |||||
return var is ResourceVariable; | |||||
} | |||||
public static bool is_resource_variable(Trackable var) | |||||
public static bool is_resource_variable(object var) | |||||
{ | { | ||||
return var is BaseResourceVariable; | return var is BaseResourceVariable; | ||||
} | } | ||||
@@ -78,6 +79,18 @@ namespace Tensorflow | |||||
string shared_name, string name, bool graph_mode, Tensor initial_value = null) | string shared_name, string name, bool graph_mode, Tensor initial_value = null) | ||||
{ | { | ||||
var container = ops.get_default_graph().Container; | var container = ops.get_default_graph().Container; | ||||
if(container is null) | |||||
{ | |||||
container = ""; | |||||
} | |||||
if (!graph_mode) | |||||
{ | |||||
if(shared_name is not null) | |||||
{ | |||||
throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||||
} | |||||
shared_name = tf.Context.anonymous_name(); | |||||
} | |||||
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | ||||
dtype: dtype, | dtype: dtype, | ||||
shared_name: shared_name, | shared_name: shared_name, | ||||
@@ -95,26 +108,20 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
// We do not want two distinct ResourceVariable objects for the same | |||||
// underlying resource in the runtime. | |||||
// When in eager mode, explicitly ensure so here. When in graph mode, it's | |||||
// ensured by always generating different variable names. | |||||
var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||||
// We create an assert Op instead of checking right away in order to be | |||||
// compatible with ASYNC execution mode. Further, since not all devices | |||||
// support string tensors, we encode the assertion string in the Op name | |||||
/*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||||
new[] { exists }, | |||||
name: "EagerVariableNameReuse");*/ | |||||
var handle_data = new HandleData(); | |||||
handle_data.IsSet = true; | |||||
handle_data.ShapeAndType.Add(new HandleShapeAndType | |||||
var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||||
if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||||
{ | { | ||||
Dtype = dtype.as_datatype_enum(), | |||||
Shape = shape.as_proto() | |||||
}); | |||||
var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||||
if (extra_handle_data is not null && extra_handle_data.IsSet) | |||||
{ | |||||
if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||||
{ | |||||
throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||||
$"but saw: '{handle_data}'"); | |||||
} | |||||
handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||||
} | |||||
} | |||||
_set_handle_shapes_and_types(handle, handle_data, graph_mode); | _set_handle_shapes_and_types(handle, handle_data, graph_mode); | ||||
return handle; | return handle; | ||||
} | } | ||||
@@ -126,7 +133,7 @@ namespace Tensorflow | |||||
/// <param name="handle"></param> | /// <param name="handle"></param> | ||||
/// <param name="handle_data"></param> | /// <param name="handle_data"></param> | ||||
/// <param name="graph_mode"></param> | /// <param name="graph_mode"></param> | ||||
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
{ | { | ||||
if (!graph_mode) | if (!graph_mode) | ||||
return; | return; | ||||
@@ -144,6 +151,47 @@ namespace Tensorflow | |||||
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | ||||
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | ||||
} | } | ||||
//tensor.HandleData = handle_data; | |||||
//if (!graph_mode) | |||||
// return; | |||||
//var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||||
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||||
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||||
//var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||||
//{ | |||||
// if (!s.UnknownRank) | |||||
// { | |||||
// return s.Dim.Select(d => (int)d.Size).ToArray(); | |||||
// } | |||||
// else | |||||
// { | |||||
// return Memory<int>.Empty; | |||||
// } | |||||
//}).ToArray(); | |||||
//List<MemoryHandle> handles = new(); | |||||
//IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||||
//foreach(var (i, m) in enumerate(converted_shapes)) | |||||
//{ | |||||
// if(m.IsEmpty) | |||||
// { | |||||
// shapes_with_ptr[i] = IntPtr.Zero; | |||||
// } | |||||
// else | |||||
// { | |||||
// var handle = m.Pin(); | |||||
// handles.Add(handle); | |||||
// shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||||
// } | |||||
//} | |||||
//Status status = new(); | |||||
//// TODO(Rinne): enable it. | |||||
//c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||||
// shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||||
//handles = null; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -162,24 +210,6 @@ namespace Tensorflow | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
private static HandleData get_eager_safe_handle_data(Tensor handle) | |||||
{ | |||||
if (handle.Handle == null) | |||||
{ | |||||
var data = new HandleData(); | |||||
data.ShapeAndType.Add(new HandleShapeAndType | |||||
{ | |||||
Shape = handle.shape.as_shape_proto(), | |||||
Dtype = handle.dtype.as_datatype_enum() | |||||
}); | |||||
return data; | |||||
} | |||||
else | |||||
{ | |||||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Copies an existing variable to a new graph, with no initializer. | /// Copies an existing variable to a new graph, with no initializer. | ||||
/// </summary> | /// </summary> | ||||
@@ -231,5 +261,60 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
} | } | ||||
public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) | |||||
{ | |||||
if(dtype == dtypes.variant) | |||||
{ | |||||
var handle_data = get_eager_safe_handle_data(handle); | |||||
if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) | |||||
{ | |||||
tensor.HandleData = new HandleData() | |||||
{ | |||||
IsSet = true | |||||
}; | |||||
tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); | |||||
} | |||||
} | |||||
} | |||||
public static HandleData get_eager_safe_handle_data(Tensor handle) | |||||
{ | |||||
if (handle.Handle == null) | |||||
{ | |||||
var data = new HandleData(); | |||||
data.ShapeAndType.Add(new HandleShapeAndType | |||||
{ | |||||
Shape = handle.shape.as_shape_proto(), | |||||
Dtype = handle.dtype.as_datatype_enum() | |||||
}); | |||||
return data; | |||||
} | |||||
else | |||||
{ | |||||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||||
} | |||||
//if(handle is EagerTensor) | |||||
//{ | |||||
// return handle.HandleData; | |||||
//} | |||||
//else | |||||
//{ | |||||
// return handle_data_util.get_resource_handle_data(handle); | |||||
//} | |||||
} | |||||
public static void variable_accessed(IVariableV1 variable) | |||||
{ | |||||
if (ops.get_default_graph() is FuncGraph func_graph) | |||||
{ | |||||
func_graph.watch_variable(variable); | |||||
} | |||||
if (variable.Trainable) | |||||
{ | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
tape.VariableAccessed(variable); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/framework/allocation_description.proto | // source: tensorflow/core/framework/allocation_description.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -43,23 +43,31 @@ namespace Tensorflow { | |||||
} | } | ||||
#region Messages | #region Messages | ||||
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> { | |||||
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<AllocationDescription> _parser = new pb::MessageParser<AllocationDescription>(() => new AllocationDescription()); | private static readonly pb::MessageParser<AllocationDescription> _parser = new pb::MessageParser<AllocationDescription>(() => new AllocationDescription()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<AllocationDescription> Parser { get { return _parser; } } | public static pb::MessageParser<AllocationDescription> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AllocationDescription() { | public AllocationDescription() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -67,6 +75,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AllocationDescription(AllocationDescription other) : this() { | public AllocationDescription(AllocationDescription other) : this() { | ||||
requestedBytes_ = other.requestedBytes_; | requestedBytes_ = other.requestedBytes_; | ||||
allocatedBytes_ = other.allocatedBytes_; | allocatedBytes_ = other.allocatedBytes_; | ||||
@@ -78,6 +87,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AllocationDescription Clone() { | public AllocationDescription Clone() { | ||||
return new AllocationDescription(this); | return new AllocationDescription(this); | ||||
} | } | ||||
@@ -89,6 +99,7 @@ namespace Tensorflow { | |||||
/// Total number of bytes requested | /// Total number of bytes requested | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long RequestedBytes { | public long RequestedBytes { | ||||
get { return requestedBytes_; } | get { return requestedBytes_; } | ||||
set { | set { | ||||
@@ -103,6 +114,7 @@ namespace Tensorflow { | |||||
/// Total number of bytes allocated if known | /// Total number of bytes allocated if known | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long AllocatedBytes { | public long AllocatedBytes { | ||||
get { return allocatedBytes_; } | get { return allocatedBytes_; } | ||||
set { | set { | ||||
@@ -117,6 +129,7 @@ namespace Tensorflow { | |||||
/// Name of the allocator used | /// Name of the allocator used | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string AllocatorName { | public string AllocatorName { | ||||
get { return allocatorName_; } | get { return allocatorName_; } | ||||
set { | set { | ||||
@@ -131,6 +144,7 @@ namespace Tensorflow { | |||||
/// Identifier of the allocated buffer if known | /// Identifier of the allocated buffer if known | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long AllocationId { | public long AllocationId { | ||||
get { return allocationId_; } | get { return allocationId_; } | ||||
set { | set { | ||||
@@ -145,6 +159,7 @@ namespace Tensorflow { | |||||
/// Set if this tensor only has one remaining reference | /// Set if this tensor only has one remaining reference | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool HasSingleReference { | public bool HasSingleReference { | ||||
get { return hasSingleReference_; } | get { return hasSingleReference_; } | ||||
set { | set { | ||||
@@ -159,6 +174,7 @@ namespace Tensorflow { | |||||
/// Address of the allocation. | /// Address of the allocation. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ulong Ptr { | public ulong Ptr { | ||||
get { return ptr_; } | get { return ptr_; } | ||||
set { | set { | ||||
@@ -167,11 +183,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as AllocationDescription); | return Equals(other as AllocationDescription); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(AllocationDescription other) { | public bool Equals(AllocationDescription other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -189,6 +207,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode(); | if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode(); | ||||
@@ -204,12 +223,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (RequestedBytes != 0L) { | if (RequestedBytes != 0L) { | ||||
output.WriteRawTag(8); | output.WriteRawTag(8); | ||||
output.WriteInt64(RequestedBytes); | output.WriteInt64(RequestedBytes); | ||||
@@ -237,9 +261,45 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (RequestedBytes != 0L) { | |||||
output.WriteRawTag(8); | |||||
output.WriteInt64(RequestedBytes); | |||||
} | |||||
if (AllocatedBytes != 0L) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt64(AllocatedBytes); | |||||
} | |||||
if (AllocatorName.Length != 0) { | |||||
output.WriteRawTag(26); | |||||
output.WriteString(AllocatorName); | |||||
} | |||||
if (AllocationId != 0L) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt64(AllocationId); | |||||
} | |||||
if (HasSingleReference != false) { | |||||
output.WriteRawTag(40); | |||||
output.WriteBool(HasSingleReference); | |||||
} | |||||
if (Ptr != 0UL) { | |||||
output.WriteRawTag(48); | |||||
output.WriteUInt64(Ptr); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (RequestedBytes != 0L) { | if (RequestedBytes != 0L) { | ||||
@@ -267,6 +327,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(AllocationDescription other) { | public void MergeFrom(AllocationDescription other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -293,7 +354,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -326,7 +391,47 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 8: { | |||||
RequestedBytes = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
AllocatedBytes = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 26: { | |||||
AllocatorName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
AllocationId = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 40: { | |||||
HasSingleReference = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 48: { | |||||
Ptr = input.ReadUInt64(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/framework/attr_value.proto | // source: tensorflow/core/framework/attr_value.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -63,23 +63,31 @@ namespace Tensorflow { | |||||
/// Comment indicates the corresponding attr type. Only the field matching the | /// Comment indicates the corresponding attr type. Only the field matching the | ||||
/// attr type may be filled. | /// attr type may be filled. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class AttrValue : pb::IMessage<AttrValue> { | |||||
public sealed partial class AttrValue : pb::IMessage<AttrValue> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<AttrValue> _parser = new pb::MessageParser<AttrValue>(() => new AttrValue()); | private static readonly pb::MessageParser<AttrValue> _parser = new pb::MessageParser<AttrValue>(() => new AttrValue()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<AttrValue> Parser { get { return _parser; } } | public static pb::MessageParser<AttrValue> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AttrValue() { | public AttrValue() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -87,6 +95,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AttrValue(AttrValue other) : this() { | public AttrValue(AttrValue other) : this() { | ||||
switch (other.ValueCase) { | switch (other.ValueCase) { | ||||
case ValueOneofCase.S: | case ValueOneofCase.S: | ||||
@@ -125,6 +134,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public AttrValue Clone() { | public AttrValue Clone() { | ||||
return new AttrValue(this); | return new AttrValue(this); | ||||
} | } | ||||
@@ -135,6 +145,7 @@ namespace Tensorflow { | |||||
/// "string" | /// "string" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pb::ByteString S { | public pb::ByteString S { | ||||
get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } | get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } | ||||
set { | set { | ||||
@@ -149,6 +160,7 @@ namespace Tensorflow { | |||||
/// "int" | /// "int" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long I { | public long I { | ||||
get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } | get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } | ||||
set { | set { | ||||
@@ -163,6 +175,7 @@ namespace Tensorflow { | |||||
/// "float" | /// "float" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public float F { | public float F { | ||||
get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } | get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } | ||||
set { | set { | ||||
@@ -177,6 +190,7 @@ namespace Tensorflow { | |||||
/// "bool" | /// "bool" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool B { | public bool B { | ||||
get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } | get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } | ||||
set { | set { | ||||
@@ -191,6 +205,7 @@ namespace Tensorflow { | |||||
/// "type" | /// "type" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.DataType Type { | public global::Tensorflow.DataType Type { | ||||
get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; } | get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; } | ||||
set { | set { | ||||
@@ -205,6 +220,7 @@ namespace Tensorflow { | |||||
/// "shape" | /// "shape" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.TensorShapeProto Shape { | public global::Tensorflow.TensorShapeProto Shape { | ||||
get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } | get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } | ||||
set { | set { | ||||
@@ -219,6 +235,7 @@ namespace Tensorflow { | |||||
/// "tensor" | /// "tensor" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.TensorProto Tensor { | public global::Tensorflow.TensorProto Tensor { | ||||
get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } | get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } | ||||
set { | set { | ||||
@@ -233,6 +250,7 @@ namespace Tensorflow { | |||||
/// any "list(...)" | /// any "list(...)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.AttrValue.Types.ListValue List { | public global::Tensorflow.AttrValue.Types.ListValue List { | ||||
get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } | get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } | ||||
set { | set { | ||||
@@ -250,6 +268,7 @@ namespace Tensorflow { | |||||
/// that attr in the instantiation. | /// that attr in the instantiation. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.NameAttrList Func { | public global::Tensorflow.NameAttrList Func { | ||||
get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } | get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } | ||||
set { | set { | ||||
@@ -270,6 +289,7 @@ namespace Tensorflow { | |||||
/// given the value "bar". | /// given the value "bar". | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Placeholder { | public string Placeholder { | ||||
get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } | get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } | ||||
set { | set { | ||||
@@ -295,22 +315,26 @@ namespace Tensorflow { | |||||
} | } | ||||
private ValueOneofCase valueCase_ = ValueOneofCase.None; | private ValueOneofCase valueCase_ = ValueOneofCase.None; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ValueOneofCase ValueCase { | public ValueOneofCase ValueCase { | ||||
get { return valueCase_; } | get { return valueCase_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void ClearValue() { | public void ClearValue() { | ||||
valueCase_ = ValueOneofCase.None; | valueCase_ = ValueOneofCase.None; | ||||
value_ = null; | value_ = null; | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as AttrValue); | return Equals(other as AttrValue); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(AttrValue other) { | public bool Equals(AttrValue other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -333,6 +357,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); | if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); | ||||
@@ -353,12 +378,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (valueCase_ == ValueOneofCase.List) { | if (valueCase_ == ValueOneofCase.List) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteMessage(List); | output.WriteMessage(List); | ||||
@@ -402,9 +432,61 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (valueCase_ == ValueOneofCase.List) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(List); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.S) { | |||||
output.WriteRawTag(18); | |||||
output.WriteBytes(S); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.I) { | |||||
output.WriteRawTag(24); | |||||
output.WriteInt64(I); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.F) { | |||||
output.WriteRawTag(37); | |||||
output.WriteFloat(F); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.B) { | |||||
output.WriteRawTag(40); | |||||
output.WriteBool(B); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.Type) { | |||||
output.WriteRawTag(48); | |||||
output.WriteEnum((int) Type); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.Shape) { | |||||
output.WriteRawTag(58); | |||||
output.WriteMessage(Shape); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.Tensor) { | |||||
output.WriteRawTag(66); | |||||
output.WriteMessage(Tensor); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.Placeholder) { | |||||
output.WriteRawTag(74); | |||||
output.WriteString(Placeholder); | |||||
} | |||||
if (valueCase_ == ValueOneofCase.Func) { | |||||
output.WriteRawTag(82); | |||||
output.WriteMessage(Func); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (valueCase_ == ValueOneofCase.S) { | if (valueCase_ == ValueOneofCase.S) { | ||||
@@ -444,6 +526,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(AttrValue other) { | public void MergeFrom(AttrValue other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -497,7 +580,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -567,32 +654,118 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue(); | |||||
if (valueCase_ == ValueOneofCase.List) { | |||||
subBuilder.MergeFrom(List); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
List = subBuilder; | |||||
break; | |||||
} | |||||
case 18: { | |||||
S = input.ReadBytes(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
I = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 37: { | |||||
F = input.ReadFloat(); | |||||
break; | |||||
} | |||||
case 40: { | |||||
B = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 48: { | |||||
value_ = input.ReadEnum(); | |||||
valueCase_ = ValueOneofCase.Type; | |||||
break; | |||||
} | |||||
case 58: { | |||||
global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); | |||||
if (valueCase_ == ValueOneofCase.Shape) { | |||||
subBuilder.MergeFrom(Shape); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
Shape = subBuilder; | |||||
break; | |||||
} | |||||
case 66: { | |||||
global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto(); | |||||
if (valueCase_ == ValueOneofCase.Tensor) { | |||||
subBuilder.MergeFrom(Tensor); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
Tensor = subBuilder; | |||||
break; | |||||
} | |||||
case 74: { | |||||
Placeholder = input.ReadString(); | |||||
break; | |||||
} | |||||
case 82: { | |||||
global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList(); | |||||
if (valueCase_ == ValueOneofCase.Func) { | |||||
subBuilder.MergeFrom(Func); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
Func = subBuilder; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
#region Nested types | #region Nested types | ||||
/// <summary>Container for nested types declared in the AttrValue message type.</summary> | /// <summary>Container for nested types declared in the AttrValue message type.</summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static partial class Types { | public static partial class Types { | ||||
/// <summary> | /// <summary> | ||||
/// LINT.IfChange | /// LINT.IfChange | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class ListValue : pb::IMessage<ListValue> { | |||||
public sealed partial class ListValue : pb::IMessage<ListValue> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<ListValue> _parser = new pb::MessageParser<ListValue>(() => new ListValue()); | private static readonly pb::MessageParser<ListValue> _parser = new pb::MessageParser<ListValue>(() => new ListValue()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<ListValue> Parser { get { return _parser; } } | public static pb::MessageParser<ListValue> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } | get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ListValue() { | public ListValue() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -600,6 +773,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ListValue(ListValue other) : this() { | public ListValue(ListValue other) : this() { | ||||
s_ = other.s_.Clone(); | s_ = other.s_.Clone(); | ||||
i_ = other.i_.Clone(); | i_ = other.i_.Clone(); | ||||
@@ -613,6 +787,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ListValue Clone() { | public ListValue Clone() { | ||||
return new ListValue(this); | return new ListValue(this); | ||||
} | } | ||||
@@ -626,6 +801,7 @@ namespace Tensorflow { | |||||
/// "list(string)" | /// "list(string)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<pb::ByteString> S { | public pbc::RepeatedField<pb::ByteString> S { | ||||
get { return s_; } | get { return s_; } | ||||
} | } | ||||
@@ -639,6 +815,7 @@ namespace Tensorflow { | |||||
/// "list(int)" | /// "list(int)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<long> I { | public pbc::RepeatedField<long> I { | ||||
get { return i_; } | get { return i_; } | ||||
} | } | ||||
@@ -652,6 +829,7 @@ namespace Tensorflow { | |||||
/// "list(float)" | /// "list(float)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<float> F { | public pbc::RepeatedField<float> F { | ||||
get { return f_; } | get { return f_; } | ||||
} | } | ||||
@@ -665,6 +843,7 @@ namespace Tensorflow { | |||||
/// "list(bool)" | /// "list(bool)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<bool> B { | public pbc::RepeatedField<bool> B { | ||||
get { return b_; } | get { return b_; } | ||||
} | } | ||||
@@ -678,6 +857,7 @@ namespace Tensorflow { | |||||
/// "list(type)" | /// "list(type)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.DataType> Type { | public pbc::RepeatedField<global::Tensorflow.DataType> Type { | ||||
get { return type_; } | get { return type_; } | ||||
} | } | ||||
@@ -691,6 +871,7 @@ namespace Tensorflow { | |||||
/// "list(shape)" | /// "list(shape)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.TensorShapeProto> Shape { | public pbc::RepeatedField<global::Tensorflow.TensorShapeProto> Shape { | ||||
get { return shape_; } | get { return shape_; } | ||||
} | } | ||||
@@ -704,6 +885,7 @@ namespace Tensorflow { | |||||
/// "list(tensor)" | /// "list(tensor)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.TensorProto> Tensor { | public pbc::RepeatedField<global::Tensorflow.TensorProto> Tensor { | ||||
get { return tensor_; } | get { return tensor_; } | ||||
} | } | ||||
@@ -717,16 +899,19 @@ namespace Tensorflow { | |||||
/// "list(attr)" | /// "list(attr)" | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.NameAttrList> Func { | public pbc::RepeatedField<global::Tensorflow.NameAttrList> Func { | ||||
get { return func_; } | get { return func_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as ListValue); | return Equals(other as ListValue); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(ListValue other) { | public bool Equals(ListValue other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -746,6 +931,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= s_.GetHashCode(); | hash ^= s_.GetHashCode(); | ||||
@@ -763,12 +949,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
s_.WriteTo(output, _repeated_s_codec); | s_.WriteTo(output, _repeated_s_codec); | ||||
i_.WriteTo(output, _repeated_i_codec); | i_.WriteTo(output, _repeated_i_codec); | ||||
f_.WriteTo(output, _repeated_f_codec); | f_.WriteTo(output, _repeated_f_codec); | ||||
@@ -780,9 +971,29 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
s_.WriteTo(ref output, _repeated_s_codec); | |||||
i_.WriteTo(ref output, _repeated_i_codec); | |||||
f_.WriteTo(ref output, _repeated_f_codec); | |||||
b_.WriteTo(ref output, _repeated_b_codec); | |||||
type_.WriteTo(ref output, _repeated_type_codec); | |||||
shape_.WriteTo(ref output, _repeated_shape_codec); | |||||
tensor_.WriteTo(ref output, _repeated_tensor_codec); | |||||
func_.WriteTo(ref output, _repeated_func_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += s_.CalculateSize(_repeated_s_codec); | size += s_.CalculateSize(_repeated_s_codec); | ||||
@@ -800,6 +1011,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(ListValue other) { | public void MergeFrom(ListValue other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -816,7 +1028,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -861,7 +1077,59 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 18: { | |||||
s_.AddEntriesFrom(ref input, _repeated_s_codec); | |||||
break; | |||||
} | |||||
case 26: | |||||
case 24: { | |||||
i_.AddEntriesFrom(ref input, _repeated_i_codec); | |||||
break; | |||||
} | |||||
case 34: | |||||
case 37: { | |||||
f_.AddEntriesFrom(ref input, _repeated_f_codec); | |||||
break; | |||||
} | |||||
case 42: | |||||
case 40: { | |||||
b_.AddEntriesFrom(ref input, _repeated_b_codec); | |||||
break; | |||||
} | |||||
case 50: | |||||
case 48: { | |||||
type_.AddEntriesFrom(ref input, _repeated_type_codec); | |||||
break; | |||||
} | |||||
case 58: { | |||||
shape_.AddEntriesFrom(ref input, _repeated_shape_codec); | |||||
break; | |||||
} | |||||
case 66: { | |||||
tensor_.AddEntriesFrom(ref input, _repeated_tensor_codec); | |||||
break; | |||||
} | |||||
case 74: { | |||||
func_.AddEntriesFrom(ref input, _repeated_func_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -874,23 +1142,31 @@ namespace Tensorflow { | |||||
/// A list of attr names and their values. The whole list is attached | /// A list of attr names and their values. The whole list is attached | ||||
/// with a string name. E.g., MatMul[T=float]. | /// with a string name. E.g., MatMul[T=float]. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class NameAttrList : pb::IMessage<NameAttrList> { | |||||
public sealed partial class NameAttrList : pb::IMessage<NameAttrList> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<NameAttrList> _parser = new pb::MessageParser<NameAttrList>(() => new NameAttrList()); | private static readonly pb::MessageParser<NameAttrList> _parser = new pb::MessageParser<NameAttrList>(() => new NameAttrList()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<NameAttrList> Parser { get { return _parser; } } | public static pb::MessageParser<NameAttrList> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public NameAttrList() { | public NameAttrList() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -898,6 +1174,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public NameAttrList(NameAttrList other) : this() { | public NameAttrList(NameAttrList other) : this() { | ||||
name_ = other.name_; | name_ = other.name_; | ||||
attr_ = other.attr_.Clone(); | attr_ = other.attr_.Clone(); | ||||
@@ -905,6 +1182,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public NameAttrList Clone() { | public NameAttrList Clone() { | ||||
return new NameAttrList(this); | return new NameAttrList(this); | ||||
} | } | ||||
@@ -913,6 +1191,7 @@ namespace Tensorflow { | |||||
public const int NameFieldNumber = 1; | public const int NameFieldNumber = 1; | ||||
private string name_ = ""; | private string name_ = ""; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Name { | public string Name { | ||||
get { return name_; } | get { return name_; } | ||||
set { | set { | ||||
@@ -926,16 +1205,19 @@ namespace Tensorflow { | |||||
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); | = new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); | ||||
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | ||||
get { return attr_; } | get { return attr_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as NameAttrList); | return Equals(other as NameAttrList); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(NameAttrList other) { | public bool Equals(NameAttrList other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -949,6 +1231,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | if (Name.Length != 0) hash ^= Name.GetHashCode(); | ||||
@@ -960,12 +1243,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(Name); | output.WriteString(Name); | ||||
@@ -974,9 +1262,26 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (Name.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Name); | |||||
} | |||||
attr_.WriteTo(ref output, _map_attr_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | } | ||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
@@ -990,6 +1295,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(NameAttrList other) { | public void MergeFrom(NameAttrList other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -1002,7 +1308,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -1019,7 +1329,31 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
Name = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
attr_.AddEntriesFrom(ref input, _map_attr_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/python/training/checkpoint_state.proto | // source: tensorflow/python/training/checkpoint_state.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -43,23 +43,31 @@ namespace Tensorflow { | |||||
/// <summary> | /// <summary> | ||||
/// Protocol buffer representing the checkpoint state. | /// Protocol buffer representing the checkpoint state. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> { | |||||
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CheckpointState() { | public CheckpointState() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -67,6 +75,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CheckpointState(CheckpointState other) : this() { | public CheckpointState(CheckpointState other) : this() { | ||||
modelCheckpointPath_ = other.modelCheckpointPath_; | modelCheckpointPath_ = other.modelCheckpointPath_; | ||||
allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | ||||
@@ -76,6 +85,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CheckpointState Clone() { | public CheckpointState Clone() { | ||||
return new CheckpointState(this); | return new CheckpointState(this); | ||||
} | } | ||||
@@ -87,6 +97,7 @@ namespace Tensorflow { | |||||
/// Path to the most-recent model checkpoint. | /// Path to the most-recent model checkpoint. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string ModelCheckpointPath { | public string ModelCheckpointPath { | ||||
get { return modelCheckpointPath_; } | get { return modelCheckpointPath_; } | ||||
set { | set { | ||||
@@ -106,6 +117,7 @@ namespace Tensorflow { | |||||
/// this list. | /// this list. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> AllModelCheckpointPaths { | public pbc::RepeatedField<string> AllModelCheckpointPaths { | ||||
get { return allModelCheckpointPaths_; } | get { return allModelCheckpointPaths_; } | ||||
} | } | ||||
@@ -120,6 +132,7 @@ namespace Tensorflow { | |||||
/// when each checkpoint was created. | /// when each checkpoint was created. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | ||||
get { return allModelCheckpointTimestamps_; } | get { return allModelCheckpointTimestamps_; } | ||||
} | } | ||||
@@ -132,6 +145,7 @@ namespace Tensorflow { | |||||
/// checkpoint. | /// checkpoint. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public double LastPreservedTimestamp { | public double LastPreservedTimestamp { | ||||
get { return lastPreservedTimestamp_; } | get { return lastPreservedTimestamp_; } | ||||
set { | set { | ||||
@@ -140,11 +154,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as CheckpointState); | return Equals(other as CheckpointState); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CheckpointState other) { | public bool Equals(CheckpointState other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -160,6 +176,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | ||||
@@ -173,12 +190,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (ModelCheckpointPath.Length != 0) { | if (ModelCheckpointPath.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(ModelCheckpointPath); | output.WriteString(ModelCheckpointPath); | ||||
@@ -192,9 +214,31 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (ModelCheckpointPath.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(ModelCheckpointPath); | |||||
} | |||||
allModelCheckpointPaths_.WriteTo(ref output, _repeated_allModelCheckpointPaths_codec); | |||||
allModelCheckpointTimestamps_.WriteTo(ref output, _repeated_allModelCheckpointTimestamps_codec); | |||||
if (LastPreservedTimestamp != 0D) { | |||||
output.WriteRawTag(33); | |||||
output.WriteDouble(LastPreservedTimestamp); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (ModelCheckpointPath.Length != 0) { | if (ModelCheckpointPath.Length != 0) { | ||||
@@ -212,6 +256,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CheckpointState other) { | public void MergeFrom(CheckpointState other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -228,7 +273,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -254,7 +303,40 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
ModelCheckpointPath = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
allModelCheckpointPaths_.AddEntriesFrom(ref input, _repeated_allModelCheckpointPaths_codec); | |||||
break; | |||||
} | |||||
case 26: | |||||
case 25: { | |||||
allModelCheckpointTimestamps_.AddEntriesFrom(ref input, _repeated_allModelCheckpointTimestamps_codec); | |||||
break; | |||||
} | |||||
case 33: { | |||||
LastPreservedTimestamp = input.ReadDouble(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/protobuf/cluster.proto | // source: tensorflow/core/protobuf/cluster.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -47,23 +47,31 @@ namespace Tensorflow { | |||||
/// <summary> | /// <summary> | ||||
/// Defines a single job in a TensorFlow cluster. | /// Defines a single job in a TensorFlow cluster. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class JobDef : pb::IMessage<JobDef> { | |||||
public sealed partial class JobDef : pb::IMessage<JobDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<JobDef> _parser = new pb::MessageParser<JobDef>(() => new JobDef()); | private static readonly pb::MessageParser<JobDef> _parser = new pb::MessageParser<JobDef>(() => new JobDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<JobDef> Parser { get { return _parser; } } | public static pb::MessageParser<JobDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public JobDef() { | public JobDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -71,6 +79,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public JobDef(JobDef other) : this() { | public JobDef(JobDef other) : this() { | ||||
name_ = other.name_; | name_ = other.name_; | ||||
tasks_ = other.tasks_.Clone(); | tasks_ = other.tasks_.Clone(); | ||||
@@ -78,6 +87,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public JobDef Clone() { | public JobDef Clone() { | ||||
return new JobDef(this); | return new JobDef(this); | ||||
} | } | ||||
@@ -89,6 +99,7 @@ namespace Tensorflow { | |||||
/// The name of this job. | /// The name of this job. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Name { | public string Name { | ||||
get { return name_; } | get { return name_; } | ||||
set { | set { | ||||
@@ -109,16 +120,19 @@ namespace Tensorflow { | |||||
/// "/job:worker/task:7" will be assigned to "example.org:2222". | /// "/job:worker/task:7" will be assigned to "example.org:2222". | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::MapField<int, string> Tasks { | public pbc::MapField<int, string> Tasks { | ||||
get { return tasks_; } | get { return tasks_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as JobDef); | return Equals(other as JobDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(JobDef other) { | public bool Equals(JobDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -132,6 +146,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | if (Name.Length != 0) hash ^= Name.GetHashCode(); | ||||
@@ -143,12 +158,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(Name); | output.WriteString(Name); | ||||
@@ -157,9 +177,26 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (Name.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Name); | |||||
} | |||||
tasks_.WriteTo(ref output, _map_tasks_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
@@ -173,6 +210,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(JobDef other) { | public void MergeFrom(JobDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -185,7 +223,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -202,30 +244,62 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
Name = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
tasks_.AddEntriesFrom(ref input, _map_tasks_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Defines a TensorFlow cluster as a set of jobs. | /// Defines a TensorFlow cluster as a set of jobs. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class ClusterDef : pb::IMessage<ClusterDef> { | |||||
public sealed partial class ClusterDef : pb::IMessage<ClusterDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<ClusterDef> _parser = new pb::MessageParser<ClusterDef>(() => new ClusterDef()); | private static readonly pb::MessageParser<ClusterDef> _parser = new pb::MessageParser<ClusterDef>(() => new ClusterDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<ClusterDef> Parser { get { return _parser; } } | public static pb::MessageParser<ClusterDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ClusterDef() { | public ClusterDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -233,12 +307,14 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ClusterDef(ClusterDef other) : this() { | public ClusterDef(ClusterDef other) : this() { | ||||
job_ = other.job_.Clone(); | job_ = other.job_.Clone(); | ||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ClusterDef Clone() { | public ClusterDef Clone() { | ||||
return new ClusterDef(this); | return new ClusterDef(this); | ||||
} | } | ||||
@@ -252,16 +328,19 @@ namespace Tensorflow { | |||||
/// The jobs that comprise the cluster. | /// The jobs that comprise the cluster. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.JobDef> Job { | public pbc::RepeatedField<global::Tensorflow.JobDef> Job { | ||||
get { return job_; } | get { return job_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as ClusterDef); | return Equals(other as ClusterDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(ClusterDef other) { | public bool Equals(ClusterDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -274,6 +353,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= job_.GetHashCode(); | hash ^= job_.GetHashCode(); | ||||
@@ -284,19 +364,37 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
job_.WriteTo(output, _repeated_job_codec); | job_.WriteTo(output, _repeated_job_codec); | ||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
job_.WriteTo(ref output, _repeated_job_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += job_.CalculateSize(_repeated_job_codec); | size += job_.CalculateSize(_repeated_job_codec); | ||||
@@ -307,6 +405,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(ClusterDef other) { | public void MergeFrom(ClusterDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -316,7 +415,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -329,7 +432,27 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
job_.AddEntriesFrom(ref input, _repeated_job_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/protobuf/control_flow.proto | // source: tensorflow/core/protobuf/control_flow.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -64,23 +64,31 @@ namespace Tensorflow { | |||||
/// <summary> | /// <summary> | ||||
/// Protocol buffer representing the values in ControlFlowContext. | /// Protocol buffer representing the values in ControlFlowContext. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class ValuesDef : pb::IMessage<ValuesDef> { | |||||
public sealed partial class ValuesDef : pb::IMessage<ValuesDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<ValuesDef> _parser = new pb::MessageParser<ValuesDef>(() => new ValuesDef()); | private static readonly pb::MessageParser<ValuesDef> _parser = new pb::MessageParser<ValuesDef>(() => new ValuesDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<ValuesDef> Parser { get { return _parser; } } | public static pb::MessageParser<ValuesDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ValuesDef() { | public ValuesDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -88,6 +96,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ValuesDef(ValuesDef other) : this() { | public ValuesDef(ValuesDef other) : this() { | ||||
values_ = other.values_.Clone(); | values_ = other.values_.Clone(); | ||||
externalValues_ = other.externalValues_.Clone(); | externalValues_ = other.externalValues_.Clone(); | ||||
@@ -95,6 +104,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ValuesDef Clone() { | public ValuesDef Clone() { | ||||
return new ValuesDef(this); | return new ValuesDef(this); | ||||
} | } | ||||
@@ -108,6 +118,7 @@ namespace Tensorflow { | |||||
/// Value names that have been seen in this context. | /// Value names that have been seen in this context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> Values { | public pbc::RepeatedField<string> Values { | ||||
get { return values_; } | get { return values_; } | ||||
} | } | ||||
@@ -121,16 +132,19 @@ namespace Tensorflow { | |||||
/// Value names referenced by but external to this context. | /// Value names referenced by but external to this context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::MapField<string, string> ExternalValues { | public pbc::MapField<string, string> ExternalValues { | ||||
get { return externalValues_; } | get { return externalValues_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as ValuesDef); | return Equals(other as ValuesDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(ValuesDef other) { | public bool Equals(ValuesDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -144,6 +158,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= values_.GetHashCode(); | hash ^= values_.GetHashCode(); | ||||
@@ -155,20 +170,39 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
values_.WriteTo(output, _repeated_values_codec); | values_.WriteTo(output, _repeated_values_codec); | ||||
externalValues_.WriteTo(output, _map_externalValues_codec); | externalValues_.WriteTo(output, _map_externalValues_codec); | ||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
values_.WriteTo(ref output, _repeated_values_codec); | |||||
externalValues_.WriteTo(ref output, _map_externalValues_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += values_.CalculateSize(_repeated_values_codec); | size += values_.CalculateSize(_repeated_values_codec); | ||||
@@ -180,6 +214,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(ValuesDef other) { | public void MergeFrom(ValuesDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -190,7 +225,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -207,7 +246,31 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
values_.AddEntriesFrom(ref input, _repeated_values_codec); | |||||
break; | |||||
} | |||||
case 18: { | |||||
externalValues_.AddEntriesFrom(ref input, _map_externalValues_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -215,23 +278,31 @@ namespace Tensorflow { | |||||
/// Container for any kind of control flow context. Any other control flow | /// Container for any kind of control flow context. Any other control flow | ||||
/// contexts that are added below should also be added here. | /// contexts that are added below should also be added here. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> { | |||||
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<ControlFlowContextDef> _parser = new pb::MessageParser<ControlFlowContextDef>(() => new ControlFlowContextDef()); | private static readonly pb::MessageParser<ControlFlowContextDef> _parser = new pb::MessageParser<ControlFlowContextDef>(() => new ControlFlowContextDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<ControlFlowContextDef> Parser { get { return _parser; } } | public static pb::MessageParser<ControlFlowContextDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ControlFlowContextDef() { | public ControlFlowContextDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -239,6 +310,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ControlFlowContextDef(ControlFlowContextDef other) : this() { | public ControlFlowContextDef(ControlFlowContextDef other) : this() { | ||||
switch (other.CtxtCase) { | switch (other.CtxtCase) { | ||||
case CtxtOneofCase.CondCtxt: | case CtxtOneofCase.CondCtxt: | ||||
@@ -253,6 +325,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ControlFlowContextDef Clone() { | public ControlFlowContextDef Clone() { | ||||
return new ControlFlowContextDef(this); | return new ControlFlowContextDef(this); | ||||
} | } | ||||
@@ -260,6 +333,7 @@ namespace Tensorflow { | |||||
/// <summary>Field number for the "cond_ctxt" field.</summary> | /// <summary>Field number for the "cond_ctxt" field.</summary> | ||||
public const int CondCtxtFieldNumber = 1; | public const int CondCtxtFieldNumber = 1; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.CondContextDef CondCtxt { | public global::Tensorflow.CondContextDef CondCtxt { | ||||
get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } | get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } | ||||
set { | set { | ||||
@@ -271,6 +345,7 @@ namespace Tensorflow { | |||||
/// <summary>Field number for the "while_ctxt" field.</summary> | /// <summary>Field number for the "while_ctxt" field.</summary> | ||||
public const int WhileCtxtFieldNumber = 2; | public const int WhileCtxtFieldNumber = 2; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.WhileContextDef WhileCtxt { | public global::Tensorflow.WhileContextDef WhileCtxt { | ||||
get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } | get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } | ||||
set { | set { | ||||
@@ -288,22 +363,26 @@ namespace Tensorflow { | |||||
} | } | ||||
private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; | private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CtxtOneofCase CtxtCase { | public CtxtOneofCase CtxtCase { | ||||
get { return ctxtCase_; } | get { return ctxtCase_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void ClearCtxt() { | public void ClearCtxt() { | ||||
ctxtCase_ = CtxtOneofCase.None; | ctxtCase_ = CtxtOneofCase.None; | ||||
ctxt_ = null; | ctxt_ = null; | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as ControlFlowContextDef); | return Equals(other as ControlFlowContextDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(ControlFlowContextDef other) { | public bool Equals(ControlFlowContextDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -318,6 +397,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); | if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); | ||||
@@ -330,12 +410,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteMessage(CondCtxt); | output.WriteMessage(CondCtxt); | ||||
@@ -347,9 +432,29 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(CondCtxt); | |||||
} | |||||
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||||
output.WriteRawTag(18); | |||||
output.WriteMessage(WhileCtxt); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | } | ||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | ||||
@@ -365,6 +470,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(ControlFlowContextDef other) { | public void MergeFrom(ControlFlowContextDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -388,7 +494,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -415,30 +525,72 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); | |||||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||||
subBuilder.MergeFrom(CondCtxt); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
CondCtxt = subBuilder; | |||||
break; | |||||
} | |||||
case 18: { | |||||
global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); | |||||
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||||
subBuilder.MergeFrom(WhileCtxt); | |||||
} | |||||
input.ReadMessage(subBuilder); | |||||
WhileCtxt = subBuilder; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Protocol buffer representing a CondContext object. | /// Protocol buffer representing a CondContext object. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class CondContextDef : pb::IMessage<CondContextDef> { | |||||
public sealed partial class CondContextDef : pb::IMessage<CondContextDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CondContextDef> _parser = new pb::MessageParser<CondContextDef>(() => new CondContextDef()); | private static readonly pb::MessageParser<CondContextDef> _parser = new pb::MessageParser<CondContextDef>(() => new CondContextDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CondContextDef> Parser { get { return _parser; } } | public static pb::MessageParser<CondContextDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } | get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CondContextDef() { | public CondContextDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -446,6 +598,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CondContextDef(CondContextDef other) : this() { | public CondContextDef(CondContextDef other) : this() { | ||||
contextName_ = other.contextName_; | contextName_ = other.contextName_; | ||||
predName_ = other.predName_; | predName_ = other.predName_; | ||||
@@ -457,6 +610,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CondContextDef Clone() { | public CondContextDef Clone() { | ||||
return new CondContextDef(this); | return new CondContextDef(this); | ||||
} | } | ||||
@@ -468,6 +622,7 @@ namespace Tensorflow { | |||||
/// Name of the context. | /// Name of the context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string ContextName { | public string ContextName { | ||||
get { return contextName_; } | get { return contextName_; } | ||||
set { | set { | ||||
@@ -482,6 +637,7 @@ namespace Tensorflow { | |||||
/// Name of the pred tensor. | /// Name of the pred tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PredName { | public string PredName { | ||||
get { return predName_; } | get { return predName_; } | ||||
set { | set { | ||||
@@ -496,6 +652,7 @@ namespace Tensorflow { | |||||
/// Name of the pivot tensor. | /// Name of the pivot tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PivotName { | public string PivotName { | ||||
get { return pivotName_; } | get { return pivotName_; } | ||||
set { | set { | ||||
@@ -510,6 +667,7 @@ namespace Tensorflow { | |||||
/// Branch prediction. 0 or 1. | /// Branch prediction. 0 or 1. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int Branch { | public int Branch { | ||||
get { return branch_; } | get { return branch_; } | ||||
set { | set { | ||||
@@ -524,6 +682,7 @@ namespace Tensorflow { | |||||
/// Values and external values in control flow context. | /// Values and external values in control flow context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.ValuesDef ValuesDef { | public global::Tensorflow.ValuesDef ValuesDef { | ||||
get { return valuesDef_; } | get { return valuesDef_; } | ||||
set { | set { | ||||
@@ -540,16 +699,19 @@ namespace Tensorflow { | |||||
/// Contexts contained inside this context (e.g. nested conds). | /// Contexts contained inside this context (e.g. nested conds). | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | ||||
get { return nestedContexts_; } | get { return nestedContexts_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as CondContextDef); | return Equals(other as CondContextDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CondContextDef other) { | public bool Equals(CondContextDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -567,6 +729,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | ||||
@@ -582,12 +745,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (ContextName.Length != 0) { | if (ContextName.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(ContextName); | output.WriteString(ContextName); | ||||
@@ -612,9 +780,42 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (ContextName.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(ContextName); | |||||
} | |||||
if (PredName.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(PredName); | |||||
} | |||||
if (PivotName.Length != 0) { | |||||
output.WriteRawTag(26); | |||||
output.WriteString(PivotName); | |||||
} | |||||
if (Branch != 0) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt32(Branch); | |||||
} | |||||
if (valuesDef_ != null) { | |||||
output.WriteRawTag(42); | |||||
output.WriteMessage(ValuesDef); | |||||
} | |||||
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (ContextName.Length != 0) { | if (ContextName.Length != 0) { | ||||
@@ -640,6 +841,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CondContextDef other) { | public void MergeFrom(CondContextDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -667,7 +869,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -703,30 +909,81 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
ContextName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
PredName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 26: { | |||||
PivotName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
Branch = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 42: { | |||||
if (valuesDef_ == null) { | |||||
ValuesDef = new global::Tensorflow.ValuesDef(); | |||||
} | |||||
input.ReadMessage(ValuesDef); | |||||
break; | |||||
} | |||||
case 50: { | |||||
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Protocol buffer representing a WhileContext object. | /// Protocol buffer representing a WhileContext object. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> { | |||||
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<WhileContextDef> _parser = new pb::MessageParser<WhileContextDef>(() => new WhileContextDef()); | private static readonly pb::MessageParser<WhileContextDef> _parser = new pb::MessageParser<WhileContextDef>(() => new WhileContextDef()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<WhileContextDef> Parser { get { return _parser; } } | public static pb::MessageParser<WhileContextDef> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } | get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public WhileContextDef() { | public WhileContextDef() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -734,6 +991,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public WhileContextDef(WhileContextDef other) : this() { | public WhileContextDef(WhileContextDef other) : this() { | ||||
contextName_ = other.contextName_; | contextName_ = other.contextName_; | ||||
parallelIterations_ = other.parallelIterations_; | parallelIterations_ = other.parallelIterations_; | ||||
@@ -751,6 +1009,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public WhileContextDef Clone() { | public WhileContextDef Clone() { | ||||
return new WhileContextDef(this); | return new WhileContextDef(this); | ||||
} | } | ||||
@@ -762,6 +1021,7 @@ namespace Tensorflow { | |||||
/// Name of the context. | /// Name of the context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string ContextName { | public string ContextName { | ||||
get { return contextName_; } | get { return contextName_; } | ||||
set { | set { | ||||
@@ -776,6 +1036,7 @@ namespace Tensorflow { | |||||
/// The number of iterations allowed to run in parallel. | /// The number of iterations allowed to run in parallel. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int ParallelIterations { | public int ParallelIterations { | ||||
get { return parallelIterations_; } | get { return parallelIterations_; } | ||||
set { | set { | ||||
@@ -790,6 +1051,7 @@ namespace Tensorflow { | |||||
/// Whether backprop is enabled for this while loop. | /// Whether backprop is enabled for this while loop. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool BackProp { | public bool BackProp { | ||||
get { return backProp_; } | get { return backProp_; } | ||||
set { | set { | ||||
@@ -804,6 +1066,7 @@ namespace Tensorflow { | |||||
/// Whether GPU-CPU memory swap is enabled for this loop. | /// Whether GPU-CPU memory swap is enabled for this loop. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool SwapMemory { | public bool SwapMemory { | ||||
get { return swapMemory_; } | get { return swapMemory_; } | ||||
set { | set { | ||||
@@ -818,6 +1081,7 @@ namespace Tensorflow { | |||||
/// Name of the pivot tensor. | /// Name of the pivot tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PivotName { | public string PivotName { | ||||
get { return pivotName_; } | get { return pivotName_; } | ||||
set { | set { | ||||
@@ -832,6 +1096,7 @@ namespace Tensorflow { | |||||
/// Name of the pivot_for_pred tensor. | /// Name of the pivot_for_pred tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PivotForPredName { | public string PivotForPredName { | ||||
get { return pivotForPredName_; } | get { return pivotForPredName_; } | ||||
set { | set { | ||||
@@ -846,6 +1111,7 @@ namespace Tensorflow { | |||||
/// Name of the pivot_for_body tensor. | /// Name of the pivot_for_body tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PivotForBodyName { | public string PivotForBodyName { | ||||
get { return pivotForBodyName_; } | get { return pivotForBodyName_; } | ||||
set { | set { | ||||
@@ -862,6 +1128,7 @@ namespace Tensorflow { | |||||
/// List of names for exit tensors. | /// List of names for exit tensors. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> LoopExitNames { | public pbc::RepeatedField<string> LoopExitNames { | ||||
get { return loopExitNames_; } | get { return loopExitNames_; } | ||||
} | } | ||||
@@ -875,6 +1142,7 @@ namespace Tensorflow { | |||||
/// List of names for enter tensors. | /// List of names for enter tensors. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> LoopEnterNames { | public pbc::RepeatedField<string> LoopEnterNames { | ||||
get { return loopEnterNames_; } | get { return loopEnterNames_; } | ||||
} | } | ||||
@@ -886,6 +1154,7 @@ namespace Tensorflow { | |||||
/// Values and external values in control flow context. | /// Values and external values in control flow context. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.ValuesDef ValuesDef { | public global::Tensorflow.ValuesDef ValuesDef { | ||||
get { return valuesDef_; } | get { return valuesDef_; } | ||||
set { | set { | ||||
@@ -900,6 +1169,7 @@ namespace Tensorflow { | |||||
/// Optional name of the maximum_iterations tensor. | /// Optional name of the maximum_iterations tensor. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string MaximumIterationsName { | public string MaximumIterationsName { | ||||
get { return maximumIterationsName_; } | get { return maximumIterationsName_; } | ||||
set { | set { | ||||
@@ -916,16 +1186,19 @@ namespace Tensorflow { | |||||
/// Contexts contained inside this context (e.g. nested whiles). | /// Contexts contained inside this context (e.g. nested whiles). | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | ||||
get { return nestedContexts_; } | get { return nestedContexts_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as WhileContextDef); | return Equals(other as WhileContextDef); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(WhileContextDef other) { | public bool Equals(WhileContextDef other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -949,6 +1222,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | ||||
@@ -970,12 +1244,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (ContextName.Length != 0) { | if (ContextName.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(ContextName); | output.WriteString(ContextName); | ||||
@@ -1018,9 +1297,60 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (ContextName.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(ContextName); | |||||
} | |||||
if (ParallelIterations != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(ParallelIterations); | |||||
} | |||||
if (BackProp != false) { | |||||
output.WriteRawTag(24); | |||||
output.WriteBool(BackProp); | |||||
} | |||||
if (SwapMemory != false) { | |||||
output.WriteRawTag(32); | |||||
output.WriteBool(SwapMemory); | |||||
} | |||||
if (PivotName.Length != 0) { | |||||
output.WriteRawTag(42); | |||||
output.WriteString(PivotName); | |||||
} | |||||
if (PivotForPredName.Length != 0) { | |||||
output.WriteRawTag(50); | |||||
output.WriteString(PivotForPredName); | |||||
} | |||||
if (PivotForBodyName.Length != 0) { | |||||
output.WriteRawTag(58); | |||||
output.WriteString(PivotForBodyName); | |||||
} | |||||
loopExitNames_.WriteTo(ref output, _repeated_loopExitNames_codec); | |||||
if (valuesDef_ != null) { | |||||
output.WriteRawTag(74); | |||||
output.WriteMessage(ValuesDef); | |||||
} | |||||
loopEnterNames_.WriteTo(ref output, _repeated_loopEnterNames_codec); | |||||
if (MaximumIterationsName.Length != 0) { | |||||
output.WriteRawTag(90); | |||||
output.WriteString(MaximumIterationsName); | |||||
} | |||||
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | } | ||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (ContextName.Length != 0) { | if (ContextName.Length != 0) { | ||||
@@ -1060,6 +1390,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(WhileContextDef other) { | public void MergeFrom(WhileContextDef other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -1101,7 +1432,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -1161,7 +1496,74 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
ContextName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
ParallelIterations = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
BackProp = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
SwapMemory = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 42: { | |||||
PivotName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 50: { | |||||
PivotForPredName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 58: { | |||||
PivotForBodyName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 66: { | |||||
loopExitNames_.AddEntriesFrom(ref input, _repeated_loopExitNames_codec); | |||||
break; | |||||
} | |||||
case 74: { | |||||
if (valuesDef_ == null) { | |||||
ValuesDef = new global::Tensorflow.ValuesDef(); | |||||
} | |||||
input.ReadMessage(ValuesDef); | |||||
break; | |||||
} | |||||
case 82: { | |||||
loopEnterNames_.AddEntriesFrom(ref input, _repeated_loopEnterNames_codec); | |||||
break; | |||||
} | |||||
case 90: { | |||||
MaximumIterationsName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 98: { | |||||
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -0,0 +1,791 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: tensorflow/core/protobuf/coordination_config.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/coordination_config.proto</summary> | |||||
public static partial class CoordinationConfigReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for tensorflow/core/protobuf/coordination_config.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static CoordinationConfigReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"CjJ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX2NvbmZp", | |||||
"Zy5wcm90bxIKdGVuc29yZmxvdyIxCg5Db29yZGluYXRlZEpvYhIMCgRuYW1l", | |||||
"GAEgASgJEhEKCW51bV90YXNrcxgCIAEoBSLdAgoZQ29vcmRpbmF0aW9uU2Vy", | |||||
"dmljZUNvbmZpZxIUCgxzZXJ2aWNlX3R5cGUYASABKAkSFgoOc2VydmljZV9s", | |||||
"ZWFkZXIYAiABKAkSGwoTZW5hYmxlX2hlYWx0aF9jaGVjaxgDIAEoCBImCh5j", | |||||
"bHVzdGVyX3JlZ2lzdGVyX3RpbWVvdXRfaW5fbXMYBCABKAMSHwoXaGVhcnRi", | |||||
"ZWF0X3RpbWVvdXRfaW5fbXMYBSABKAMSOAoUY29vcmRpbmF0ZWRfam9iX2xp", | |||||
"c3QYCiADKAsyGi50ZW5zb3JmbG93LkNvb3JkaW5hdGVkSm9iEiYKHnNodXRk", | |||||
"b3duX2JhcnJpZXJfdGltZW91dF9pbl9tcxgHIAEoAxIqCiJhZ2VudF9kZXN0", | |||||
"cnVjdGlvbl93aXRob3V0X3NodXRkb3duGAggASgIEhgKEHJlY292ZXJhYmxl", | |||||
"X2pvYnMYCSADKAlKBAgGEAdCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl", | |||||
"X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { }, | |||||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedJob), global::Tensorflow.CoordinatedJob.Parser, new[]{ "Name", "NumTasks" }, null, null, null, null), | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceConfig), global::Tensorflow.CoordinationServiceConfig.Parser, new[]{ "ServiceType", "ServiceLeader", "EnableHealthCheck", "ClusterRegisterTimeoutInMs", "HeartbeatTimeoutInMs", "CoordinatedJobList", "ShutdownBarrierTimeoutInMs", "AgentDestructionWithoutShutdown", "RecoverableJobs" }, null, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// Represents a job type and the number of tasks under this job. | |||||
/// For example, ("worker", 20) implies that there will be 20 worker tasks. | |||||
/// </summary> | |||||
public sealed partial class CoordinatedJob : pb::IMessage<CoordinatedJob> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CoordinatedJob> _parser = new pb::MessageParser<CoordinatedJob>(() => new CoordinatedJob()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CoordinatedJob> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinatedJob() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinatedJob(CoordinatedJob other) : this() { | |||||
name_ = other.name_; | |||||
numTasks_ = other.numTasks_; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinatedJob Clone() { | |||||
return new CoordinatedJob(this); | |||||
} | |||||
/// <summary>Field number for the "name" field.</summary> | |||||
public const int NameFieldNumber = 1; | |||||
private string name_ = ""; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Name { | |||||
get { return name_; } | |||||
set { | |||||
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "num_tasks" field.</summary> | |||||
public const int NumTasksFieldNumber = 2; | |||||
private int numTasks_; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int NumTasks { | |||||
get { return numTasks_; } | |||||
set { | |||||
numTasks_ = value; | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as CoordinatedJob); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CoordinatedJob other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (Name != other.Name) return false; | |||||
if (NumTasks != other.NumTasks) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||||
if (NumTasks != 0) hash ^= NumTasks.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (Name.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Name); | |||||
} | |||||
if (NumTasks != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(NumTasks); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (Name.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Name); | |||||
} | |||||
if (NumTasks != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(NumTasks); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (Name.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); | |||||
} | |||||
if (NumTasks != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumTasks); | |||||
} | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CoordinatedJob other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.Name.Length != 0) { | |||||
Name = other.Name; | |||||
} | |||||
if (other.NumTasks != 0) { | |||||
NumTasks = other.NumTasks; | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
Name = input.ReadString(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
NumTasks = input.ReadInt32(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
Name = input.ReadString(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
NumTasks = input.ReadInt32(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
/// <summary> | |||||
/// Coordination service configuration parameters. | |||||
/// The system picks appropriate values for fields that are not set. | |||||
/// </summary> | |||||
public sealed partial class CoordinationServiceConfig : pb::IMessage<CoordinationServiceConfig> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CoordinationServiceConfig> _parser = new pb::MessageParser<CoordinationServiceConfig>(() => new CoordinationServiceConfig()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CoordinationServiceConfig> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[1]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinationServiceConfig() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinationServiceConfig(CoordinationServiceConfig other) : this() { | |||||
serviceType_ = other.serviceType_; | |||||
serviceLeader_ = other.serviceLeader_; | |||||
enableHealthCheck_ = other.enableHealthCheck_; | |||||
clusterRegisterTimeoutInMs_ = other.clusterRegisterTimeoutInMs_; | |||||
heartbeatTimeoutInMs_ = other.heartbeatTimeoutInMs_; | |||||
coordinatedJobList_ = other.coordinatedJobList_.Clone(); | |||||
shutdownBarrierTimeoutInMs_ = other.shutdownBarrierTimeoutInMs_; | |||||
agentDestructionWithoutShutdown_ = other.agentDestructionWithoutShutdown_; | |||||
recoverableJobs_ = other.recoverableJobs_.Clone(); | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CoordinationServiceConfig Clone() { | |||||
return new CoordinationServiceConfig(this); | |||||
} | |||||
/// <summary>Field number for the "service_type" field.</summary> | |||||
public const int ServiceTypeFieldNumber = 1; | |||||
private string serviceType_ = ""; | |||||
/// <summary> | |||||
/// Type of coordination service implementation to enable. | |||||
/// For example, setting the service type as "standalone" starts a service | |||||
/// instance on the leader task to provide the coordination services such as | |||||
/// heartbeats and consistent key-value store. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string ServiceType { | |||||
get { return serviceType_; } | |||||
set { | |||||
serviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "service_leader" field.</summary> | |||||
public const int ServiceLeaderFieldNumber = 2; | |||||
private string serviceLeader_ = ""; | |||||
/// <summary> | |||||
/// Address where the coordination service instance is hosted. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string ServiceLeader { | |||||
get { return serviceLeader_; } | |||||
set { | |||||
serviceLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "enable_health_check" field.</summary> | |||||
public const int EnableHealthCheckFieldNumber = 3; | |||||
private bool enableHealthCheck_; | |||||
/// <summary> | |||||
/// Whether to enable the health check mechanism. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool EnableHealthCheck { | |||||
get { return enableHealthCheck_; } | |||||
set { | |||||
enableHealthCheck_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "cluster_register_timeout_in_ms" field.</summary> | |||||
public const int ClusterRegisterTimeoutInMsFieldNumber = 4; | |||||
private long clusterRegisterTimeoutInMs_; | |||||
/// <summary> | |||||
/// Maximum wait time for all members in the cluster to be registered. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long ClusterRegisterTimeoutInMs { | |||||
get { return clusterRegisterTimeoutInMs_; } | |||||
set { | |||||
clusterRegisterTimeoutInMs_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "heartbeat_timeout_in_ms" field.</summary> | |||||
public const int HeartbeatTimeoutInMsFieldNumber = 5; | |||||
private long heartbeatTimeoutInMs_; | |||||
/// <summary> | |||||
/// Heartbeat timeout, if a task does not record heartbeat in this time | |||||
/// window, it will be considered disconnected. | |||||
/// Note: This is also used as a grace period to accept any heartbeats after | |||||
/// the agent has disconnected, to account for the lag time between the service | |||||
/// recording the state change and the agent stopping heartbeats. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long HeartbeatTimeoutInMs { | |||||
get { return heartbeatTimeoutInMs_; } | |||||
set { | |||||
heartbeatTimeoutInMs_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "coordinated_job_list" field.</summary> | |||||
public const int CoordinatedJobListFieldNumber = 10; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.CoordinatedJob> _repeated_coordinatedJobList_codec | |||||
= pb::FieldCodec.ForMessage(82, global::Tensorflow.CoordinatedJob.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.CoordinatedJob> coordinatedJobList_ = new pbc::RepeatedField<global::Tensorflow.CoordinatedJob>(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.CoordinatedJob> CoordinatedJobList { | |||||
get { return coordinatedJobList_; } | |||||
} | |||||
/// <summary>Field number for the "shutdown_barrier_timeout_in_ms" field.</summary> | |||||
public const int ShutdownBarrierTimeoutInMsFieldNumber = 7; | |||||
private long shutdownBarrierTimeoutInMs_; | |||||
/// <summary> | |||||
/// Denotes how long to wait for all coordination agents to reach the barriers | |||||
/// (after the first shutdown request) before disconnecting together. If | |||||
/// set to 0, no barrier is imposed upon shutdown and each worker can | |||||
/// disconnect individually. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long ShutdownBarrierTimeoutInMs { | |||||
get { return shutdownBarrierTimeoutInMs_; } | |||||
set { | |||||
shutdownBarrierTimeoutInMs_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "agent_destruction_without_shutdown" field.</summary> | |||||
public const int AgentDestructionWithoutShutdownFieldNumber = 8; | |||||
private bool agentDestructionWithoutShutdown_; | |||||
/// <summary> | |||||
/// If set, agents do not make an explicit Shutdown() call. Service will only | |||||
/// find out about the disconnecte agent via stale heartbeats. Used for | |||||
/// testing. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool AgentDestructionWithoutShutdown { | |||||
get { return agentDestructionWithoutShutdown_; } | |||||
set { | |||||
agentDestructionWithoutShutdown_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "recoverable_jobs" field.</summary> | |||||
public const int RecoverableJobsFieldNumber = 9; | |||||
private static readonly pb::FieldCodec<string> _repeated_recoverableJobs_codec | |||||
= pb::FieldCodec.ForString(74); | |||||
private readonly pbc::RepeatedField<string> recoverableJobs_ = new pbc::RepeatedField<string>(); | |||||
/// <summary> | |||||
/// The list of jobs which are recoverable. If a task in this list fails, | |||||
/// it will not propagate error to other tasks. | |||||
/// If empty, no jobs will be recoverable and every task failure will cause | |||||
/// error propagation to other tasks. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> RecoverableJobs { | |||||
get { return recoverableJobs_; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as CoordinationServiceConfig); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CoordinationServiceConfig other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (ServiceType != other.ServiceType) return false; | |||||
if (ServiceLeader != other.ServiceLeader) return false; | |||||
if (EnableHealthCheck != other.EnableHealthCheck) return false; | |||||
if (ClusterRegisterTimeoutInMs != other.ClusterRegisterTimeoutInMs) return false; | |||||
if (HeartbeatTimeoutInMs != other.HeartbeatTimeoutInMs) return false; | |||||
if(!coordinatedJobList_.Equals(other.coordinatedJobList_)) return false; | |||||
if (ShutdownBarrierTimeoutInMs != other.ShutdownBarrierTimeoutInMs) return false; | |||||
if (AgentDestructionWithoutShutdown != other.AgentDestructionWithoutShutdown) return false; | |||||
if(!recoverableJobs_.Equals(other.recoverableJobs_)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (ServiceType.Length != 0) hash ^= ServiceType.GetHashCode(); | |||||
if (ServiceLeader.Length != 0) hash ^= ServiceLeader.GetHashCode(); | |||||
if (EnableHealthCheck != false) hash ^= EnableHealthCheck.GetHashCode(); | |||||
if (ClusterRegisterTimeoutInMs != 0L) hash ^= ClusterRegisterTimeoutInMs.GetHashCode(); | |||||
if (HeartbeatTimeoutInMs != 0L) hash ^= HeartbeatTimeoutInMs.GetHashCode(); | |||||
hash ^= coordinatedJobList_.GetHashCode(); | |||||
if (ShutdownBarrierTimeoutInMs != 0L) hash ^= ShutdownBarrierTimeoutInMs.GetHashCode(); | |||||
if (AgentDestructionWithoutShutdown != false) hash ^= AgentDestructionWithoutShutdown.GetHashCode(); | |||||
hash ^= recoverableJobs_.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (ServiceType.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(ServiceType); | |||||
} | |||||
if (ServiceLeader.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(ServiceLeader); | |||||
} | |||||
if (EnableHealthCheck != false) { | |||||
output.WriteRawTag(24); | |||||
output.WriteBool(EnableHealthCheck); | |||||
} | |||||
if (ClusterRegisterTimeoutInMs != 0L) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt64(ClusterRegisterTimeoutInMs); | |||||
} | |||||
if (HeartbeatTimeoutInMs != 0L) { | |||||
output.WriteRawTag(40); | |||||
output.WriteInt64(HeartbeatTimeoutInMs); | |||||
} | |||||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||||
output.WriteRawTag(56); | |||||
output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||||
} | |||||
if (AgentDestructionWithoutShutdown != false) { | |||||
output.WriteRawTag(64); | |||||
output.WriteBool(AgentDestructionWithoutShutdown); | |||||
} | |||||
recoverableJobs_.WriteTo(output, _repeated_recoverableJobs_codec); | |||||
coordinatedJobList_.WriteTo(output, _repeated_coordinatedJobList_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (ServiceType.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(ServiceType); | |||||
} | |||||
if (ServiceLeader.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(ServiceLeader); | |||||
} | |||||
if (EnableHealthCheck != false) { | |||||
output.WriteRawTag(24); | |||||
output.WriteBool(EnableHealthCheck); | |||||
} | |||||
if (ClusterRegisterTimeoutInMs != 0L) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt64(ClusterRegisterTimeoutInMs); | |||||
} | |||||
if (HeartbeatTimeoutInMs != 0L) { | |||||
output.WriteRawTag(40); | |||||
output.WriteInt64(HeartbeatTimeoutInMs); | |||||
} | |||||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||||
output.WriteRawTag(56); | |||||
output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||||
} | |||||
if (AgentDestructionWithoutShutdown != false) { | |||||
output.WriteRawTag(64); | |||||
output.WriteBool(AgentDestructionWithoutShutdown); | |||||
} | |||||
recoverableJobs_.WriteTo(ref output, _repeated_recoverableJobs_codec); | |||||
coordinatedJobList_.WriteTo(ref output, _repeated_coordinatedJobList_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (ServiceType.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceType); | |||||
} | |||||
if (ServiceLeader.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceLeader); | |||||
} | |||||
if (EnableHealthCheck != false) { | |||||
size += 1 + 1; | |||||
} | |||||
if (ClusterRegisterTimeoutInMs != 0L) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClusterRegisterTimeoutInMs); | |||||
} | |||||
if (HeartbeatTimeoutInMs != 0L) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatTimeoutInMs); | |||||
} | |||||
size += coordinatedJobList_.CalculateSize(_repeated_coordinatedJobList_codec); | |||||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownBarrierTimeoutInMs); | |||||
} | |||||
if (AgentDestructionWithoutShutdown != false) { | |||||
size += 1 + 1; | |||||
} | |||||
size += recoverableJobs_.CalculateSize(_repeated_recoverableJobs_codec); | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CoordinationServiceConfig other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.ServiceType.Length != 0) { | |||||
ServiceType = other.ServiceType; | |||||
} | |||||
if (other.ServiceLeader.Length != 0) { | |||||
ServiceLeader = other.ServiceLeader; | |||||
} | |||||
if (other.EnableHealthCheck != false) { | |||||
EnableHealthCheck = other.EnableHealthCheck; | |||||
} | |||||
if (other.ClusterRegisterTimeoutInMs != 0L) { | |||||
ClusterRegisterTimeoutInMs = other.ClusterRegisterTimeoutInMs; | |||||
} | |||||
if (other.HeartbeatTimeoutInMs != 0L) { | |||||
HeartbeatTimeoutInMs = other.HeartbeatTimeoutInMs; | |||||
} | |||||
coordinatedJobList_.Add(other.coordinatedJobList_); | |||||
if (other.ShutdownBarrierTimeoutInMs != 0L) { | |||||
ShutdownBarrierTimeoutInMs = other.ShutdownBarrierTimeoutInMs; | |||||
} | |||||
if (other.AgentDestructionWithoutShutdown != false) { | |||||
AgentDestructionWithoutShutdown = other.AgentDestructionWithoutShutdown; | |||||
} | |||||
recoverableJobs_.Add(other.recoverableJobs_); | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
ServiceType = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
ServiceLeader = input.ReadString(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
EnableHealthCheck = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 40: { | |||||
HeartbeatTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 56: { | |||||
ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 64: { | |||||
AgentDestructionWithoutShutdown = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 74: { | |||||
recoverableJobs_.AddEntriesFrom(input, _repeated_recoverableJobs_codec); | |||||
break; | |||||
} | |||||
case 82: { | |||||
coordinatedJobList_.AddEntriesFrom(input, _repeated_coordinatedJobList_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
ServiceType = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
ServiceLeader = input.ReadString(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
EnableHealthCheck = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 40: { | |||||
HeartbeatTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 56: { | |||||
ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 64: { | |||||
AgentDestructionWithoutShutdown = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 74: { | |||||
recoverableJobs_.AddEntriesFrom(ref input, _repeated_recoverableJobs_codec); | |||||
break; | |||||
} | |||||
case 82: { | |||||
coordinatedJobList_.AddEntriesFrom(ref input, _repeated_coordinatedJobList_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/python/framework/cpp_shape_inference.proto | // source: tensorflow/python/framework/cpp_shape_inference.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -55,23 +55,31 @@ namespace Tensorflow { | |||||
} | } | ||||
#region Messages | #region Messages | ||||
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> { | |||||
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CppShapeInferenceResult> _parser = new pb::MessageParser<CppShapeInferenceResult>(() => new CppShapeInferenceResult()); | private static readonly pb::MessageParser<CppShapeInferenceResult> _parser = new pb::MessageParser<CppShapeInferenceResult>(() => new CppShapeInferenceResult()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CppShapeInferenceResult> Parser { get { return _parser; } } | public static pb::MessageParser<CppShapeInferenceResult> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceResult() { | public CppShapeInferenceResult() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -79,6 +87,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceResult(CppShapeInferenceResult other) : this() { | public CppShapeInferenceResult(CppShapeInferenceResult other) : this() { | ||||
shape_ = other.shape_ != null ? other.shape_.Clone() : null; | shape_ = other.shape_ != null ? other.shape_.Clone() : null; | ||||
handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null; | handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null; | ||||
@@ -86,6 +95,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceResult Clone() { | public CppShapeInferenceResult Clone() { | ||||
return new CppShapeInferenceResult(this); | return new CppShapeInferenceResult(this); | ||||
} | } | ||||
@@ -94,6 +104,7 @@ namespace Tensorflow { | |||||
public const int ShapeFieldNumber = 1; | public const int ShapeFieldNumber = 1; | ||||
private global::Tensorflow.TensorShapeProto shape_; | private global::Tensorflow.TensorShapeProto shape_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.TensorShapeProto Shape { | public global::Tensorflow.TensorShapeProto Shape { | ||||
get { return shape_; } | get { return shape_; } | ||||
set { | set { | ||||
@@ -105,6 +116,7 @@ namespace Tensorflow { | |||||
public const int HandleDataFieldNumber = 4; | public const int HandleDataFieldNumber = 4; | ||||
private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_; | private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { | public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { | ||||
get { return handleData_; } | get { return handleData_; } | ||||
set { | set { | ||||
@@ -113,11 +125,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as CppShapeInferenceResult); | return Equals(other as CppShapeInferenceResult); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CppShapeInferenceResult other) { | public bool Equals(CppShapeInferenceResult other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -131,6 +145,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (shape_ != null) hash ^= Shape.GetHashCode(); | if (shape_ != null) hash ^= Shape.GetHashCode(); | ||||
@@ -142,12 +157,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (shape_ != null) { | if (shape_ != null) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteMessage(Shape); | output.WriteMessage(Shape); | ||||
@@ -159,9 +179,29 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (shape_ != null) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(Shape); | |||||
} | |||||
if (handleData_ != null) { | |||||
output.WriteRawTag(34); | |||||
output.WriteMessage(HandleData); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (shape_ != null) { | if (shape_ != null) { | ||||
@@ -177,6 +217,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CppShapeInferenceResult other) { | public void MergeFrom(CppShapeInferenceResult other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -197,7 +238,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -220,29 +265,68 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
if (shape_ == null) { | |||||
Shape = new global::Tensorflow.TensorShapeProto(); | |||||
} | |||||
input.ReadMessage(Shape); | |||||
break; | |||||
} | |||||
case 34: { | |||||
if (handleData_ == null) { | |||||
HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); | |||||
} | |||||
input.ReadMessage(HandleData); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
#region Nested types | #region Nested types | ||||
/// <summary>Container for nested types declared in the CppShapeInferenceResult message type.</summary> | /// <summary>Container for nested types declared in the CppShapeInferenceResult message type.</summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static partial class Types { | public static partial class Types { | ||||
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> { | |||||
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<HandleShapeAndType> _parser = new pb::MessageParser<HandleShapeAndType>(() => new HandleShapeAndType()); | private static readonly pb::MessageParser<HandleShapeAndType> _parser = new pb::MessageParser<HandleShapeAndType>(() => new HandleShapeAndType()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<HandleShapeAndType> Parser { get { return _parser; } } | public static pb::MessageParser<HandleShapeAndType> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; } | get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleShapeAndType() { | public HandleShapeAndType() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -250,6 +334,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleShapeAndType(HandleShapeAndType other) : this() { | public HandleShapeAndType(HandleShapeAndType other) : this() { | ||||
shape_ = other.shape_ != null ? other.shape_.Clone() : null; | shape_ = other.shape_ != null ? other.shape_.Clone() : null; | ||||
dtype_ = other.dtype_; | dtype_ = other.dtype_; | ||||
@@ -258,6 +343,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleShapeAndType Clone() { | public HandleShapeAndType Clone() { | ||||
return new HandleShapeAndType(this); | return new HandleShapeAndType(this); | ||||
} | } | ||||
@@ -266,6 +352,7 @@ namespace Tensorflow { | |||||
public const int ShapeFieldNumber = 1; | public const int ShapeFieldNumber = 1; | ||||
private global::Tensorflow.TensorShapeProto shape_; | private global::Tensorflow.TensorShapeProto shape_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.TensorShapeProto Shape { | public global::Tensorflow.TensorShapeProto Shape { | ||||
get { return shape_; } | get { return shape_; } | ||||
set { | set { | ||||
@@ -277,6 +364,7 @@ namespace Tensorflow { | |||||
public const int DtypeFieldNumber = 2; | public const int DtypeFieldNumber = 2; | ||||
private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; | private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.DataType Dtype { | public global::Tensorflow.DataType Dtype { | ||||
get { return dtype_; } | get { return dtype_; } | ||||
set { | set { | ||||
@@ -288,6 +376,7 @@ namespace Tensorflow { | |||||
public const int TypeFieldNumber = 4; | public const int TypeFieldNumber = 4; | ||||
private global::Tensorflow.FullTypeDef type_; | private global::Tensorflow.FullTypeDef type_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.FullTypeDef Type { | public global::Tensorflow.FullTypeDef Type { | ||||
get { return type_; } | get { return type_; } | ||||
set { | set { | ||||
@@ -296,11 +385,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as HandleShapeAndType); | return Equals(other as HandleShapeAndType); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(HandleShapeAndType other) { | public bool Equals(HandleShapeAndType other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -315,6 +406,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (shape_ != null) hash ^= Shape.GetHashCode(); | if (shape_ != null) hash ^= Shape.GetHashCode(); | ||||
@@ -327,12 +419,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (shape_ != null) { | if (shape_ != null) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteMessage(Shape); | output.WriteMessage(Shape); | ||||
@@ -348,9 +445,33 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (shape_ != null) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(Shape); | |||||
} | |||||
if (Dtype != global::Tensorflow.DataType.DtInvalid) { | |||||
output.WriteRawTag(16); | |||||
output.WriteEnum((int) Dtype); | |||||
} | |||||
if (type_ != null) { | |||||
output.WriteRawTag(34); | |||||
output.WriteMessage(Type); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (shape_ != null) { | if (shape_ != null) { | ||||
@@ -369,6 +490,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(HandleShapeAndType other) { | public void MergeFrom(HandleShapeAndType other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -392,7 +514,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -419,27 +545,69 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
if (shape_ == null) { | |||||
Shape = new global::Tensorflow.TensorShapeProto(); | |||||
} | |||||
input.ReadMessage(Shape); | |||||
break; | |||||
} | |||||
case 16: { | |||||
Dtype = (global::Tensorflow.DataType) input.ReadEnum(); | |||||
break; | |||||
} | |||||
case 34: { | |||||
if (type_ == null) { | |||||
Type = new global::Tensorflow.FullTypeDef(); | |||||
} | |||||
input.ReadMessage(Type); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
public sealed partial class HandleData : pb::IMessage<HandleData> { | |||||
public sealed partial class HandleData : pb::IMessage<HandleData> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<HandleData> _parser = new pb::MessageParser<HandleData>(() => new HandleData()); | private static readonly pb::MessageParser<HandleData> _parser = new pb::MessageParser<HandleData>(() => new HandleData()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<HandleData> Parser { get { return _parser; } } | public static pb::MessageParser<HandleData> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; } | get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleData() { | public HandleData() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -447,6 +615,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleData(HandleData other) : this() { | public HandleData(HandleData other) : this() { | ||||
isSet_ = other.isSet_; | isSet_ = other.isSet_; | ||||
shapeAndType_ = other.shapeAndType_.Clone(); | shapeAndType_ = other.shapeAndType_.Clone(); | ||||
@@ -454,6 +623,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public HandleData Clone() { | public HandleData Clone() { | ||||
return new HandleData(this); | return new HandleData(this); | ||||
} | } | ||||
@@ -462,6 +632,7 @@ namespace Tensorflow { | |||||
public const int IsSetFieldNumber = 1; | public const int IsSetFieldNumber = 1; | ||||
private bool isSet_; | private bool isSet_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool IsSet { | public bool IsSet { | ||||
get { return isSet_; } | get { return isSet_; } | ||||
set { | set { | ||||
@@ -478,16 +649,19 @@ namespace Tensorflow { | |||||
/// Only valid if <is_set>. | /// Only valid if <is_set>. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | ||||
get { return shapeAndType_; } | get { return shapeAndType_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as HandleData); | return Equals(other as HandleData); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(HandleData other) { | public bool Equals(HandleData other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -501,6 +675,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (IsSet != false) hash ^= IsSet.GetHashCode(); | if (IsSet != false) hash ^= IsSet.GetHashCode(); | ||||
@@ -512,12 +687,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (IsSet != false) { | if (IsSet != false) { | ||||
output.WriteRawTag(8); | output.WriteRawTag(8); | ||||
output.WriteBool(IsSet); | output.WriteBool(IsSet); | ||||
@@ -526,9 +706,26 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (IsSet != false) { | |||||
output.WriteRawTag(8); | |||||
output.WriteBool(IsSet); | |||||
} | |||||
shapeAndType_.WriteTo(ref output, _repeated_shapeAndType_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (IsSet != false) { | if (IsSet != false) { | ||||
@@ -542,6 +739,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(HandleData other) { | public void MergeFrom(HandleData other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -554,7 +752,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -571,8 +773,32 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 8: { | |||||
IsSet = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
shapeAndType_.AddEntriesFrom(ref input, _repeated_shapeAndType_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
} | } | ||||
@@ -580,23 +806,31 @@ namespace Tensorflow { | |||||
} | } | ||||
public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> { | |||||
public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<CppShapeInferenceInputsNeeded> _parser = new pb::MessageParser<CppShapeInferenceInputsNeeded>(() => new CppShapeInferenceInputsNeeded()); | private static readonly pb::MessageParser<CppShapeInferenceInputsNeeded> _parser = new pb::MessageParser<CppShapeInferenceInputsNeeded>(() => new CppShapeInferenceInputsNeeded()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<CppShapeInferenceInputsNeeded> Parser { get { return _parser; } } | public static pb::MessageParser<CppShapeInferenceInputsNeeded> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceInputsNeeded() { | public CppShapeInferenceInputsNeeded() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -604,6 +838,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() { | public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() { | ||||
inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone(); | inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone(); | ||||
inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone(); | inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone(); | ||||
@@ -611,6 +846,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public CppShapeInferenceInputsNeeded Clone() { | public CppShapeInferenceInputsNeeded Clone() { | ||||
return new CppShapeInferenceInputsNeeded(this); | return new CppShapeInferenceInputsNeeded(this); | ||||
} | } | ||||
@@ -621,6 +857,7 @@ namespace Tensorflow { | |||||
= pb::FieldCodec.ForInt32(10); | = pb::FieldCodec.ForInt32(10); | ||||
private readonly pbc::RepeatedField<int> inputTensorsNeeded_ = new pbc::RepeatedField<int>(); | private readonly pbc::RepeatedField<int> inputTensorsNeeded_ = new pbc::RepeatedField<int>(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<int> InputTensorsNeeded { | public pbc::RepeatedField<int> InputTensorsNeeded { | ||||
get { return inputTensorsNeeded_; } | get { return inputTensorsNeeded_; } | ||||
} | } | ||||
@@ -631,16 +868,19 @@ namespace Tensorflow { | |||||
= pb::FieldCodec.ForInt32(18); | = pb::FieldCodec.ForInt32(18); | ||||
private readonly pbc::RepeatedField<int> inputTensorsAsShapesNeeded_ = new pbc::RepeatedField<int>(); | private readonly pbc::RepeatedField<int> inputTensorsAsShapesNeeded_ = new pbc::RepeatedField<int>(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<int> InputTensorsAsShapesNeeded { | public pbc::RepeatedField<int> InputTensorsAsShapesNeeded { | ||||
get { return inputTensorsAsShapesNeeded_; } | get { return inputTensorsAsShapesNeeded_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as CppShapeInferenceInputsNeeded); | return Equals(other as CppShapeInferenceInputsNeeded); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(CppShapeInferenceInputsNeeded other) { | public bool Equals(CppShapeInferenceInputsNeeded other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -654,6 +894,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= inputTensorsNeeded_.GetHashCode(); | hash ^= inputTensorsNeeded_.GetHashCode(); | ||||
@@ -665,20 +906,39 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec); | inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec); | ||||
inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec); | inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec); | ||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
inputTensorsNeeded_.WriteTo(ref output, _repeated_inputTensorsNeeded_codec); | |||||
inputTensorsAsShapesNeeded_.WriteTo(ref output, _repeated_inputTensorsAsShapesNeeded_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | } | ||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec); | size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec); | ||||
@@ -690,6 +950,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(CppShapeInferenceInputsNeeded other) { | public void MergeFrom(CppShapeInferenceInputsNeeded other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -700,7 +961,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -719,7 +984,33 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: | |||||
case 8: { | |||||
inputTensorsNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsNeeded_codec); | |||||
break; | |||||
} | |||||
case 18: | |||||
case 16: { | |||||
inputTensorsAsShapesNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsAsShapesNeeded_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/protobuf/debug.proto | // source: tensorflow/core/protobuf/debug.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -55,23 +55,31 @@ namespace Tensorflow { | |||||
/// <summary> | /// <summary> | ||||
/// Option for watching a node in TensorFlow Debugger (tfdbg). | /// Option for watching a node in TensorFlow Debugger (tfdbg). | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> { | |||||
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DebugTensorWatch> _parser = new pb::MessageParser<DebugTensorWatch>(() => new DebugTensorWatch()); | private static readonly pb::MessageParser<DebugTensorWatch> _parser = new pb::MessageParser<DebugTensorWatch>(() => new DebugTensorWatch()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DebugTensorWatch> Parser { get { return _parser; } } | public static pb::MessageParser<DebugTensorWatch> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugTensorWatch() { | public DebugTensorWatch() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -79,6 +87,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugTensorWatch(DebugTensorWatch other) : this() { | public DebugTensorWatch(DebugTensorWatch other) : this() { | ||||
nodeName_ = other.nodeName_; | nodeName_ = other.nodeName_; | ||||
outputSlot_ = other.outputSlot_; | outputSlot_ = other.outputSlot_; | ||||
@@ -89,6 +98,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugTensorWatch Clone() { | public DebugTensorWatch Clone() { | ||||
return new DebugTensorWatch(this); | return new DebugTensorWatch(this); | ||||
} | } | ||||
@@ -102,6 +112,7 @@ namespace Tensorflow { | |||||
/// general. | /// general. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string NodeName { | public string NodeName { | ||||
get { return nodeName_; } | get { return nodeName_; } | ||||
set { | set { | ||||
@@ -120,6 +131,7 @@ namespace Tensorflow { | |||||
/// errors currently. | /// errors currently. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int OutputSlot { | public int OutputSlot { | ||||
get { return outputSlot_; } | get { return outputSlot_; } | ||||
set { | set { | ||||
@@ -138,6 +150,7 @@ namespace Tensorflow { | |||||
/// e.g., {"DebugIdentity", "DebugNanCount"} | /// e.g., {"DebugIdentity", "DebugNanCount"} | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> DebugOps { | public pbc::RepeatedField<string> DebugOps { | ||||
get { return debugOps_; } | get { return debugOps_; } | ||||
} | } | ||||
@@ -170,6 +183,7 @@ namespace Tensorflow { | |||||
/// TODO(cais): More visible documentation of this in g3docs. | /// TODO(cais): More visible documentation of this in g3docs. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> DebugUrls { | public pbc::RepeatedField<string> DebugUrls { | ||||
get { return debugUrls_; } | get { return debugUrls_; } | ||||
} | } | ||||
@@ -182,6 +196,7 @@ namespace Tensorflow { | |||||
/// incompatibility). Instead, just log the failure. | /// incompatibility). Instead, just log the failure. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool TolerateDebugOpCreationFailures { | public bool TolerateDebugOpCreationFailures { | ||||
get { return tolerateDebugOpCreationFailures_; } | get { return tolerateDebugOpCreationFailures_; } | ||||
set { | set { | ||||
@@ -190,11 +205,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DebugTensorWatch); | return Equals(other as DebugTensorWatch); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DebugTensorWatch other) { | public bool Equals(DebugTensorWatch other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -211,6 +228,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); | if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); | ||||
@@ -225,12 +243,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (NodeName.Length != 0) { | if (NodeName.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(NodeName); | output.WriteString(NodeName); | ||||
@@ -248,9 +271,35 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (NodeName.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(NodeName); | |||||
} | |||||
if (OutputSlot != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(OutputSlot); | |||||
} | |||||
debugOps_.WriteTo(ref output, _repeated_debugOps_codec); | |||||
debugUrls_.WriteTo(ref output, _repeated_debugUrls_codec); | |||||
if (TolerateDebugOpCreationFailures != false) { | |||||
output.WriteRawTag(40); | |||||
output.WriteBool(TolerateDebugOpCreationFailures); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (NodeName.Length != 0) { | if (NodeName.Length != 0) { | ||||
@@ -271,6 +320,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DebugTensorWatch other) { | public void MergeFrom(DebugTensorWatch other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -290,7 +340,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -319,30 +373,74 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
NodeName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
OutputSlot = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 26: { | |||||
debugOps_.AddEntriesFrom(ref input, _repeated_debugOps_codec); | |||||
break; | |||||
} | |||||
case 34: { | |||||
debugUrls_.AddEntriesFrom(ref input, _repeated_debugUrls_codec); | |||||
break; | |||||
} | |||||
case 40: { | |||||
TolerateDebugOpCreationFailures = input.ReadBool(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). | /// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class DebugOptions : pb::IMessage<DebugOptions> { | |||||
public sealed partial class DebugOptions : pb::IMessage<DebugOptions> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DebugOptions> _parser = new pb::MessageParser<DebugOptions>(() => new DebugOptions()); | private static readonly pb::MessageParser<DebugOptions> _parser = new pb::MessageParser<DebugOptions>(() => new DebugOptions()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DebugOptions> Parser { get { return _parser; } } | public static pb::MessageParser<DebugOptions> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugOptions() { | public DebugOptions() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -350,6 +448,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugOptions(DebugOptions other) : this() { | public DebugOptions(DebugOptions other) : this() { | ||||
debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone(); | debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone(); | ||||
globalStep_ = other.globalStep_; | globalStep_ = other.globalStep_; | ||||
@@ -358,6 +457,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebugOptions Clone() { | public DebugOptions Clone() { | ||||
return new DebugOptions(this); | return new DebugOptions(this); | ||||
} | } | ||||
@@ -371,6 +471,7 @@ namespace Tensorflow { | |||||
/// Debugging options | /// Debugging options | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.DebugTensorWatch> DebugTensorWatchOpts { | public pbc::RepeatedField<global::Tensorflow.DebugTensorWatch> DebugTensorWatchOpts { | ||||
get { return debugTensorWatchOpts_; } | get { return debugTensorWatchOpts_; } | ||||
} | } | ||||
@@ -384,6 +485,7 @@ namespace Tensorflow { | |||||
/// step count. | /// step count. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long GlobalStep { | public long GlobalStep { | ||||
get { return globalStep_; } | get { return globalStep_; } | ||||
set { | set { | ||||
@@ -401,6 +503,7 @@ namespace Tensorflow { | |||||
/// are cleaned up from the disk after each Session.run. | /// are cleaned up from the disk after each Session.run. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool ResetDiskByteUsage { | public bool ResetDiskByteUsage { | ||||
get { return resetDiskByteUsage_; } | get { return resetDiskByteUsage_; } | ||||
set { | set { | ||||
@@ -409,11 +512,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DebugOptions); | return Equals(other as DebugOptions); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DebugOptions other) { | public bool Equals(DebugOptions other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -428,6 +533,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= debugTensorWatchOpts_.GetHashCode(); | hash ^= debugTensorWatchOpts_.GetHashCode(); | ||||
@@ -440,12 +546,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec); | debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec); | ||||
if (GlobalStep != 0L) { | if (GlobalStep != 0L) { | ||||
output.WriteRawTag(80); | output.WriteRawTag(80); | ||||
@@ -458,9 +569,30 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
debugTensorWatchOpts_.WriteTo(ref output, _repeated_debugTensorWatchOpts_codec); | |||||
if (GlobalStep != 0L) { | |||||
output.WriteRawTag(80); | |||||
output.WriteInt64(GlobalStep); | |||||
} | |||||
if (ResetDiskByteUsage != false) { | |||||
output.WriteRawTag(88); | |||||
output.WriteBool(ResetDiskByteUsage); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec); | size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec); | ||||
@@ -477,6 +609,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DebugOptions other) { | public void MergeFrom(DebugOptions other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -492,7 +625,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -513,27 +650,63 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 34: { | |||||
debugTensorWatchOpts_.AddEntriesFrom(ref input, _repeated_debugTensorWatchOpts_codec); | |||||
break; | |||||
} | |||||
case 80: { | |||||
GlobalStep = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 88: { | |||||
ResetDiskByteUsage = input.ReadBool(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> { | |||||
public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DebuggedSourceFile> _parser = new pb::MessageParser<DebuggedSourceFile>(() => new DebuggedSourceFile()); | private static readonly pb::MessageParser<DebuggedSourceFile> _parser = new pb::MessageParser<DebuggedSourceFile>(() => new DebuggedSourceFile()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DebuggedSourceFile> Parser { get { return _parser; } } | public static pb::MessageParser<DebuggedSourceFile> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; } | get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFile() { | public DebuggedSourceFile() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -541,6 +714,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFile(DebuggedSourceFile other) : this() { | public DebuggedSourceFile(DebuggedSourceFile other) : this() { | ||||
host_ = other.host_; | host_ = other.host_; | ||||
filePath_ = other.filePath_; | filePath_ = other.filePath_; | ||||
@@ -551,6 +725,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFile Clone() { | public DebuggedSourceFile Clone() { | ||||
return new DebuggedSourceFile(this); | return new DebuggedSourceFile(this); | ||||
} | } | ||||
@@ -562,6 +737,7 @@ namespace Tensorflow { | |||||
/// The host name on which a source code file is located. | /// The host name on which a source code file is located. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Host { | public string Host { | ||||
get { return host_; } | get { return host_; } | ||||
set { | set { | ||||
@@ -576,6 +752,7 @@ namespace Tensorflow { | |||||
/// Path to the source code file. | /// Path to the source code file. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string FilePath { | public string FilePath { | ||||
get { return filePath_; } | get { return filePath_; } | ||||
set { | set { | ||||
@@ -590,6 +767,7 @@ namespace Tensorflow { | |||||
/// The timestamp at which the source code file is last modified. | /// The timestamp at which the source code file is last modified. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long LastModified { | public long LastModified { | ||||
get { return lastModified_; } | get { return lastModified_; } | ||||
set { | set { | ||||
@@ -604,6 +782,7 @@ namespace Tensorflow { | |||||
/// Byte size of the file. | /// Byte size of the file. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long Bytes { | public long Bytes { | ||||
get { return bytes_; } | get { return bytes_; } | ||||
set { | set { | ||||
@@ -620,16 +799,19 @@ namespace Tensorflow { | |||||
/// Line-by-line content of the source code file. | /// Line-by-line content of the source code file. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<string> Lines { | public pbc::RepeatedField<string> Lines { | ||||
get { return lines_; } | get { return lines_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DebuggedSourceFile); | return Equals(other as DebuggedSourceFile); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DebuggedSourceFile other) { | public bool Equals(DebuggedSourceFile other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -646,6 +828,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (Host.Length != 0) hash ^= Host.GetHashCode(); | if (Host.Length != 0) hash ^= Host.GetHashCode(); | ||||
@@ -660,12 +843,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (Host.Length != 0) { | if (Host.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(Host); | output.WriteString(Host); | ||||
@@ -686,9 +874,38 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (Host.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Host); | |||||
} | |||||
if (FilePath.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(FilePath); | |||||
} | |||||
if (LastModified != 0L) { | |||||
output.WriteRawTag(24); | |||||
output.WriteInt64(LastModified); | |||||
} | |||||
if (Bytes != 0L) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt64(Bytes); | |||||
} | |||||
lines_.WriteTo(ref output, _repeated_lines_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (Host.Length != 0) { | if (Host.Length != 0) { | ||||
@@ -711,6 +928,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DebuggedSourceFile other) { | public void MergeFrom(DebuggedSourceFile other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -732,7 +950,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -761,27 +983,71 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
Host = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
FilePath = input.ReadString(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
LastModified = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
Bytes = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 42: { | |||||
lines_.AddEntriesFrom(ref input, _repeated_lines_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> { | |||||
public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DebuggedSourceFiles> _parser = new pb::MessageParser<DebuggedSourceFiles>(() => new DebuggedSourceFiles()); | private static readonly pb::MessageParser<DebuggedSourceFiles> _parser = new pb::MessageParser<DebuggedSourceFiles>(() => new DebuggedSourceFiles()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DebuggedSourceFiles> Parser { get { return _parser; } } | public static pb::MessageParser<DebuggedSourceFiles> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; } | get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFiles() { | public DebuggedSourceFiles() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -789,12 +1055,14 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFiles(DebuggedSourceFiles other) : this() { | public DebuggedSourceFiles(DebuggedSourceFiles other) : this() { | ||||
sourceFiles_ = other.sourceFiles_.Clone(); | sourceFiles_ = other.sourceFiles_.Clone(); | ||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DebuggedSourceFiles Clone() { | public DebuggedSourceFiles Clone() { | ||||
return new DebuggedSourceFiles(this); | return new DebuggedSourceFiles(this); | ||||
} | } | ||||
@@ -808,16 +1076,19 @@ namespace Tensorflow { | |||||
/// A collection of source code files. | /// A collection of source code files. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.DebuggedSourceFile> SourceFiles { | public pbc::RepeatedField<global::Tensorflow.DebuggedSourceFile> SourceFiles { | ||||
get { return sourceFiles_; } | get { return sourceFiles_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DebuggedSourceFiles); | return Equals(other as DebuggedSourceFiles); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DebuggedSourceFiles other) { | public bool Equals(DebuggedSourceFiles other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -830,6 +1101,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= sourceFiles_.GetHashCode(); | hash ^= sourceFiles_.GetHashCode(); | ||||
@@ -840,19 +1112,37 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec); | sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec); | ||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
sourceFiles_.WriteTo(ref output, _repeated_sourceFiles_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec); | size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec); | ||||
@@ -863,6 +1153,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DebuggedSourceFiles other) { | public void MergeFrom(DebuggedSourceFiles other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -872,7 +1163,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -885,7 +1180,27 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
sourceFiles_.AddEntriesFrom(ref input, _repeated_sourceFiles_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -2,7 +2,7 @@ | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | // Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
// source: tensorflow/core/framework/device_attributes.proto | // source: tensorflow/core/framework/device_attributes.proto | ||||
// </auto-generated> | // </auto-generated> | ||||
#pragma warning disable 1591, 0612, 3021 | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | #region Designer generated code | ||||
using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
@@ -30,44 +30,53 @@ namespace Tensorflow { | |||||
"OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl", | "OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl", | ||||
"cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo", | "cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo", | ||||
"BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm", | "BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm", | ||||
"bG93LkxvY2FsTGlua3MirAEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||||
"bG93LkxvY2FsTGlua3MiwwEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||||
"IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB", | "IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB", | ||||
"KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs", | "KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs", | ||||
"aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k", | "aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k", | ||||
"ZXNjGAcgASgJQpEBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCFkRldmlj", | |||||
"ZUF0dHJpYnV0ZXNQcm90b3NQAVpYZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay9kZXZpY2Vf", | |||||
"YXR0cmlidXRlc19nb19wcm90b/gBAWIGcHJvdG8z")); | |||||
"ZXNjGAcgASgJEhUKDXhsYV9nbG9iYWxfaWQYCCABKANCkQEKGG9yZy50ZW5z", | |||||
"b3JmbG93LmZyYW1ld29ya0IWRGV2aWNlQXR0cmlidXRlc1Byb3Rvc1ABWlhn", | |||||
"aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", | |||||
"L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVzX2dvX3Byb3Rv+AEB", | |||||
"YgZwcm90bzM=")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | ||||
new pbr::FileDescriptor[] { }, | new pbr::FileDescriptor[] { }, | ||||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | ||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null), | new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null), | ||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null), | new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null), | ||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null), | new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null), | ||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc" }, null, null, null, null) | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc", "XlaGlobalId" }, null, null, null, null) | |||||
})); | })); | ||||
} | } | ||||
#endregion | #endregion | ||||
} | } | ||||
#region Messages | #region Messages | ||||
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> { | |||||
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<InterconnectLink> _parser = new pb::MessageParser<InterconnectLink>(() => new InterconnectLink()); | private static readonly pb::MessageParser<InterconnectLink> _parser = new pb::MessageParser<InterconnectLink>(() => new InterconnectLink()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<InterconnectLink> Parser { get { return _parser; } } | public static pb::MessageParser<InterconnectLink> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; } | get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public InterconnectLink() { | public InterconnectLink() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -75,6 +84,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public InterconnectLink(InterconnectLink other) : this() { | public InterconnectLink(InterconnectLink other) : this() { | ||||
deviceId_ = other.deviceId_; | deviceId_ = other.deviceId_; | ||||
type_ = other.type_; | type_ = other.type_; | ||||
@@ -83,6 +93,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public InterconnectLink Clone() { | public InterconnectLink Clone() { | ||||
return new InterconnectLink(this); | return new InterconnectLink(this); | ||||
} | } | ||||
@@ -91,6 +102,7 @@ namespace Tensorflow { | |||||
public const int DeviceIdFieldNumber = 1; | public const int DeviceIdFieldNumber = 1; | ||||
private int deviceId_; | private int deviceId_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int DeviceId { | public int DeviceId { | ||||
get { return deviceId_; } | get { return deviceId_; } | ||||
set { | set { | ||||
@@ -102,6 +114,7 @@ namespace Tensorflow { | |||||
public const int TypeFieldNumber = 2; | public const int TypeFieldNumber = 2; | ||||
private string type_ = ""; | private string type_ = ""; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Type { | public string Type { | ||||
get { return type_; } | get { return type_; } | ||||
set { | set { | ||||
@@ -113,6 +126,7 @@ namespace Tensorflow { | |||||
public const int StrengthFieldNumber = 3; | public const int StrengthFieldNumber = 3; | ||||
private int strength_; | private int strength_; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int Strength { | public int Strength { | ||||
get { return strength_; } | get { return strength_; } | ||||
set { | set { | ||||
@@ -121,11 +135,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as InterconnectLink); | return Equals(other as InterconnectLink); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(InterconnectLink other) { | public bool Equals(InterconnectLink other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -140,6 +156,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (DeviceId != 0) hash ^= DeviceId.GetHashCode(); | if (DeviceId != 0) hash ^= DeviceId.GetHashCode(); | ||||
@@ -152,12 +169,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (DeviceId != 0) { | if (DeviceId != 0) { | ||||
output.WriteRawTag(8); | output.WriteRawTag(8); | ||||
output.WriteInt32(DeviceId); | output.WriteInt32(DeviceId); | ||||
@@ -173,9 +195,33 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (DeviceId != 0) { | |||||
output.WriteRawTag(8); | |||||
output.WriteInt32(DeviceId); | |||||
} | |||||
if (Type.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(Type); | |||||
} | |||||
if (Strength != 0) { | |||||
output.WriteRawTag(24); | |||||
output.WriteInt32(Strength); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (DeviceId != 0) { | if (DeviceId != 0) { | ||||
@@ -194,6 +240,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(InterconnectLink other) { | public void MergeFrom(InterconnectLink other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -211,7 +258,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -232,27 +283,63 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 8: { | |||||
DeviceId = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
Type = input.ReadString(); | |||||
break; | |||||
} | |||||
case 24: { | |||||
Strength = input.ReadInt32(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | } | ||||
public sealed partial class LocalLinks : pb::IMessage<LocalLinks> { | |||||
public sealed partial class LocalLinks : pb::IMessage<LocalLinks> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<LocalLinks> _parser = new pb::MessageParser<LocalLinks>(() => new LocalLinks()); | private static readonly pb::MessageParser<LocalLinks> _parser = new pb::MessageParser<LocalLinks>(() => new LocalLinks()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<LocalLinks> Parser { get { return _parser; } } | public static pb::MessageParser<LocalLinks> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; } | get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public LocalLinks() { | public LocalLinks() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -260,12 +347,14 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public LocalLinks(LocalLinks other) : this() { | public LocalLinks(LocalLinks other) : this() { | ||||
link_ = other.link_.Clone(); | link_ = other.link_.Clone(); | ||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public LocalLinks Clone() { | public LocalLinks Clone() { | ||||
return new LocalLinks(this); | return new LocalLinks(this); | ||||
} | } | ||||
@@ -276,16 +365,19 @@ namespace Tensorflow { | |||||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser); | ||||
private readonly pbc::RepeatedField<global::Tensorflow.InterconnectLink> link_ = new pbc::RepeatedField<global::Tensorflow.InterconnectLink>(); | private readonly pbc::RepeatedField<global::Tensorflow.InterconnectLink> link_ = new pbc::RepeatedField<global::Tensorflow.InterconnectLink>(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public pbc::RepeatedField<global::Tensorflow.InterconnectLink> Link { | public pbc::RepeatedField<global::Tensorflow.InterconnectLink> Link { | ||||
get { return link_; } | get { return link_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as LocalLinks); | return Equals(other as LocalLinks); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(LocalLinks other) { | public bool Equals(LocalLinks other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -298,6 +390,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
hash ^= link_.GetHashCode(); | hash ^= link_.GetHashCode(); | ||||
@@ -308,19 +401,37 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
link_.WriteTo(output, _repeated_link_codec); | link_.WriteTo(output, _repeated_link_codec); | ||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
link_.WriteTo(ref output, _repeated_link_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
size += link_.CalculateSize(_repeated_link_codec); | size += link_.CalculateSize(_repeated_link_codec); | ||||
@@ -331,6 +442,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(LocalLinks other) { | public void MergeFrom(LocalLinks other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -340,7 +452,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -353,27 +469,55 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
link_.AddEntriesFrom(ref input, _repeated_link_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> { | |||||
public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DeviceLocality> _parser = new pb::MessageParser<DeviceLocality>(() => new DeviceLocality()); | private static readonly pb::MessageParser<DeviceLocality> _parser = new pb::MessageParser<DeviceLocality>(() => new DeviceLocality()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DeviceLocality> Parser { get { return _parser; } } | public static pb::MessageParser<DeviceLocality> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; } | get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceLocality() { | public DeviceLocality() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -381,6 +525,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceLocality(DeviceLocality other) : this() { | public DeviceLocality(DeviceLocality other) : this() { | ||||
busId_ = other.busId_; | busId_ = other.busId_; | ||||
numaNode_ = other.numaNode_; | numaNode_ = other.numaNode_; | ||||
@@ -389,6 +534,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceLocality Clone() { | public DeviceLocality Clone() { | ||||
return new DeviceLocality(this); | return new DeviceLocality(this); | ||||
} | } | ||||
@@ -401,6 +547,7 @@ namespace Tensorflow { | |||||
/// no specific locality. Specific localities are indexed from 1. | /// no specific locality. Specific localities are indexed from 1. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int BusId { | public int BusId { | ||||
get { return busId_; } | get { return busId_; } | ||||
set { | set { | ||||
@@ -415,6 +562,7 @@ namespace Tensorflow { | |||||
/// Optional NUMA locality of device. | /// Optional NUMA locality of device. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int NumaNode { | public int NumaNode { | ||||
get { return numaNode_; } | get { return numaNode_; } | ||||
set { | set { | ||||
@@ -429,6 +577,7 @@ namespace Tensorflow { | |||||
/// Optional local interconnect links to other devices. | /// Optional local interconnect links to other devices. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.LocalLinks Links { | public global::Tensorflow.LocalLinks Links { | ||||
get { return links_; } | get { return links_; } | ||||
set { | set { | ||||
@@ -437,11 +586,13 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DeviceLocality); | return Equals(other as DeviceLocality); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DeviceLocality other) { | public bool Equals(DeviceLocality other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -456,6 +607,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (BusId != 0) hash ^= BusId.GetHashCode(); | if (BusId != 0) hash ^= BusId.GetHashCode(); | ||||
@@ -468,12 +620,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (BusId != 0) { | if (BusId != 0) { | ||||
output.WriteRawTag(8); | output.WriteRawTag(8); | ||||
output.WriteInt32(BusId); | output.WriteInt32(BusId); | ||||
@@ -489,9 +646,33 @@ namespace Tensorflow { | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | } | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (BusId != 0) { | |||||
output.WriteRawTag(8); | |||||
output.WriteInt32(BusId); | |||||
} | |||||
if (NumaNode != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(NumaNode); | |||||
} | |||||
if (links_ != null) { | |||||
output.WriteRawTag(26); | |||||
output.WriteMessage(Links); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (BusId != 0) { | if (BusId != 0) { | ||||
@@ -510,6 +691,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DeviceLocality other) { | public void MergeFrom(DeviceLocality other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -530,7 +712,11 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -554,27 +740,66 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 8: { | |||||
BusId = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
NumaNode = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 26: { | |||||
if (links_ == null) { | |||||
Links = new global::Tensorflow.LocalLinks(); | |||||
} | |||||
input.ReadMessage(Links); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
#endif | |||||
} | } | ||||
public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> { | |||||
public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<DeviceAttributes> _parser = new pb::MessageParser<DeviceAttributes>(() => new DeviceAttributes()); | private static readonly pb::MessageParser<DeviceAttributes> _parser = new pb::MessageParser<DeviceAttributes>(() => new DeviceAttributes()); | ||||
private pb::UnknownFieldSet _unknownFields; | private pb::UnknownFieldSet _unknownFields; | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<DeviceAttributes> Parser { get { return _parser; } } | public static pb::MessageParser<DeviceAttributes> Parser { get { return _parser; } } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | public static pbr::MessageDescriptor Descriptor { | ||||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; } | get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | pbr::MessageDescriptor pb::IMessage.Descriptor { | ||||
get { return Descriptor; } | get { return Descriptor; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceAttributes() { | public DeviceAttributes() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
@@ -582,6 +807,7 @@ namespace Tensorflow { | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceAttributes(DeviceAttributes other) : this() { | public DeviceAttributes(DeviceAttributes other) : this() { | ||||
name_ = other.name_; | name_ = other.name_; | ||||
deviceType_ = other.deviceType_; | deviceType_ = other.deviceType_; | ||||
@@ -589,10 +815,12 @@ namespace Tensorflow { | |||||
locality_ = other.locality_ != null ? other.locality_.Clone() : null; | locality_ = other.locality_ != null ? other.locality_.Clone() : null; | ||||
incarnation_ = other.incarnation_; | incarnation_ = other.incarnation_; | ||||
physicalDeviceDesc_ = other.physicalDeviceDesc_; | physicalDeviceDesc_ = other.physicalDeviceDesc_; | ||||
xlaGlobalId_ = other.xlaGlobalId_; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public DeviceAttributes Clone() { | public DeviceAttributes Clone() { | ||||
return new DeviceAttributes(this); | return new DeviceAttributes(this); | ||||
} | } | ||||
@@ -604,6 +832,7 @@ namespace Tensorflow { | |||||
/// Fully specified name of the device within a cluster. | /// Fully specified name of the device within a cluster. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string Name { | public string Name { | ||||
get { return name_; } | get { return name_; } | ||||
set { | set { | ||||
@@ -618,6 +847,7 @@ namespace Tensorflow { | |||||
/// String representation of device_type. | /// String representation of device_type. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string DeviceType { | public string DeviceType { | ||||
get { return deviceType_; } | get { return deviceType_; } | ||||
set { | set { | ||||
@@ -632,6 +862,7 @@ namespace Tensorflow { | |||||
/// Memory capacity of device in bytes. | /// Memory capacity of device in bytes. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long MemoryLimit { | public long MemoryLimit { | ||||
get { return memoryLimit_; } | get { return memoryLimit_; } | ||||
set { | set { | ||||
@@ -647,6 +878,7 @@ namespace Tensorflow { | |||||
/// for supporting efficient data transfers. | /// for supporting efficient data transfers. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Tensorflow.DeviceLocality Locality { | public global::Tensorflow.DeviceLocality Locality { | ||||
get { return locality_; } | get { return locality_; } | ||||
set { | set { | ||||
@@ -662,6 +894,7 @@ namespace Tensorflow { | |||||
/// initialized. "incarnation" should never be 0. | /// initialized. "incarnation" should never be 0. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public ulong Incarnation { | public ulong Incarnation { | ||||
get { return incarnation_; } | get { return incarnation_; } | ||||
set { | set { | ||||
@@ -676,6 +909,7 @@ namespace Tensorflow { | |||||
/// String representation of the physical device that this device maps to. | /// String representation of the physical device that this device maps to. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public string PhysicalDeviceDesc { | public string PhysicalDeviceDesc { | ||||
get { return physicalDeviceDesc_; } | get { return physicalDeviceDesc_; } | ||||
set { | set { | ||||
@@ -683,12 +917,31 @@ namespace Tensorflow { | |||||
} | } | ||||
} | } | ||||
/// <summary>Field number for the "xla_global_id" field.</summary> | |||||
public const int XlaGlobalIdFieldNumber = 8; | |||||
private long xlaGlobalId_; | |||||
/// <summary> | |||||
/// A physical device ID for use in XLA DeviceAssignments, unique across | |||||
/// clients in a multi-client setup. Set to -1 if unavailable, non-negative | |||||
/// otherwise. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public long XlaGlobalId { | |||||
get { return xlaGlobalId_; } | |||||
set { | |||||
xlaGlobalId_ = value; | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | public override bool Equals(object other) { | ||||
return Equals(other as DeviceAttributes); | return Equals(other as DeviceAttributes); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(DeviceAttributes other) { | public bool Equals(DeviceAttributes other) { | ||||
if (ReferenceEquals(other, null)) { | if (ReferenceEquals(other, null)) { | ||||
return false; | return false; | ||||
@@ -702,10 +955,12 @@ namespace Tensorflow { | |||||
if (!object.Equals(Locality, other.Locality)) return false; | if (!object.Equals(Locality, other.Locality)) return false; | ||||
if (Incarnation != other.Incarnation) return false; | if (Incarnation != other.Incarnation) return false; | ||||
if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false; | if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false; | ||||
if (XlaGlobalId != other.XlaGlobalId) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | return Equals(_unknownFields, other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | public override int GetHashCode() { | ||||
int hash = 1; | int hash = 1; | ||||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | if (Name.Length != 0) hash ^= Name.GetHashCode(); | ||||
@@ -714,6 +969,7 @@ namespace Tensorflow { | |||||
if (locality_ != null) hash ^= Locality.GetHashCode(); | if (locality_ != null) hash ^= Locality.GetHashCode(); | ||||
if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); | if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); | ||||
if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode(); | if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode(); | ||||
if (XlaGlobalId != 0L) hash ^= XlaGlobalId.GetHashCode(); | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
hash ^= _unknownFields.GetHashCode(); | hash ^= _unknownFields.GetHashCode(); | ||||
} | } | ||||
@@ -721,12 +977,17 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | public override string ToString() { | ||||
return pb::JsonFormatter.ToDiagnosticString(this); | return pb::JsonFormatter.ToDiagnosticString(this); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
output.WriteRawTag(10); | output.WriteRawTag(10); | ||||
output.WriteString(Name); | output.WriteString(Name); | ||||
@@ -751,12 +1012,56 @@ namespace Tensorflow { | |||||
output.WriteRawTag(58); | output.WriteRawTag(58); | ||||
output.WriteString(PhysicalDeviceDesc); | output.WriteString(PhysicalDeviceDesc); | ||||
} | } | ||||
if (XlaGlobalId != 0L) { | |||||
output.WriteRawTag(64); | |||||
output.WriteInt64(XlaGlobalId); | |||||
} | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
_unknownFields.WriteTo(output); | _unknownFields.WriteTo(output); | ||||
} | } | ||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (Name.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(Name); | |||||
} | |||||
if (DeviceType.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(DeviceType); | |||||
} | |||||
if (MemoryLimit != 0L) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt64(MemoryLimit); | |||||
} | |||||
if (locality_ != null) { | |||||
output.WriteRawTag(42); | |||||
output.WriteMessage(Locality); | |||||
} | |||||
if (Incarnation != 0UL) { | |||||
output.WriteRawTag(49); | |||||
output.WriteFixed64(Incarnation); | |||||
} | |||||
if (PhysicalDeviceDesc.Length != 0) { | |||||
output.WriteRawTag(58); | |||||
output.WriteString(PhysicalDeviceDesc); | |||||
} | |||||
if (XlaGlobalId != 0L) { | |||||
output.WriteRawTag(64); | |||||
output.WriteInt64(XlaGlobalId); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | } | ||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | public int CalculateSize() { | ||||
int size = 0; | int size = 0; | ||||
if (Name.Length != 0) { | if (Name.Length != 0) { | ||||
@@ -777,6 +1082,9 @@ namespace Tensorflow { | |||||
if (PhysicalDeviceDesc.Length != 0) { | if (PhysicalDeviceDesc.Length != 0) { | ||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc); | size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc); | ||||
} | } | ||||
if (XlaGlobalId != 0L) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaGlobalId); | |||||
} | |||||
if (_unknownFields != null) { | if (_unknownFields != null) { | ||||
size += _unknownFields.CalculateSize(); | size += _unknownFields.CalculateSize(); | ||||
} | } | ||||
@@ -784,6 +1092,7 @@ namespace Tensorflow { | |||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(DeviceAttributes other) { | public void MergeFrom(DeviceAttributes other) { | ||||
if (other == null) { | if (other == null) { | ||||
return; | return; | ||||
@@ -809,11 +1118,18 @@ namespace Tensorflow { | |||||
if (other.PhysicalDeviceDesc.Length != 0) { | if (other.PhysicalDeviceDesc.Length != 0) { | ||||
PhysicalDeviceDesc = other.PhysicalDeviceDesc; | PhysicalDeviceDesc = other.PhysicalDeviceDesc; | ||||
} | } | ||||
if (other.XlaGlobalId != 0L) { | |||||
XlaGlobalId = other.XlaGlobalId; | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | uint tag; | ||||
while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
switch(tag) { | switch(tag) { | ||||
@@ -847,9 +1163,60 @@ namespace Tensorflow { | |||||
PhysicalDeviceDesc = input.ReadString(); | PhysicalDeviceDesc = input.ReadString(); | ||||
break; | break; | ||||
} | } | ||||
case 64: { | |||||
XlaGlobalId = input.ReadInt64(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
Name = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
DeviceType = input.ReadString(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
MemoryLimit = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 42: { | |||||
if (locality_ == null) { | |||||
Locality = new global::Tensorflow.DeviceLocality(); | |||||
} | |||||
input.ReadMessage(Locality); | |||||
break; | |||||
} | |||||
case 49: { | |||||
Incarnation = input.ReadFixed64(); | |||||
break; | |||||
} | |||||
case 58: { | |||||
PhysicalDeviceDesc = input.ReadString(); | |||||
break; | |||||
} | |||||
case 64: { | |||||
XlaGlobalId = input.ReadInt64(); | |||||
break; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#endif | |||||
} | } | ||||
@@ -0,0 +1,340 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: tensorflow/compiler/xla/service/cpu/executable.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021, 8981 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Xla.Cpu { | |||||
/// <summary>Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||||
public static partial class ExecutableReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static ExecutableReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"CjR0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS9leGVjdXRh", | |||||
"YmxlLnByb3RvEgd4bGEuY3B1Gjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9z", | |||||
"ZXJ2aWNlL2NwdS94bGFfZnJhbWV3b3JrLnByb3RvGil0ZW5zb3JmbG93L2Nv", | |||||
"bXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90byLXAQocWGxhUnVudGltZUNw", | |||||
"dUV4ZWN1dGFibGVQcm90bxI+ChZ4bGFfcnVudGltZV9leGVjdXRhYmxlGAEg", | |||||
"ASgLMh4ueGxhLlhsYVJ1bnRpbWVFeGVjdXRhYmxlUHJvdG8SQAoVeGxhX2Zy", | |||||
"YW1ld29ya19tYXBwaW5nGAIgASgLMiEueGxhLmNwdS5YbGFGcmFtZXdvcmtN", | |||||
"YXBwaW5nUHJvdG8SNQoRYnVmZmVyX2Fzc2lnbm1lbnQYAyABKAsyGi54bGEu", | |||||
"QnVmZmVyQXNzaWdubWVudFByb3Rv")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { global::Xla.Cpu.XlaFrameworkReflection.Descriptor, global::Xla.HloReflection.Descriptor, }, | |||||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaRuntimeCpuExecutableProto), global::Xla.Cpu.XlaRuntimeCpuExecutableProto.Parser, new[]{ "XlaRuntimeExecutable", "XlaFrameworkMapping", "BufferAssignment" }, null, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
public sealed partial class XlaRuntimeCpuExecutableProto : pb::IMessage<XlaRuntimeCpuExecutableProto> | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
, pb::IBufferMessage | |||||
#endif | |||||
{ | |||||
private static readonly pb::MessageParser<XlaRuntimeCpuExecutableProto> _parser = new pb::MessageParser<XlaRuntimeCpuExecutableProto>(() => new XlaRuntimeCpuExecutableProto()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pb::MessageParser<XlaRuntimeCpuExecutableProto> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Xla.Cpu.ExecutableReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public XlaRuntimeCpuExecutableProto() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public XlaRuntimeCpuExecutableProto(XlaRuntimeCpuExecutableProto other) : this() { | |||||
xlaRuntimeExecutable_ = other.xlaRuntimeExecutable_ != null ? other.xlaRuntimeExecutable_.Clone() : null; | |||||
xlaFrameworkMapping_ = other.xlaFrameworkMapping_ != null ? other.xlaFrameworkMapping_.Clone() : null; | |||||
bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public XlaRuntimeCpuExecutableProto Clone() { | |||||
return new XlaRuntimeCpuExecutableProto(this); | |||||
} | |||||
/// <summary>Field number for the "xla_runtime_executable" field.</summary> | |||||
public const int XlaRuntimeExecutableFieldNumber = 1; | |||||
private global::Xla.XlaRuntimeExecutableProto xlaRuntimeExecutable_; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Xla.XlaRuntimeExecutableProto XlaRuntimeExecutable { | |||||
get { return xlaRuntimeExecutable_; } | |||||
set { | |||||
xlaRuntimeExecutable_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "xla_framework_mapping" field.</summary> | |||||
public const int XlaFrameworkMappingFieldNumber = 2; | |||||
private global::Xla.Cpu.XlaFrameworkMappingProto xlaFrameworkMapping_; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Xla.Cpu.XlaFrameworkMappingProto XlaFrameworkMapping { | |||||
get { return xlaFrameworkMapping_; } | |||||
set { | |||||
xlaFrameworkMapping_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "buffer_assignment" field.</summary> | |||||
public const int BufferAssignmentFieldNumber = 3; | |||||
private global::Xla.BufferAssignmentProto bufferAssignment_; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public global::Xla.BufferAssignmentProto BufferAssignment { | |||||
get { return bufferAssignment_; } | |||||
set { | |||||
bufferAssignment_ = value; | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as XlaRuntimeCpuExecutableProto); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public bool Equals(XlaRuntimeCpuExecutableProto other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (!object.Equals(XlaRuntimeExecutable, other.XlaRuntimeExecutable)) return false; | |||||
if (!object.Equals(XlaFrameworkMapping, other.XlaFrameworkMapping)) return false; | |||||
if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (xlaRuntimeExecutable_ != null) hash ^= XlaRuntimeExecutable.GetHashCode(); | |||||
if (xlaFrameworkMapping_ != null) hash ^= XlaFrameworkMapping.GetHashCode(); | |||||
if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
output.WriteRawMessage(this); | |||||
#else | |||||
if (xlaRuntimeExecutable_ != null) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(XlaRuntimeExecutable); | |||||
} | |||||
if (xlaFrameworkMapping_ != null) { | |||||
output.WriteRawTag(18); | |||||
output.WriteMessage(XlaFrameworkMapping); | |||||
} | |||||
if (bufferAssignment_ != null) { | |||||
output.WriteRawTag(26); | |||||
output.WriteMessage(BufferAssignment); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||||
if (xlaRuntimeExecutable_ != null) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(XlaRuntimeExecutable); | |||||
} | |||||
if (xlaFrameworkMapping_ != null) { | |||||
output.WriteRawTag(18); | |||||
output.WriteMessage(XlaFrameworkMapping); | |||||
} | |||||
if (bufferAssignment_ != null) { | |||||
output.WriteRawTag(26); | |||||
output.WriteMessage(BufferAssignment); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(ref output); | |||||
} | |||||
} | |||||
#endif | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (xlaRuntimeExecutable_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaRuntimeExecutable); | |||||
} | |||||
if (xlaFrameworkMapping_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaFrameworkMapping); | |||||
} | |||||
if (bufferAssignment_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment); | |||||
} | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(XlaRuntimeCpuExecutableProto other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.xlaRuntimeExecutable_ != null) { | |||||
if (xlaRuntimeExecutable_ == null) { | |||||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||||
} | |||||
XlaRuntimeExecutable.MergeFrom(other.XlaRuntimeExecutable); | |||||
} | |||||
if (other.xlaFrameworkMapping_ != null) { | |||||
if (xlaFrameworkMapping_ == null) { | |||||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||||
} | |||||
XlaFrameworkMapping.MergeFrom(other.XlaFrameworkMapping); | |||||
} | |||||
if (other.bufferAssignment_ != null) { | |||||
if (bufferAssignment_ == null) { | |||||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||||
} | |||||
BufferAssignment.MergeFrom(other.BufferAssignment); | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
input.ReadRawMessage(this); | |||||
#else | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
if (xlaRuntimeExecutable_ == null) { | |||||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||||
} | |||||
input.ReadMessage(XlaRuntimeExecutable); | |||||
break; | |||||
} | |||||
case 18: { | |||||
if (xlaFrameworkMapping_ == null) { | |||||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||||
} | |||||
input.ReadMessage(XlaFrameworkMapping); | |||||
break; | |||||
} | |||||
case 26: { | |||||
if (bufferAssignment_ == null) { | |||||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||||
} | |||||
input.ReadMessage(BufferAssignment); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||||
break; | |||||
case 10: { | |||||
if (xlaRuntimeExecutable_ == null) { | |||||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||||
} | |||||
input.ReadMessage(XlaRuntimeExecutable); | |||||
break; | |||||
} | |||||
case 18: { | |||||
if (xlaFrameworkMapping_ == null) { | |||||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||||
} | |||||
input.ReadMessage(XlaFrameworkMapping); | |||||
break; | |||||
} | |||||
case 26: { | |||||
if (bufferAssignment_ == null) { | |||||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||||
} | |||||
input.ReadMessage(BufferAssignment); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |