@@ -1,7 +1,7 @@ | |||
| |||
Microsoft Visual Studio Solution File, Format Version 12.00 | |||
# Visual Studio Version 16 | |||
VisualStudioVersion = 16.0.31624.102 | |||
# Visual Studio Version 17 | |||
VisualStudioVersion = 17.4.33213.308 | |||
MinimumVisualStudioVersion = 10.0.40219.1 | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||
EndProject | |||
@@ -0,0 +1,17 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class c_api | |||
{ | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
} | |||
} |
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -45,6 +46,23 @@ namespace Tensorflow | |||
{ | |||
return as_text(bytes_or_text, encoding); | |||
} | |||
public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||
{ | |||
return bytes; | |||
} | |||
public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||
{ | |||
return ByteString.CopyFrom(bytes); | |||
} | |||
public ByteString as_bytes(string text, Encoding encoding = null) | |||
{ | |||
if(encoding is null) | |||
{ | |||
encoding = Encoding.UTF8; | |||
} | |||
return ByteString.CopyFrom(encoding.GetBytes(text)); | |||
} | |||
} | |||
public bool executing_eagerly() | |||
@@ -54,6 +54,6 @@ namespace Tensorflow | |||
Dictionary<string, Tensor> input_map = null, | |||
string[] return_elements = null, | |||
string name = null, | |||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||
} | |||
} |
@@ -14,6 +14,8 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
@@ -79,5 +81,10 @@ namespace Tensorflow | |||
num_split: num_split, | |||
axis: axis, | |||
name: name); | |||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
{ | |||
return gen_ops.ensure_shape(x, shape, name); | |||
} | |||
} | |||
} |
@@ -61,7 +61,7 @@ namespace Tensorflow | |||
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
/// <summary> | |||
/// Set `num_dims` to -1 to represent "unknown rank". | |||
@@ -22,6 +22,7 @@ using System.ComponentModel; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -107,6 +107,12 @@ namespace Tensorflow | |||
} | |||
} | |||
public void Release() | |||
{ | |||
_handle.Dispose(); | |||
_handle = null; | |||
} | |||
public override string ToString() | |||
=> $"0x{_handle.DangerousGetHandle():x16}"; | |||
@@ -25,5 +25,32 @@ namespace Tensorflow | |||
public IntPtr data; | |||
public ulong length; | |||
public IntPtr data_deallocator; | |||
public unsafe Span<T> AsSpan<T>() where T: unmanaged | |||
{ | |||
if(length > int.MaxValue) | |||
{ | |||
throw new ValueError($"The length {length} is too large to use in the span."); | |||
} | |||
return new Span<T>(data.ToPointer(), (int)length); | |||
} | |||
public unsafe byte[] ToByteArray() | |||
{ | |||
byte[] res = new byte[length]; | |||
if(length > int.MaxValue) | |||
{ | |||
byte* root = (byte*)data; | |||
for(ulong i = 0; i < length; i++) | |||
{ | |||
res[i] = *(root++); | |||
} | |||
} | |||
else | |||
{ | |||
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); | |||
} | |||
return res; | |||
} | |||
} | |||
} |
@@ -161,7 +161,7 @@ public static class CheckPointUtils | |||
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
{ | |||
return full_list.TakeWhile(x => | |||
return full_list.Where(x => | |||
{ | |||
var saveables = x.gather_saveables_for_checkpoint(); | |||
return saveables is not null && saveables.Count > 0; | |||
@@ -1,10 +1,12 @@ | |||
using System; | |||
using OneOf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using Tensorflow.Common.Extensions; | |||
using pbc = global::Google.Protobuf.Collections; | |||
namespace Tensorflow.Checkpoint | |||
@@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||
); | |||
public static class SaveUtil | |||
{ | |||
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
{ | |||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
@@ -104,7 +106,10 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
var td = trackable_data[i]; | |||
Debug.Assert(td.node_id == i); | |||
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||
TrackableObjectGraph.Types.TrackableObject trackable_object = new(); | |||
trackable_object.SlotVariables.AddRange(td.slot_variable_proto); | |||
trackable_object.Children.AddRange(td.children_proto); | |||
object_graph_proto.Nodes.Add(trackable_object); | |||
} | |||
return object_graph_proto; | |||
} | |||
@@ -117,16 +122,16 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="cache"></param> | |||
/// <param name="object_graph_proto"></param> | |||
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
foreach(var td in tensor_trackables) | |||
{ | |||
// TODO: deal with cache. | |||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
Trackable trackable = null; | |||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
{ | |||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
@@ -148,12 +153,12 @@ namespace Tensorflow.Checkpoint | |||
return serialized_tensors; | |||
} | |||
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
var trackable = trackable_data.object_to_save; | |||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||
if (call_with_mapped_captures) | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -163,8 +168,7 @@ namespace Tensorflow.Checkpoint | |||
ret_tensor_dict = trackable.serialize_to_tensors(); | |||
} | |||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
foreach(var pair in ret_tensor_dict) | |||
{ | |||
var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
@@ -173,10 +177,12 @@ namespace Tensorflow.Checkpoint | |||
tensor_dict[checkpoint_key] = maybe_tensor; | |||
if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||
foreach(var key in maybe_tensor.Keys) | |||
{ | |||
throw new NotImplementedException(); | |||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||
{ | |||
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||
} | |||
} | |||
if(object_graph_proto is not null) | |||
@@ -200,7 +206,7 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="object_graph_proto"></param> | |||
/// <returns></returns> | |||
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, string> object_names = new(); | |||
@@ -8,6 +8,7 @@ using Tensorflow.Training; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using static Tensorflow.Binding; | |||
using Google.Protobuf; | |||
using OneOf; | |||
namespace Tensorflow.Checkpoint; | |||
@@ -114,14 +115,10 @@ public static class SaveUtilV1 | |||
{ | |||
var trackable = trackable_objects[i]; | |||
Debug.Assert(node_ids[trackable] == i); | |||
TrackableObjectGraph.Types.TrackableObject object_proto; | |||
var object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||
if (slot_variables.TryGetValue(trackable, out var slots)) | |||
{ | |||
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||
} | |||
else | |||
{ | |||
object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||
object_proto.SlotVariables.AddRange(slots); | |||
} | |||
object_graph_proto.Nodes.Add(object_proto); | |||
foreach (var child in graph_view.list_children(trackable)) | |||
@@ -184,13 +181,13 @@ public static class SaveUtilV1 | |||
// TODO: tensorflow python has a process with callable `saveable_factory`. | |||
List<MySaveableObject> saveables = new(); | |||
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
if (maybe_saveable.TryPickT1(out var s, out var variable)) | |||
{ | |||
saveables.Add(s); | |||
} | |||
else | |||
{ | |||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); | |||
} | |||
foreach (var saveable in saveables) | |||
@@ -222,7 +219,7 @@ public static class SaveUtilV1 | |||
public record class CheckpointFactoryData | |||
( | |||
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||
string name, | |||
string checkpoint_key | |||
); |
@@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using Newtonsoft.Json; | |||
using Tensorflow.Training; | |||
using OneOf; | |||
namespace Tensorflow.Checkpoint; | |||
@@ -44,12 +45,12 @@ public class TrackableSaver | |||
_graph_view = graph_view; | |||
// TODO: cache when not executing eagerly. | |||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||
} | |||
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
{ | |||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
@@ -70,9 +71,10 @@ public class TrackableSaver | |||
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
{ | |||
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||
serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||
} | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||
return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
} | |||
@@ -392,6 +394,7 @@ public class CheckpointRestoreCoordinator | |||
/// </summary> | |||
public List<Trackable> AllTrackables => _all_trackables; | |||
public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
// TODO(Rinne): change to weak ref. | |||
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
public int RestoreUid => _restore_uid; | |||
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
@@ -406,7 +409,7 @@ public class CheckpointRestoreCoordinator | |||
// skip the callback. | |||
} | |||
public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
{ | |||
List<Operation> restore_ops = new(); | |||
foreach(var position in positions) | |||
@@ -418,7 +421,7 @@ public class CheckpointRestoreCoordinator | |||
Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
foreach(var item in tensor_saveables) | |||
{ | |||
if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
if(item.Value.TryPickT0(out var variable, out var _)) | |||
{ | |||
variable_dict[item.Key] = variable; | |||
} | |||
@@ -15,106 +15,14 @@ using Tensorflow.Graphs; | |||
using System.Xml.Linq; | |||
using System.Diagnostics; | |||
using RestoreFunc = System.Func<object, object>; | |||
using OneOf; | |||
namespace Tensorflow.Checkpoint | |||
{ | |||
public class Maybe<TA, TB> | |||
{ | |||
private TA? _valueA = default(TA); | |||
private TB? _valueB = default(TB); | |||
private Type _type; | |||
private bool _assignedTA; | |||
public Maybe(TA value) | |||
{ | |||
_valueA = value; | |||
_type= typeof(TA); | |||
_assignedTA = true; | |||
} | |||
public Maybe(TB value) | |||
{ | |||
_valueB = value; | |||
_type = typeof(TB); | |||
_assignedTA = false; | |||
} | |||
public Type DataType => _type; | |||
/// <summary> | |||
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||
/// It returns | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="res"></param> | |||
/// <returns></returns> | |||
public bool TryGet<T>(out T? res) | |||
{ | |||
if(_valueA is T && _valueB is not T) | |||
{ | |||
res = (T)(object)_valueA; | |||
return _assignedTA; | |||
} | |||
else if(_valueA is not T && _valueB is T) | |||
{ | |||
res = (T)(object)_valueB; | |||
return !_assignedTA; | |||
} | |||
res = default(T); | |||
return false; | |||
} | |||
public bool IsTypeOrDeriveFrom<T>() | |||
{ | |||
if (_valueA is T && _valueB is not T) | |||
{ | |||
return _assignedTA; | |||
} | |||
else if (_valueA is not T && _valueB is T) | |||
{ | |||
return !_assignedTA; | |||
} | |||
else if (_valueA is T && _valueB is T) | |||
{ | |||
return true; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
public T GetValue<T>() | |||
{ | |||
if (_valueA is T && _valueB is not T) | |||
{ | |||
return (T)(object)_valueA; | |||
} | |||
else if (_valueA is not T && _valueB is T) | |||
{ | |||
return (T)(object)_valueB; | |||
} | |||
else if (_valueA is T && _valueB is T) | |||
{ | |||
throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||
} | |||
else | |||
{ | |||
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||
} | |||
} | |||
public static implicit operator Maybe<TA, TB>(TA a) | |||
{ | |||
return new Maybe<TA, TB>(a); | |||
} | |||
public static implicit operator Maybe<TA, TB>(TB b) | |||
{ | |||
return new Maybe<TA, TB>(b); | |||
} | |||
} | |||
internal class SingleDeviceSaver | |||
{ | |||
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||
private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict) | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict; | |||
} | |||
@@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
x => x.Key, x => x.Value.ToDictionary( | |||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value)) | |||
as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
} | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
x => x.Key, x => x.Value.ToDictionary( | |||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value)) | |||
as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
} | |||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||
{ | |||
@@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
var slice_spec = slice.Key; | |||
var maybe_tensor = slice.Value; | |||
if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
{ | |||
var tensor_value = spec.tensor; | |||
if (tensor_value is not null) | |||
@@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint | |||
} | |||
else | |||
{ | |||
var tensor = maybe_tensor.GetValue<Tensor>(); | |||
tensor_names.Add(checkpoint_key); | |||
tensors.Add(tensor); | |||
slice_specs.Add(slice_spec); | |||
@@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint | |||
var slice_spec = slice.Key; | |||
var maybe_tensor = slice.Value; | |||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
{ | |||
tensor_dtypes.Add(spec.dtype); | |||
slice_specs.Add(spec.slice_spec); | |||
@@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint | |||
} | |||
else | |||
{ | |||
var tensor = maybe_tensor.GetValue<Tensor>(); | |||
tensor_dtypes.Add(tensor.dtype); | |||
slice_specs.Add(slice_spec); | |||
tensor_names.Add(checkpoint_key); | |||
@@ -256,12 +162,12 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
/// <param name="registered_savers"></param> | |||
/// <param name="call_with_mapped_capture"></param> | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
{ | |||
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||
foreach(var pair in serialized_tensors) | |||
{ | |||
@@ -276,9 +182,9 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
restore_fn = new RestoreFunc(x => | |||
{ | |||
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||
if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) | |||
{ | |||
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||
return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>); | |||
} | |||
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||
}); | |||
@@ -287,16 +193,7 @@ namespace Tensorflow.Checkpoint | |||
foreach(var item in tensor_dict) | |||
{ | |||
var checkpoint_key = item.Key; | |||
IDictionary<string, Tensor> spec_to_tensor; | |||
if(item.Value.TryGet<Tensor>(out var t)) | |||
{ | |||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||
spec_to_tensor[""] = t; | |||
} | |||
else | |||
{ | |||
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||
} | |||
var spec_to_tensor = item.Value; | |||
foreach(var spec in spec_to_tensor) | |||
{ | |||
@@ -311,12 +208,20 @@ namespace Tensorflow.Checkpoint | |||
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
// skip the process of device name because lack of API. | |||
var host_device = tensor.Device; | |||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||
string host_device; | |||
if (tensor.IsT0) | |||
{ | |||
host_device = tensor.AsT0.Device; | |||
} | |||
else | |||
{ | |||
host_device = tensor.AsT1.device; | |||
} | |||
host_device = saveable_object_util.set_cpu0(host_device); | |||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
} | |||
internal_dict[checkpoint_key][slice_spec] = tensor; | |||
} | |||
@@ -412,7 +317,7 @@ namespace Tensorflow.Checkpoint | |||
IDictionary<string, Operation> restore_func() | |||
{ | |||
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
Dictionary<string, Operation> restore_ops = new(); | |||
@@ -433,29 +338,29 @@ namespace Tensorflow.Checkpoint | |||
var slice_spec = item.Key; | |||
var tensor = item.Value; | |||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
Dictionary<string, Tensor> dict = new(); | |||
dict[slice_spec] = tensor; | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||
} | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||
} | |||
restore_fn_input_count[restore_fn]--; | |||
if (restore_fn_input_count[restore_fn] == 0) | |||
{ | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
foreach (var input in restore_fn_inputs[restore_fn]) | |||
{ | |||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
@@ -538,7 +443,7 @@ namespace Tensorflow.Checkpoint | |||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
{ | |||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
foreach (var saveable in saveables) | |||
{ | |||
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
@@ -1,7 +1,9 @@ | |||
using System; | |||
using OneOf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Security; | |||
using System.Text; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
@@ -49,7 +51,7 @@ public class CheckpointPosition | |||
{ | |||
_checkpoint.AllTrackables.Add(trackable); | |||
_checkpoint.MatchedProtoIds.Add(_proto_id); | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||
{ | |||
// skip the `logging.warning`. | |||
return false; | |||
@@ -61,13 +63,13 @@ public class CheckpointPosition | |||
} | |||
} | |||
public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
{ | |||
// skip the registered_saver | |||
if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
new List<CheckpointPosition>(), null); | |||
} | |||
@@ -75,7 +77,7 @@ public class CheckpointPosition | |||
List<Operation> existing_restore_ops; | |||
List<CheckpointPosition> positions = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
{ | |||
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
@@ -109,8 +111,8 @@ public class CheckpointPosition | |||
/// Creates a saveable using the _serialize_to_tensor method. | |||
/// </summary> | |||
/// <param name="saveable_factories"></param> | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
{ | |||
string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
suffix = suffix ?? ""; | |||
@@ -124,23 +126,23 @@ public class CheckpointPosition | |||
var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
// skip the cache. | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
dict[saveable_name] = saveable; | |||
return (new List<Operation>(), dict); | |||
} | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
{ | |||
// TODO(Rinne): implement it. | |||
if(ObjectProto.Attributes is null) | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>()); | |||
} | |||
List<Operation> existing_restore_ops = new(); | |||
HashSet<string> created_compat_names = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
foreach (var serialized_tensor in ObjectProto.Attributes) | |||
{ | |||
Operation existing_op; | |||
@@ -172,12 +174,12 @@ public class CheckpointPosition | |||
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
continue; | |||
} | |||
named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; | |||
} | |||
return (existing_restore_ops, named_saveables); | |||
} | |||
private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
{ | |||
var expected_factory_name = serialized_tensor.Name; | |||
@@ -221,7 +223,7 @@ public class CheckpointPosition | |||
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
visit_queue.Enqueue((this, this.Trackable)); | |||
List<Operation> restore_ops = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
List<CheckpointPosition> positions = new(); | |||
CheckpointPosition current_position = null; | |||
@@ -306,7 +308,7 @@ public class CheckpointPosition | |||
} | |||
} | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
{ | |||
var trackable = this.Trackable; | |||
trackable._maybe_initialize_trackable(); | |||
@@ -318,7 +320,7 @@ public class CheckpointPosition | |||
} | |||
else | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
new List<CheckpointPosition>(), null); | |||
} | |||
} | |||
@@ -14,9 +14,11 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Common.Extensions; | |||
namespace Tensorflow.Contexts | |||
{ | |||
@@ -25,12 +27,93 @@ namespace Tensorflow.Contexts | |||
/// </summary> | |||
public sealed partial class Context | |||
{ | |||
public ConfigProto Config { get; set; } = new ConfigProto | |||
protected Device.PhysicalDevice[] _physical_devices; | |||
protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index; | |||
ConfigProto _config; | |||
public ConfigProto Config | |||
{ | |||
GpuOptions = new GPUOptions | |||
get | |||
{ | |||
_initialize_physical_devices(); | |||
var config = new ConfigProto(); | |||
if(_config is not null) | |||
{ | |||
config.MergeFrom(_config); | |||
} | |||
config.LogDevicePlacement = _log_device_placement; | |||
config.DeviceCount["CPU"] = 0; | |||
config.DeviceCount["GPU"] = 0; | |||
foreach(var dev in _physical_devices) | |||
{ | |||
if (config.DeviceCount.ContainsKey(dev.DeviceType)) | |||
{ | |||
config.DeviceCount[dev.DeviceType] += 1; | |||
} | |||
else | |||
{ | |||
config.DeviceCount[dev.DeviceType] = 1; | |||
} | |||
} | |||
var gpu_options = _compute_gpu_options(); | |||
config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); | |||
return config; | |||
} | |||
set | |||
{ | |||
_config = value; | |||
} | |||
} | |||
protected void _initialize_physical_devices(bool reinitialize = false) | |||
{ | |||
if(!reinitialize && _physical_devices is not null) | |||
{ | |||
return; | |||
} | |||
var devs = list_physical_devices(); | |||
_physical_devices = devs.Select(d => new Device.PhysicalDevice() | |||
{ | |||
DeviceName = d.DeviceName, | |||
DeviceType = d.DeviceType | |||
}).ToArray(); | |||
_physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i)) | |||
.ToDictionary(x => x.Key, x => x.Value); | |||
_import_config(); | |||
} | |||
protected void _import_config() | |||
{ | |||
if(_config is null) | |||
{ | |||
return; | |||
} | |||
if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) | |||
{ | |||
num_cpus = 1; | |||
} | |||
if(num_cpus != 1) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
}; | |||
var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); | |||
if(gpus.Count() == 0) | |||
{ | |||
return; | |||
} | |||
if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) | |||
{ | |||
gpu_count = 0; | |||
} | |||
// TODO(Rinne): implement it. | |||
} | |||
ConfigProto MergeConfig() | |||
{ | |||
@@ -111,6 +111,14 @@ namespace Tensorflow.Contexts | |||
return results.ToArray(); | |||
} | |||
public bool is_custom_device(string device_name) | |||
{ | |||
return false; | |||
// TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. | |||
//ensure_initialized(); | |||
//return c_api.TFE_IsCustomDevice(_handle, device_name); | |||
} | |||
public EagerDeviceContext device(string name) | |||
{ | |||
return new EagerDeviceContext(this, name); | |||
@@ -37,7 +37,26 @@ namespace Tensorflow.Contexts | |||
public string ScopeName { get; set; } = ""; | |||
bool initialized = false; | |||
ContextSwitchStack context_switches; | |||
public FunctionCallOptions FunctionCallOptions { get; } | |||
protected FunctionCallOptions _function_call_options; | |||
public FunctionCallOptions FunctionCallOptions | |||
{ | |||
get | |||
{ | |||
if(_function_call_options is null) | |||
{ | |||
var config = Config; | |||
_function_call_options = new FunctionCallOptions() | |||
{ | |||
Config = config | |||
}; | |||
} | |||
return _function_call_options; | |||
} | |||
set | |||
{ | |||
_function_call_options = value; | |||
} | |||
} | |||
SafeContextHandle _handle; | |||
@@ -122,6 +141,11 @@ namespace Tensorflow.Contexts | |||
name : | |||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
public string anonymous_name() | |||
{ | |||
return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
} | |||
public void graph_mode(bool isFunc = false) | |||
=> context_switches.Push(false, isFunc); | |||
@@ -158,6 +182,37 @@ namespace Tensorflow.Contexts | |||
return has_graph_arg; | |||
} | |||
public bool has_function(string name) | |||
{ | |||
ensure_initialized(); | |||
return c_api.TFE_ContextHasFunction(_handle, name); | |||
} | |||
public void add_function(SafeFuncGraphHandle fn) | |||
{ | |||
ensure_initialized(); | |||
Status status = new(); | |||
c_api.TFE_ContextAddFunction(_handle, fn, status); | |||
status.Check(true); | |||
} | |||
public void remove_function(string name) | |||
{ | |||
ensure_initialized(); | |||
Status status = new(); | |||
c_api.TFE_ContextRemoveFunction(_handle, name, status); | |||
status.Check(true); | |||
} | |||
public void add_function_def(FunctionDef fdef) | |||
{ | |||
ensure_initialized(); | |||
var fdef_string = fdef.ToByteArray(); | |||
Status status = new Status(); | |||
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); | |||
status.Check(true); | |||
} | |||
public void restore_mode() | |||
{ | |||
context_switches.Pop(); | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Google.Protobuf; | |||
using Protobuf.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Contexts | |||
@@ -9,10 +10,11 @@ namespace Tensorflow.Contexts | |||
public class FunctionCallOptions | |||
{ | |||
public ConfigProto Config { get; set; } | |||
public string ExecutorType { get; set; } | |||
public string config_proto_serialized() | |||
public ByteString config_proto_serialized() | |||
{ | |||
return Config.ToByteString().ToStringUtf8(); | |||
return Config.ToByteString(); | |||
} | |||
} | |||
} |
@@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||
return HasGradientTape(); | |||
} | |||
private bool ShouldRecord(Tensor[] inputs) | |||
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||
{ | |||
bool should_record = false; | |||
foreach (var tape in tf.GetTapeSet()) | |||
var tape_set = tf.GetTapeSet(); | |||
var input_ids = MakeTensorIDList(tensors); | |||
var input_dtypes = MakeTensorDtypeList(tensors); | |||
bool some_tape_watching = false; | |||
if (tape_set is not null && tape_set.Count > 0) | |||
{ | |||
if (tape.ShouldRecord(inputs)) | |||
foreach (var tape in tape_set) | |||
{ | |||
should_record = true; | |||
break; | |||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
{ | |||
if (tape.Persistent || some_tape_watching) | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
} | |||
some_tape_watching = true; | |||
} | |||
} | |||
} | |||
return should_record; | |||
// skip the forward_accumulators. | |||
if (some_tape_watching) | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
} | |||
else | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||
} | |||
} | |||
} | |||
} |
@@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||
Tensor[] results, | |||
BackwardFunction backwardFunction = null) | |||
{ | |||
bool should_record = ShouldRecord(inputs); | |||
var input_ids = MakeTensorIDList(inputs); | |||
var input_dtypes = MakeTensorDtypeList(inputs); | |||
bool should_record = false; | |||
foreach (var tape in tf.GetTapeSet()) | |||
{ | |||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
{ | |||
should_record = true; | |||
break; | |||
} | |||
} | |||
if (!should_record) | |||
{ | |||
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||
op_inputs = inputs;*/ | |||
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | |||
TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||
return true; | |||
} | |||
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||
{ | |||
return HasGradientTape(); | |||
} | |||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
{ | |||
return tensors.Select(x => x.dtype).ToArray(); | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Eager | |||
@@ -358,7 +358,7 @@ namespace Tensorflow.Eager | |||
break; | |||
case TF_AttrType.TF_ATTR_FUNC: | |||
if (value is ConcreteFunction func) | |||
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
else | |||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
break; | |||
@@ -1,6 +1,8 @@ | |||
using System; | |||
using OneOf.Types; | |||
using System; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Eager | |||
{ | |||
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||
/// </summary> | |||
public partial class EagerRunner | |||
{ | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="tape"></param> | |||
/// <param name="target"></param> | |||
/// <param name="sources"></param> | |||
/// <param name="output_gradients"></param> | |||
/// <param name="unconnected_gradients">determines the value returned if the target and | |||
/// sources are unconnected.When 'none' the value returned is None wheras when | |||
/// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||
/// <returns></returns> | |||
/// <exception cref="RuntimeError"></exception> | |||
public Tensor[] TFE_TapeGradient(ITape tape, | |||
Tensor[] target, | |||
Tensor[] sources, | |||
Tensor[] output_gradients) | |||
List<Tensor> output_gradients, | |||
Tensor[] sources_raw, | |||
string unconnected_gradients) | |||
{ | |||
var target_vec = target; | |||
var sources_vec = sources; | |||
var sources_set = sources_vec; | |||
if (!tape.Persistent) | |||
{ | |||
var tape_set = tf.GetTapeSet(); | |||
if (tape_set.Contains(tape)) | |||
{ | |||
throw new RuntimeError("gradient() cannot be invoked within the " + | |||
"GradientTape context (i.e., while operations are being " + | |||
"recorded). Either move the call to gradient() to be " + | |||
"outside the 'with tf.GradientTape' block, or " + | |||
"use a persistent tape: " + | |||
"'with tf.GradientTape(persistent=true)'"); | |||
} | |||
} | |||
var target_vec = MakeTensorIDList(target); | |||
var sources_vec = MakeTensorIDList(sources); | |||
HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||
int len = target.Length; | |||
for(int i = 0; i < len; i++) | |||
{ | |||
var target_id = target_vec[i]; | |||
if (sources_set.Contains(target_id)) | |||
{ | |||
var tensor = target[i]; | |||
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||
} | |||
} | |||
List<Tensor> outgrad_vec = new(); | |||
if(output_gradients is not null) | |||
{ | |||
outgrad_vec = output_gradients.ToList(); | |||
} | |||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
var seq_array = target; | |||
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||
for (int i = 0; i < target.Length; ++i) | |||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
Tensor[] sources_obj = null; | |||
if (unconnected_gradients_zero) | |||
{ | |||
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||
sources_obj = MakeTensorList(sources_raw); | |||
} | |||
if (output_gradients != null) | |||
if (result.Length > 0) | |||
{ | |||
throw new NotImplementedException(""); | |||
for(int i = 0; i < result.Length; i++) | |||
{ | |||
if (result[i] is null && unconnected_gradients_zero) | |||
{ | |||
var dtype = sources_obj[i].dtype; | |||
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||
} | |||
} | |||
} | |||
else | |||
return result; | |||
} | |||
Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||
{ | |||
return tensors.ToArray(); | |||
} | |||
long[] MakeTensorIDList(Tensor[] tensors) | |||
{ | |||
int len = tensors.Length; | |||
long[] ids = new long[len]; | |||
for(int i = 0; i < len; i++) | |||
{ | |||
var tensor = tensors[i]; | |||
ids[i] = tensor.Id; | |||
} | |||
return ids; | |||
} | |||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
{ | |||
int len = tensors.Length; | |||
TF_DataType[] dtypes = new TF_DataType[len]; | |||
for (int i = 0; i < len; i++) | |||
{ | |||
output_gradients = new Tensor[0]; | |||
var tensor = tensors[i]; | |||
dtypes[i] = tensor.dtype; | |||
} | |||
return dtypes; | |||
} | |||
var outgrad_vec = MakeTensorList(output_gradients); | |||
TapeTensor TapeTensorFromTensor(Tensor tensor) | |||
{ | |||
long id = tensor.Id; | |||
var dtype = tensor.dtype; | |||
if (tensor is EagerTensor) | |||
{ | |||
var handle = tensor.EagerTensorHandle; | |||
if (DTypeNeedsHandleData(dtype)) | |||
{ | |||
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||
} | |||
Status status = new(); | |||
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||
long[] dims = new long[num_dims]; | |||
for(int i = 0; i < num_dims; i++) | |||
{ | |||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
} | |||
Shape tensor_shape = new(dims); | |||
if(status.Code != TF_Code.TF_OK) | |||
{ | |||
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||
} | |||
else | |||
{ | |||
return new TapeTensor(id, dtype, tensor_shape); | |||
} | |||
} | |||
var shape_tuple = tensor.shape.dims; | |||
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||
{ | |||
return new TapeTensor(id, dtype, tensor); | |||
} | |||
long[] l = new long[shape_tuple.Length]; | |||
for(int i = 0; i < shape_tuple.Length; i++) | |||
{ | |||
if (shape_tuple[i] < 0) | |||
{ | |||
l[i] = 0; | |||
} | |||
else | |||
{ | |||
l[i] = shape_tuple[i]; | |||
} | |||
} | |||
return new TapeTensor(id, dtype, new Shape(l)); | |||
} | |||
return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||
bool DTypeNeedsHandleData(TF_DataType dtype) | |||
{ | |||
return dtype == dtypes.variant || dtype == dtypes.resource; | |||
} | |||
Tensor[] MakeTensorList(Tensor[] tensors) | |||
bool ListContainNone(long[] list) | |||
{ | |||
return tensors; | |||
int len = list.Length; | |||
if(len == 0) | |||
{ | |||
return true; | |||
} | |||
for(int i = 0; i < len; i++) | |||
{ | |||
if (list[i] == -1) | |||
{ | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
} | |||
} |
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||
public partial class EagerRunner | |||
{ | |||
void TapeSetRecordBackprop(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
TapeTensor[] output_info, | |||
long[] input_ids, | |||
TF_DataType[] input_detyps, | |||
BackwardFunction backward_function) | |||
{ | |||
if (!CouldBackprop()) | |||
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||
foreach (var tape in tf.GetTapeSet()) | |||
{ | |||
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||
} | |||
} | |||
} | |||
@@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||
public bool TapeSetRecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
Tensor[] output_tensors, | |||
long[] input_ids, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function) | |||
{ | |||
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||
var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | |||
backward_function)) | |||
return false; | |||
TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||
backward_function); | |||
return true; | |||
} | |||
public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
Tensor[] input_tensors, BackwardFunction backward_function) | |||
{ | |||
var input_ids = MakeTensorIDList(input_tensors); | |||
var input_dtypes = MakeTensorDtypeList(input_tensors); | |||
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
backward_function); | |||
} | |||
} | |||
} |
@@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||
Tensor[] TFE_TapeGradient(ITape tape, | |||
Tensor[] target, | |||
Tensor[] sources, | |||
Tensor[] output_gradients); | |||
List<Tensor> output_gradients, | |||
Tensor[] sources_raw, | |||
string unconnected_gradients); | |||
void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
Tensor[] input_tensors, BackwardFunction backward_function); | |||
int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||
bool RecordGradient(string op_name, | |||
Tensor[] inputs, | |||
@@ -0,0 +1,53 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow.Eager | |||
{ | |||
internal static class backprop_util | |||
{ | |||
// TODO: add quantized_dtypes (after being supported). | |||
private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[] | |||
{ | |||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, | |||
dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 | |||
}); | |||
public static bool IsTrainable(Tensor tensor) | |||
{ | |||
var dtype = _DTypeFromTensor(tensor); | |||
return _trainable_dtypes.Contains(dtype); | |||
} | |||
public static bool IsTrainable(TF_DataType dtype) | |||
{ | |||
return _trainable_dtypes.Contains(dtype); | |||
} | |||
private static TF_DataType _DTypeFromTensor(Tensor tensor) | |||
{ | |||
var dtype = tensor.dtype; | |||
if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) | |||
{ | |||
CppShapeInferenceResult.Types.HandleData handle_data; | |||
if (tensor is EagerTensor) | |||
{ | |||
handle_data = tensor.HandleData; | |||
} | |||
else | |||
{ | |||
handle_data = handle_data_util.get_resource_handle_data(tensor); | |||
} | |||
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && | |||
handle_data.ShapeAndType.Count > 0) | |||
{ | |||
var first_type = handle_data.ShapeAndType[0].Dtype; | |||
if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) | |||
{ | |||
return first_type.as_tf_dtype(); | |||
} | |||
} | |||
} | |||
return dtype; | |||
} | |||
} | |||
} |
@@ -30,6 +30,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | |||
@@ -277,7 +280,7 @@ namespace Tensorflow | |||
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -480,5 +483,8 @@ namespace Tensorflow | |||
IntPtr[] target, int target_size, | |||
IntPtr[] sources, int source_size, | |||
IntPtr[] outputs, int output_size); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); | |||
} | |||
} |
@@ -18,6 +18,10 @@ namespace Tensorflow.Eager | |||
var types = v.Select(t => t.dtype.as_datatype_enum()); | |||
return (types.ToArray(), v.ToArray()); | |||
} | |||
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
{ | |||
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||
} | |||
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
{ | |||
string device_name = ctx.DeviceName; | |||
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Eager | |||
{ | |||
public class TangentInfo | |||
{ | |||
// TODO(Rinne): implement it. | |||
public object Indices { get; set; } | |||
public object Tangents { get; set; } | |||
} | |||
} |
@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class DictionaryExtension | |||
{ | |||
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second) | |||
{ | |||
first = pair.Key; | |||
second = pair.Value; | |||
} | |||
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other) | |||
{ | |||
foreach(var (key, value) in other) | |||
{ | |||
dic[key] = value; | |||
} | |||
} | |||
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue) | |||
{ | |||
if (dic.ContainsKey(key)) | |||
{ | |||
return dic[key]; | |||
} | |||
return defaultValue; | |||
} | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class NamedTuple | |||
{ | |||
public string Name { get; set; } | |||
public Dictionary<string, object> ValueDict { get; set; } | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using OneOf; | |||
using System; | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class OneofExtension | |||
{ | |||
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src) | |||
{ | |||
return src.Value is T; | |||
} | |||
} | |||
} |
@@ -6,8 +6,11 @@ | |||
public class DenseSpec : TypeSpec | |||
{ | |||
protected Shape _shape; | |||
public Shape shape => _shape; | |||
public Shape shape | |||
{ | |||
get { return _shape; } | |||
set { _shape = value; } | |||
} | |||
protected TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
@@ -1,6 +0,0 @@ | |||
namespace Tensorflow.Framework.Models | |||
{ | |||
class ScopedTFFunction | |||
{ | |||
} | |||
} |
@@ -0,0 +1,22 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Framework | |||
{ | |||
internal class ScopedTFFunction | |||
{ | |||
SafeFuncGraphHandle _handle; | |||
string _name; | |||
public ScopedTFFunction(SafeFuncGraphHandle func, string name) | |||
{ | |||
_handle = func; | |||
_name = name; | |||
} | |||
public SafeFuncGraphHandle Get() | |||
{ | |||
return _handle; | |||
} | |||
} | |||
} |
@@ -111,7 +111,17 @@ namespace Tensorflow | |||
public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | |||
public static Buffer tf_buffer(byte[] data) => new Buffer(data); | |||
public static Buffer tf_buffer(byte[] data = null) | |||
{ | |||
if(data is not null) | |||
{ | |||
return new Buffer(data); ; | |||
} | |||
else | |||
{ | |||
return new Buffer(); | |||
} | |||
} | |||
public static IEnumerable<Operation> new_tf_operations(Graph graph) | |||
{ | |||
@@ -0,0 +1,297 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Security.Cryptography; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Common.Extensions; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
namespace Tensorflow.Framework | |||
{ | |||
public class function_def_lib | |||
{ | |||
// TODO(Rinne): process signatures and structured outputs. | |||
public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||
object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||
{ | |||
var func_graph = new FuncGraph(fdef.Signature.Name); | |||
if(input_shapes is null) | |||
{ | |||
if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||
{ | |||
var raw_input_shapes = input_shapes_attr.List.Shape; | |||
input_shapes = new List<TensorShapeProto>(); | |||
foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||
{ | |||
if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
{ | |||
input_shapes.Add(null); | |||
} | |||
else | |||
{ | |||
input_shapes.Add(input_shape); | |||
} | |||
} | |||
} | |||
} | |||
var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||
func_graph.as_default(); | |||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
// TODO(Rinne): func_graph.ControlOutputs | |||
_set_handle_data(func_graph, fdef); | |||
foreach(var node in graph_def.Node) | |||
{ | |||
if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||
{ | |||
var op = func_graph.get_operation_by_name(node.Name); | |||
foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||
{ | |||
op.outputs[output_index].shape = new Shape(shape); | |||
} | |||
} | |||
} | |||
Dictionary<long, string> output_names = new(); | |||
foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||
{ | |||
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||
} | |||
func_graph._output_names = output_names; | |||
func_graph.Exit(); | |||
return func_graph; | |||
} | |||
public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||
{ | |||
var graph_def = new GraphDef() | |||
{ | |||
Versions = new VersionDef() | |||
{ | |||
Producer = versions.GRAPH_DEF_VERSION, | |||
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
} | |||
}; | |||
var default_graph = ops.get_default_graph(); | |||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||
{ | |||
throw new ValueError($"Length of `input_shapes` must match the number " + | |||
$"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||
$"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||
} | |||
foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||
{ | |||
NodeDef node_def = new(); | |||
node_def.Name = arg_def.Name; | |||
node_def.Op = "Placeholder"; | |||
node_def.Attr["dtype"] = new AttrValue() | |||
{ | |||
Type = arg_def.Type | |||
}; | |||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||
{ | |||
var input_shape = input_shapes[i]; | |||
// skip the condition that input_shape is not `TensorShapeProto`. | |||
AttrValue shape = new AttrValue() | |||
{ | |||
Shape = new TensorShapeProto() | |||
}; | |||
shape.Shape = new TensorShapeProto(input_shape); | |||
node_def.Attr["shape"] = shape; | |||
} | |||
if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||
{ | |||
fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||
} | |||
var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||
foreach(var k in arg_attrs.Keys) | |||
{ | |||
if(k == "_output_shapes") | |||
{ | |||
if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||
} | |||
else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||
{ | |||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||
} | |||
} | |||
else if (k.StartsWith("_")) | |||
{ | |||
if (!node_def.Attr.ContainsKey(k)) | |||
{ | |||
node_def.Attr[k] = new AttrValue(); | |||
} | |||
node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||
} | |||
} | |||
graph_def.Node.Add(node_def); | |||
} | |||
graph_def.Node.AddRange(fdef.NodeDef); | |||
Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||
foreach(var arg_def in fdef.Signature.InputArg) | |||
{ | |||
nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||
string control_name = "^" + arg_def.Name; | |||
nested_to_flat_tensor_name[control_name] = control_name; | |||
} | |||
foreach(var node_def in fdef.NodeDef) | |||
{ | |||
var graph = default_graph; | |||
while (true) | |||
{ | |||
if(graph is null) | |||
{ | |||
break; | |||
} | |||
var f = graph.Functions.GetOrDefault(node_def.Op, null); | |||
if(f is not null && graph.OuterGraph is null) | |||
{ | |||
break; | |||
} | |||
graph = graph.OuterGraph; | |||
} | |||
var op_def = default_graph.GetOpDef(node_def.Op); | |||
foreach(var attr in op_def.Attr) | |||
{ | |||
if(attr.Type == "func") | |||
{ | |||
var fname = node_def.Attr[attr.Name].Func.Name; | |||
if (!is_function(fname)) | |||
{ | |||
throw new ValueError($"Function {fname} was not found. Please make sure " + | |||
$"the FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
else if(attr.Type == "list(func)") | |||
{ | |||
foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||
{ | |||
var fname = fn.Name; | |||
if (!is_function(fname)) | |||
{ | |||
throw new ValueError($"Function {fname} was not found. Please make " + | |||
$"sure the FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
} | |||
} | |||
int flattened_index = 0; | |||
foreach(var arg_def in op_def.OutputArg) | |||
{ | |||
var num_args = _get_num_args(arg_def, node_def); | |||
for(int i = 0; i < num_args; i++) | |||
{ | |||
var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||
var flat_name = $"{node_def.Name}:{flattened_index}"; | |||
nested_to_flat_tensor_name[nested_name] = flat_name; | |||
flattened_index++; | |||
} | |||
} | |||
string control_name = "^" + node_def.Name; | |||
nested_to_flat_tensor_name[control_name] = control_name; | |||
} | |||
foreach(var node_def in graph_def.Node) | |||
{ | |||
for(int i = 0; i < node_def.Input.Count; i++) | |||
{ | |||
node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||
} | |||
} | |||
return (graph_def, nested_to_flat_tensor_name); | |||
} | |||
private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||
{ | |||
foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||
{ | |||
if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
{ | |||
tensor.shape = Shape.Scalar; | |||
var shape_and_type = arg_def.HandleData[0]; | |||
var handle_data = new HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
{ | |||
Shape = shape_and_type.Shape, | |||
Dtype = shape_and_type.Dtype | |||
}); | |||
resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||
} | |||
} | |||
} | |||
private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||
{ | |||
if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||
{ | |||
return node_def.Attr[arg_def.NumberAttr].I; | |||
} | |||
else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||
{ | |||
return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||
} | |||
else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||
{ | |||
return 1; | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||
$"FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
public static bool is_function(string fname) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return tf.Context.has_function(fname); | |||
} | |||
else | |||
{ | |||
var graph = ops.get_default_graph(); | |||
while(graph is not null) | |||
{ | |||
if (graph.IsFunction(fname)) | |||
{ | |||
return true; | |||
} | |||
if(graph.OuterGraph is not null) | |||
{ | |||
graph = graph.OuterGraph; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
} | |||
throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.OpDef.Types; | |||
@@ -25,9 +26,14 @@ namespace Tensorflow | |||
{ | |||
public class importer | |||
{ | |||
public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||
{ | |||
return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||
} | |||
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
Dictionary<string, Tensor> input_map = null, | |||
string[] return_elements = null, | |||
bool validate_colocation_constraints = true, | |||
string name = null, | |||
OpList producer_op_list = null) | |||
{ | |||
@@ -60,7 +66,7 @@ namespace Tensorflow | |||
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
var status = new Status(); | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
status.Check(true); | |||
@@ -107,21 +113,36 @@ namespace Tensorflow | |||
foreach (var new_op in graph._add_new_tf_operations()) | |||
{ | |||
var original_device = new_op.Device; | |||
new_op._set_device(original_device); | |||
} | |||
} | |||
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | |||
string prefix, | |||
Dictionary<string, Tensor> input_map, | |||
string[] return_elements) | |||
string[] return_elements, | |||
bool validate_colocation_constraints) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||
foreach (var input in input_map) | |||
{ | |||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
var input_src = tf.compat.as_str(input.Key); | |||
var input_dst = input.Value; | |||
if (input_src.StartsWith("^")) | |||
{ | |||
var src_name = tf.compat.as_str(input_src.Substring(1)); | |||
var dst_op = input_dst._as_tf_output().oper; | |||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||
} | |||
else | |||
{ | |||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||
src_name = tf.compat.as_str(src_name); | |||
var dst_output = input_dst._as_tf_output(); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||
} | |||
} | |||
if (return_elements == null) | |||
@@ -132,15 +153,16 @@ namespace Tensorflow | |||
if (name.Contains(":")) | |||
{ | |||
var (op_name, index) = _ParseTensorName(name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
op_name = tf.compat.as_str(op_name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||
} | |||
else | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||
} | |||
} | |||
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||
} | |||
private static (string, int) _ParseTensorName(string tensor_name) | |||
@@ -173,6 +195,14 @@ namespace Tensorflow | |||
return graph_def; | |||
} | |||
private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||
{ | |||
var old_graph_def = graph_def; | |||
graph_def = new GraphDef(old_graph_def); | |||
return graph_def; | |||
} | |||
private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | |||
{ | |||
foreach (var attr_def in op_def.Attr) | |||
@@ -240,6 +270,35 @@ namespace Tensorflow | |||
} | |||
} | |||
private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||
{ | |||
var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||
foreach (var node in graph_def.Node) | |||
{ | |||
// Remove any default attr values that aren't in op_def. | |||
if (producer_op_dict.ContainsKey(node.Op)) | |||
{ | |||
var op_def = op_def_registry.GetOpDef(node.Op); | |||
if(op_def is null) | |||
{ | |||
continue; | |||
} | |||
var producer_op_def = producer_op_dict[node.Op]; | |||
foreach (var key in node.Attr.Keys) | |||
{ | |||
if (_FindAttrInOpDef(key, op_def) is null) | |||
{ | |||
var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||
if (attr_def != null && attr_def.DefaultValue != null && | |||
node.Attr[key] == attr_def.DefaultValue) | |||
node.Attr[key].ClearValue(); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | |||
{ | |||
return op_def.Attr.FirstOrDefault(x => x.Name == name); | |||
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Framework | |||
{ | |||
public class versions | |||
{ | |||
public static int GRAPH_DEF_VERSION = 1286; | |||
public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||
} | |||
} |
@@ -1,9 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Train; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
@@ -13,29 +17,46 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
public class ConcreteFunction: Trackable | |||
{ | |||
protected IEnumerable<Tensor> _captured_inputs; | |||
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
protected Dictionary<string, AttrValue> _attrs; | |||
protected FunctionSpec _function_spec; | |||
protected FunctionSpec _pre_initialized_function_spec = null; | |||
protected EagerDefinedFunction _inference_function; | |||
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||
internal FuncGraph func_graph; | |||
internal ForwardBackwardCall forward_backward; | |||
public Tensor[] Inputs => func_graph.Inputs; | |||
public Tensor[] CapturedInputs => func_graph.external_captures; | |||
public string Name => func_graph?.FuncName; | |||
public string Name => _delayed_rewrite_functions.Forward().Name; | |||
public Tensor[] Outputs; | |||
public Tensor[] Outputs => func_graph.Outputs; | |||
public Type ReturnType; | |||
public TensorSpec[] OutputStructure; | |||
public IEnumerable<string> ArgKeywords { get; set; } | |||
public long NumPositionArgs { get; set; } | |||
public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; | |||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
public ConcreteFunction(string name) | |||
{ | |||
func_graph = new FuncGraph(name); | |||
_captured_inputs = func_graph.external_captures; | |||
_attrs= new Dictionary<string, AttrValue>(); | |||
_set_infer_function(); | |||
} | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null) | |||
{ | |||
func_graph = graph; | |||
_captured_inputs = func_graph.external_captures; | |||
ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
_attrs = attrs; | |||
_set_infer_function(); | |||
} | |||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
@@ -53,6 +74,9 @@ namespace Tensorflow.Functions | |||
new[] { output }, | |||
null); | |||
func_graph.Exit(); | |||
_captured_inputs = func_graph.external_captures; | |||
_attrs = new Dictionary<string, AttrValue>(); | |||
_set_infer_function(); | |||
} | |||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
@@ -73,6 +97,9 @@ namespace Tensorflow.Functions | |||
new[] { output.variant_tensor }, | |||
null); | |||
func_graph.Exit(); | |||
_captured_inputs = func_graph.external_captures; | |||
_attrs = new Dictionary<string, AttrValue>(); | |||
_set_infer_function(); | |||
} | |||
/*public ConcreteFunction(Func<Tensors, Tensors> func, | |||
@@ -130,39 +157,56 @@ namespace Tensorflow.Functions | |||
{ | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
var default_graph = ops.get_default_graph(); | |||
// TODO(Rinne): deal with `default_graph.building_function` | |||
var tempvv = func_graph.Variables; | |||
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||
{ | |||
foreach(var v in this.func_graph.Variables) | |||
{ | |||
resource_variable_ops.variable_accessed(v); | |||
} | |||
} | |||
var tensor_inputs = new Tensors(); | |||
foreach (var (i, arg) in enumerate(args)) | |||
{ | |||
tensor_inputs.Add(arg); | |||
// If we're graph building, shape inference is on. | |||
if (!executing_eagerly) | |||
{ | |||
} | |||
} | |||
tensor_inputs.AddRange(captured_inputs); | |||
if (!executing_eagerly) | |||
{ | |||
// TODO(Rinne): add the check | |||
} | |||
tensor_inputs.AddRange(captured_inputs); | |||
args = tensor_inputs.ToArray(); | |||
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||
var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); | |||
// No tape is watching; skip to running the function. | |||
if (possible_gradient_type == 0 && executing_eagerly) | |||
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
}; | |||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
return _build_call_outputs(_inference_function.Call(args)); | |||
} | |||
if (forward_backward == null) | |||
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
Tensors flat_outputs = null; | |||
if (executing_eagerly) | |||
{ | |||
flat_outputs = forward_function.Call(args_with_tangents); | |||
} | |||
else | |||
{ | |||
tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){ | |||
{ "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } | |||
}), _ => | |||
{ | |||
flat_outputs = forward_function.Call(args_with_tangents); | |||
}); | |||
} | |||
forward_backward.Record(flat_outputs); | |||
return flat_outputs; | |||
return _build_call_outputs(flat_outputs); | |||
} | |||
public void AddTograph(Graph? g = null) | |||
@@ -171,13 +215,99 @@ namespace Tensorflow.Functions | |||
{ | |||
g = ops.get_default_graph(); | |||
} | |||
// TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
_delayed_rewrite_functions.Forward().AddToGraph(g); | |||
} | |||
public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
{ | |||
_captured_inputs = captures; | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
TangentInfo input_tangents; | |||
if (executing_eagerly) | |||
{ | |||
// TODO(Rinne): check if it needs to be implemented. | |||
input_tangents = new TangentInfo(); | |||
} | |||
else | |||
{ | |||
input_tangents = new TangentInfo(); | |||
} | |||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||
{ | |||
if(input_tangents.Indices is not null || executing_eagerly) | |||
{ | |||
string cache_key = "first_order"; | |||
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||
{ | |||
functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
_tape_functions_cache[cache_key] = functions; | |||
} | |||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
} | |||
else | |||
{ | |||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); | |||
} | |||
} | |||
else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
} | |||
internal void set_variables(IEnumerable<IVariableV1> variables) | |||
{ | |||
func_graph.Variables = variables; | |||
} | |||
internal void _set_infer_function() | |||
{ | |||
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
_inference_function = _delayed_rewrite_functions.Forward(); | |||
} | |||
internal void _set_function_spec(FunctionSpec spec) | |||
{ | |||
_function_spec = null; | |||
_pre_initialized_function_spec = spec; | |||
_initialize_function_spec(); | |||
} | |||
internal void _initialize_function_spec() | |||
{ | |||
if(_pre_initialized_function_spec is null) | |||
{ | |||
return; | |||
} | |||
Debug.Assert(_function_spec is null, "already initialized"); | |||
var spec = _pre_initialized_function_spec; | |||
//var args = spec.Fullargspec.DictValue.Fields["args"]; | |||
// TODO(Rinne): self.structured_input_signature | |||
_function_spec = new FunctionSpec() | |||
{ | |||
Fullargspec = spec.Fullargspec, | |||
IsMethod = spec.IsMethod, | |||
InputSignature = spec.InputSignature | |||
}; | |||
} | |||
internal Func<Operation, object[], Tensor[]> _get_gradient_function() | |||
{ | |||
return _delayed_rewrite_functions._rewrite_forward_and_call_backward; | |||
} | |||
private Tensors _build_call_outputs(Tensors result) | |||
{ | |||
// TODO(Rinne): deal with `func_graph.structured_outputs` | |||
return result; | |||
} | |||
public override string ToString() | |||
@@ -1,50 +1,232 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
using Tensorflow.Common.Extensions; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Framework; | |||
using System.Buffers; | |||
using Tensorflow.Gradients; | |||
namespace Tensorflow.Functions | |||
{ | |||
public class EagerDefinedFunction | |||
public class EagerDefinedFunction: IDisposable | |||
{ | |||
public int _num_outputs; | |||
public string Name => _func_graph.FuncName; | |||
FuncGraph _graph; | |||
FunctionDef _definition; | |||
OpDef _signature; | |||
string _name; | |||
internal ScopedTFFunction _c_func; | |||
internal Tensor[] _func_graph_outputs; | |||
internal string _grad_func_name; | |||
internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func; | |||
internal EagerDefinedFunction _grad_func; | |||
internal bool _registered_on_context = false; | |||
public string Name => _name; | |||
public DataType[] OutputTypes { get; protected set; } | |||
public Shape[] OutputShapes { get; protected set; } | |||
public FunctionDef Definition | |||
{ | |||
get | |||
{ | |||
if(_definition is null) | |||
{ | |||
_definition = _get_definition(); | |||
} | |||
return _definition; | |||
} | |||
} | |||
FuncGraph _func_graph; | |||
public EagerDefinedFunction(string name, FuncGraph graph, | |||
public OpDef Signature | |||
{ | |||
get | |||
{ | |||
if( _signature is null) | |||
{ | |||
_signature = Definition.Signature; | |||
} | |||
return _signature; | |||
} | |||
} | |||
public unsafe EagerDefinedFunction(string name, FuncGraph graph, | |||
Tensors inputs, Tensors outputs, | |||
Dictionary<string, string> attrs) | |||
Dictionary<string, AttrValue> attrs) | |||
{ | |||
_num_outputs = outputs.Length; | |||
var input_ops = inputs.Select(x => x.op).ToArray(); | |||
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
.Select(x => x as Operation).ToArray(); | |||
var output_names = new string[0]; | |||
var graph_output_names = graph._output_names; | |||
string[] output_names; | |||
if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) | |||
{ | |||
output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); | |||
if(output_names.Distinct().Count() != output_names.Length) | |||
{ | |||
output_names = new string[0]; | |||
} | |||
} | |||
else | |||
{ | |||
output_names = new string[0]; | |||
} | |||
_func_graph = new FuncGraph(graph, name, attrs); | |||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
Status status = new Status(); | |||
var fn = c_api.TF_GraphToFunction(graph.c_graph, | |||
name, | |||
false, | |||
operations.Length, | |||
operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), | |||
inputs.Length, | |||
inputs.Select(t => t._as_tf_output()).ToArray(), | |||
outputs.Length, | |||
outputs.Select(t => t._as_tf_output()).ToArray(), | |||
output_names.Length != outputs.Length ? null : output_names, | |||
IntPtr.Zero, // warning: the control output hasbben totally ignored. | |||
null, | |||
status); | |||
status.Check(true); | |||
_c_func = new ScopedTFFunction(fn, name); | |||
foreach(var (attr_name, attr_value) in attrs) | |||
{ | |||
var serialized = attr_value.ToByteArray(); | |||
c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); | |||
status.Check(true); | |||
} | |||
var signature = _get_definition().Signature; | |||
_name = signature.Name; | |||
tf_with(ops.init_scope(), s => | |||
{ | |||
tf.Context.add_function(fn); | |||
_registered_on_context = true; | |||
}); | |||
_num_outputs = signature.OutputArg.Count; | |||
OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); | |||
OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||
_func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||
csharp_grad_func = null; | |||
_graph = graph; | |||
} | |||
public Tensors Call(Tensors args) | |||
public unsafe Tensors Call(Tensors args) | |||
{ | |||
// TODO(Rinne): Add arg `CancellationManager`. | |||
// TODO(Rinne): Check the arg length. | |||
var function_call_options = tf.Context.FunctionCallOptions; | |||
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||
//if (function_call_options.config_proto_serialized().Length == 0) | |||
//{ | |||
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
//} | |||
//else | |||
//{ | |||
// config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||
//} | |||
string executor_type = function_call_options.ExecutorType ?? ""; | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
var attrs = new object[] | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
"executor_type", executor_type, | |||
"config_proto", config | |||
}; | |||
var results = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
_func_graph.FuncName, | |||
args, | |||
attrs, | |||
_num_outputs); | |||
Tensor[] outputs; | |||
if (executing_eagerly) | |||
{ | |||
outputs = execute.executes( | |||
Signature.Name, | |||
_num_outputs, | |||
args, | |||
attrs, | |||
tf.Context); | |||
} | |||
else | |||
{ | |||
if(tf.GetTapeSet().Count == 0) | |||
{ | |||
outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
executing_eagerly, config, ""); | |||
} | |||
else | |||
{ | |||
var tape = tf.GetTapeSet().Peek(); | |||
tape.StopRecord(); | |||
outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
executing_eagerly, config, ""); | |||
tape.StartRecord(); | |||
} | |||
} | |||
foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||
{ | |||
handle_data_util.copy_handle_data(func_graph_output, outputs[i]); | |||
} | |||
if (executing_eagerly) | |||
{ | |||
return outputs; | |||
} | |||
else | |||
{ | |||
foreach(var (i, shape) in enumerate(OutputShapes)) | |||
{ | |||
outputs[i].shape = shape; | |||
} | |||
return outputs; | |||
} | |||
} | |||
public void AddToGraph(Graph g = null) | |||
{ | |||
if(g is null && tf.Context.executing_eagerly()) | |||
{ | |||
var ctx = tf.Context; | |||
if (!ctx.has_function(this.Name)) | |||
{ | |||
ctx.add_function_def(Definition); | |||
} | |||
} | |||
else | |||
{ | |||
if (!g.IsFunction(Name)) | |||
{ | |||
g.AddFunction(this); | |||
} | |||
foreach(var f in _graph.Functions.Values) | |||
{ | |||
if (!g.IsFunction(f.Name)) | |||
{ | |||
g.AddFunction(f); | |||
} | |||
} | |||
} | |||
} | |||
return results; | |||
private FunctionDef _get_definition() | |||
{ | |||
var buffer = c_api_util.tf_buffer(); | |||
Status status = new(); | |||
c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); | |||
status.Check(true); | |||
var proto_data = c_api.TF_GetBuffer(buffer); | |||
return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>()); | |||
} | |||
public void Dispose() | |||
{ | |||
tf.Context.remove_function(Name); | |||
} | |||
} | |||
} |
@@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||
} | |||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
var outputs = _func_graph.Outputs; | |||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
= BuildFunctionsForOutputs(outputs, inference_args); | |||
return _forward; | |||
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||
return BuildFunctionsForOutputs(outputs, inference_args); | |||
} | |||
} | |||
} |
@@ -1,23 +1,84 @@ | |||
using System; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Train; | |||
namespace Tensorflow | |||
{ | |||
public class Function: Trackable | |||
public class Function: Trackable, IGenericFunction | |||
{ | |||
#pragma warning disable CS0169 // The field 'Function._handle' is never used | |||
private IntPtr _handle; | |||
#pragma warning restore CS0169 // The field 'Function._handle' is never used | |||
protected Func<Tensor[], Tensor[]> _csharp_function; | |||
protected ConcreteFunction _concrete_variable_creation_fn; | |||
protected bool _autograph; | |||
protected TracingCompiler _variable_creation_fn; | |||
public string Name { get; set; } | |||
public Function() | |||
public Function(Func<Tensor[], Tensor[]> csharp_function, | |||
string name, bool auto_graph = true) | |||
{ | |||
_csharp_function = csharp_function; | |||
Name = name; | |||
_autograph = auto_graph; | |||
} | |||
public virtual Tensors Apply(Tensors inputs) | |||
{ | |||
if (_run_functions_eagerly()) | |||
{ | |||
return _csharp_function(inputs); | |||
} | |||
var result = _call(inputs); | |||
return result; | |||
} | |||
public ConcreteFunction get_concrete_function(params Tensor[] args) | |||
{ | |||
return _get_concrete_function_garbage_collected(args); | |||
} | |||
public Function(string name) | |||
protected virtual Tensors _call(Tensors inputs) | |||
{ | |||
Name = name; | |||
if(_variable_creation_fn is not null) | |||
{ | |||
return _variable_creation_fn.Apply(inputs); | |||
} | |||
_initialize(inputs); | |||
return _concrete_variable_creation_fn.CallFlat(inputs, | |||
_concrete_variable_creation_fn.CapturedInputs); | |||
} | |||
protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn) | |||
{ | |||
var name = nameof(fn); | |||
return new TracingCompiler(fn, name, autograph: _autograph); | |||
} | |||
protected virtual bool _run_functions_eagerly() | |||
{ | |||
return false; | |||
} | |||
protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) | |||
{ | |||
if(_variable_creation_fn is null) | |||
{ | |||
_initialize(args); | |||
// TODO(Rinne): _initialize_uninitialized_variables | |||
} | |||
var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
return concrete; | |||
} | |||
private void _initialize(Tensor[] args) | |||
{ | |||
_variable_creation_fn = _compiler(_csharp_function); | |||
_variable_creation_fn._name = this.Name; | |||
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
} | |||
} | |||
} |
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Functions | |||
{ | |||
public interface IGenericFunction | |||
{ | |||
Tensors Apply(Tensors args); | |||
ConcreteFunction get_concrete_function(params Tensor[] args); | |||
} | |||
} |
@@ -3,8 +3,10 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.tensorflow; | |||
@@ -15,17 +17,21 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
public abstract class TapeGradientFunctions | |||
{ | |||
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
string _FORWARD_PREFIX = "__forward_"; | |||
string _BACKWARD_PREFIX = "__backward_"; | |||
string _INFERENCE_PREFIX = "__inference_"; | |||
protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
protected string _FORWARD_PREFIX = "__forward_"; | |||
protected string _BACKWARD_PREFIX = "__backward_"; | |||
protected string _INFERENCE_PREFIX = "__inference_"; | |||
protected FuncGraph _func_graph; | |||
protected EagerDefinedFunction _forward; | |||
protected FuncGraph _forward_graph; | |||
protected List<int> _forwardprop_input_indices; | |||
protected List<int> _forwardprop_output_indices; | |||
protected int _num_forwardprop_outputs; | |||
protected int _num_inference_outputs; | |||
protected int _num_outputs; | |||
protected int _num_trainable_inference_outputs; | |||
protected ConcreteFunction _backward; | |||
BackwardFunction _backward_function_wrapper; | |||
@@ -33,11 +39,25 @@ namespace Tensorflow.Functions | |||
bool need_gradients_for_jvps) | |||
{ | |||
_func_graph = func_graph; | |||
_forward_graph = null; | |||
_forward = null; | |||
_backward = null; | |||
_num_outputs = func_graph.Outputs.Length; | |||
_forwardprop_output_indices = null; | |||
_num_forwardprop_outputs = 0; | |||
_num_inference_outputs = func_graph.Outputs.Length; | |||
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||
} | |||
public EagerDefinedFunction Forward(Tensors inference_args) | |||
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||
{ | |||
return ForwardAndBackwardFunctions(inference_args); | |||
// TODO(Rinne): add input_tangents arg. | |||
if(_forward is null) | |||
{ | |||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
= ForwardAndBackwardFunctions(inference_args); | |||
} | |||
return _forward; | |||
} | |||
/// <summary> | |||
@@ -45,11 +65,16 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
/// <param name="flat_outputs"></param> | |||
/// <param name="inference_args"></param> | |||
public void Record(Tensors flat_outputs, Tensors inference_args) | |||
public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
{ | |||
// TODO(Rinne): add arg `input_tagents`. | |||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | |||
getBackwardFunction: backward_function); | |||
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
} | |||
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||
} | |||
/// <summary> | |||
@@ -61,66 +86,95 @@ namespace Tensorflow.Functions | |||
/// <returns></returns> | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
{ | |||
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||
.ToDictionary(x => x.Item1, x => x.Item2); | |||
var captured_inputs = backward.CapturedInputs; | |||
var remapped_captures = captured_inputs.Select(c => | |||
{ | |||
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||
{ | |||
return value; | |||
} | |||
else | |||
{ | |||
return c; | |||
} | |||
}).ToArray(); | |||
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||
{ | |||
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||
throw new RuntimeError($"Failed to map all backward graph captures to " + | |||
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||
} | |||
Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
var recorded_outputs = new Tensors(); | |||
var trainable_recorded_outputs = 0; | |||
foreach (var (output_index, output) in enumerate(outputs)) | |||
int trainable_recorded_outputs = 0; | |||
var skip_positions = new HashSet<int>(); | |||
var relevant_outputs = outputs; | |||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
{ | |||
if (trainable_recorded_outputs < backward_function_inputs) | |||
recorded_outputs.Add(output); | |||
if (gradients_util.IsTrainable(output)) | |||
trainable_recorded_outputs += 1; | |||
if (backprop_util.IsTrainable(output)) | |||
trainable_recorded_outputs++; | |||
else | |||
skip_positions.Add(output_index); | |||
if (output.dtype == dtypes.variant) | |||
variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||
} | |||
if(_backward_function_wrapper == null) | |||
_backward_function_wrapper = (args, unneeded_gradients) => | |||
{ | |||
var capture_mapping = new Dictionary<long, Tensor>(); | |||
foreach (var (i, output) in enumerate(outputs)) | |||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
var remapped_captures = new Tensors(); | |||
foreach (var capture in backward.CapturedInputs) | |||
{ | |||
if (capture_mapping.ContainsKey(capture.Id)) | |||
remapped_captures.Add(capture_mapping[capture.Id]); | |||
} | |||
var skip_positions = new List<int>(); | |||
foreach (var (output_index, output) in enumerate(outputs)) | |||
if(backward.Outputs is null || backward.Outputs.Length == 0) | |||
{ | |||
if (!gradients_util.IsTrainable(output)) | |||
skip_positions.Add(output_index); | |||
return backward.FlatStructuredOutputs; | |||
} | |||
_backward_function_wrapper = (args, unneeded_gradients) => | |||
var processed_args = new Tensors(); | |||
int input_index = 0; | |||
foreach (var (output_index, arg) in enumerate(args)) | |||
{ | |||
var processed_args = new Tensors(); | |||
var input_index = 0; | |||
foreach (var (output_index, arg) in enumerate(args)) | |||
if (skip_positions.Contains(output_index)) | |||
continue; | |||
if (arg is null) | |||
{ | |||
var input_placeholder = backward.Inputs[input_index]; | |||
Tensor variant_arg; | |||
if (input_placeholder.dtype == dtypes.variant) | |||
{ | |||
variant_arg = variant_zeros_like[output_index]; | |||
} | |||
else | |||
{ | |||
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||
variant_arg = array_ops.zeros(shape, type); | |||
} | |||
processed_args.Add(variant_arg); | |||
} | |||
else | |||
{ | |||
if (skip_positions.Contains(output_index)) | |||
continue; | |||
if (arg == null) | |||
throw new NotImplementedException(""); | |||
processed_args.Add(arg); | |||
input_index += 1; | |||
if (input_index >= backward_function_inputs) | |||
break; | |||
} | |||
input_index++; | |||
if (input_index >= backward_function_inputs) | |||
break; | |||
} | |||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||
{ | |||
var index = Convert.ToInt32(unneeded_gradient_index); | |||
if (gradients.Length <= index) | |||
gradients.Insert(index, null); | |||
} | |||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||
{ | |||
var index = Convert.ToInt32(unneeded_gradient_index); | |||
if (gradients.Length <= index) | |||
gradients.Insert(index, null); | |||
} | |||
return gradients; | |||
}; | |||
} | |||
return gradients; | |||
}; | |||
return (_backward_function_wrapper, recorded_outputs); | |||
} | |||
@@ -132,51 +186,66 @@ namespace Tensorflow.Functions | |||
var trainable_indices = new List<int>(); | |||
foreach(var (index, output) in enumerate(outputs)) | |||
{ | |||
if (gradients_util.IsTrainable(output)) | |||
if (backprop_util.IsTrainable(output)) | |||
{ | |||
trainable_outputs.Add(output); | |||
trainable_indices.Add(index); | |||
} | |||
} | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
backwards_graph.as_default(); | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
foreach (var output in trainable_outputs) | |||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
{ | |||
var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); | |||
var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); | |||
gradients_wrt_outputs.Add(gradient_placeholder); | |||
handle_data_util.copy_handle_data(output, gradient_placeholder); | |||
} | |||
// TODO(Rinne): with ops.device(None) | |||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
_func_graph.Inputs, | |||
grad_ys: gradients_wrt_outputs.ToArray(), | |||
src_graph: _func_graph); | |||
_func_graph.Inputs, | |||
grad_ys: gradients_wrt_outputs.ToArray(), | |||
src_graph: _func_graph); | |||
var captures_from_forward = backwards_graph.external_captures | |||
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | |||
.ToArray(); | |||
HashSet<Tensor> existing_outputs = new(_func_graph.Outputs); | |||
foreach(var capture in captures_from_forward) | |||
{ | |||
if (!_func_graph.Outputs.Contains(capture)) | |||
if (!existing_outputs.Contains(capture)) | |||
{ | |||
existing_outputs.Add(capture); | |||
_func_graph.Outputs.Add(capture); | |||
} | |||
} | |||
backwards_graph.Exit(); | |||
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
var backward_function_attr = new Dictionary<string, string>(); | |||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
backwards_graph.Inputs = gradients_wrt_outputs; | |||
backwards_graph.Outputs = gradients_wrt_inputs; | |||
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||
var (wrapped_forward_function, wrapped_backward_function) = | |||
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
//var backward_function_attr = new Dictionary<string, string>(); | |||
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
//var backward_function = new ConcreteFunction(backwards_graph, | |||
// monomorphic_function_utils._parse_func_attrs(backward_function_attr)); | |||
var forward_function_attr = new Dictionary<string, string>(); | |||
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||
//var forward_function_attr = new Dictionary<string, string>(); | |||
//forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
//var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
// _func_graph.Inputs, _func_graph.Outputs, | |||
// monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||
return (forward_function, _func_graph, backward_function, null, 0); | |||
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||
} | |||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -0,0 +1,84 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Security.Cryptography.X509Certificates; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Functions | |||
{ | |||
public class TracingCompiler | |||
{ | |||
Func<Tensor[], Tensor[]> _csharp_function; | |||
//FunctionSpec _function_spec; | |||
internal string _name; | |||
bool _autograph; | |||
Dictionary<string, ConcreteFunction> _function_cache; | |||
Dictionary<string, AttrValue> _function_attributes; | |||
int _tracing_count; | |||
public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null, | |||
Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null, | |||
bool reduce_retracing = false, bool capture_by_value = false) | |||
{ | |||
_csharp_function = csharp_function; | |||
bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); | |||
_name = name; | |||
_autograph = autograph; | |||
_function_attributes = attributes ?? new Dictionary<string, AttrValue>(); | |||
_function_cache = new Dictionary<string, ConcreteFunction>(); | |||
_tracing_count = 0; | |||
} | |||
public Tensor[] Apply(Tensor[] inputs) | |||
{ | |||
// TODO(Rinne): add lock here. | |||
var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); | |||
return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); | |||
} | |||
internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) | |||
{ | |||
var (concrete_function, _) = _maybe_define_concrete_function(args); | |||
return concrete_function; | |||
} | |||
private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) | |||
{ | |||
return _maybe_define_function(args); | |||
} | |||
private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | |||
{ | |||
var lookup_func_key = make_cache_key(args); | |||
if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | |||
{ | |||
return (concrete_function, args); | |||
} | |||
concrete_function = _create_concrete_function(args); | |||
_function_cache[lookup_func_key] = concrete_function; | |||
return (concrete_function, args); | |||
} | |||
private ConcreteFunction _create_concrete_function(Tensor[] args) | |||
{ | |||
_tracing_count++; | |||
int arglen = args.Length; | |||
var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( | |||
_name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), | |||
args, new Dictionary<string, object>(), autograph: _autograph | |||
), _function_attributes); | |||
return concrete_function; | |||
} | |||
private static string make_cache_key(Tensor[] inputs) | |||
{ | |||
//string res = ""; | |||
//foreach (var input in inputs) | |||
//{ | |||
// res += $"{input.name}_{input.Id}"; | |||
//} | |||
return inputs.Length.ToString(); | |||
} | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Functions; | |||
namespace Tensorflow | |||
{ | |||
@@ -54,6 +55,9 @@ namespace Tensorflow | |||
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); | |||
} | |||
} |
@@ -0,0 +1,50 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Functions | |||
{ | |||
internal static class composite_tensor_utils | |||
{ | |||
public static List<object> flatten_with_variables(object inputs) | |||
{ | |||
List<object> flat_inputs = new(); | |||
foreach(var value in nest.flatten(inputs)) | |||
{ | |||
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
{ | |||
throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
} | |||
else | |||
{ | |||
flat_inputs.Add(value); | |||
} | |||
} | |||
return flat_inputs; | |||
} | |||
public static List<object> flatten_with_variables_or_variable_specs(object arg) | |||
{ | |||
List<object> flat_inputs = new(); | |||
foreach(var value in nest.flatten(arg)) | |||
{ | |||
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
{ | |||
throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
} | |||
// TODO(Rinne): deal with `VariableSpec`. | |||
else if(value is TypeSpec type_spec && value is not TensorSpec) | |||
{ | |||
throw new NotImplementedException("The TypeSpec has not been fully supported."); | |||
} | |||
else | |||
{ | |||
flat_inputs.Add(value); | |||
} | |||
} | |||
return flat_inputs; | |||
} | |||
} | |||
} |
@@ -0,0 +1,94 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Train; | |||
using Tensorflow.Variables; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
{ | |||
public static class function_saved_model_utils | |||
{ | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="concrete_function"></param> | |||
/// <param name="inputs">a list tensors or other objects (such as variables) which | |||
/// contain tensors that were originally captured by the function</param> | |||
public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||
{ | |||
var bound_inputs = inputs?.Select(obj => | |||
{ | |||
if(obj is Tensor tensor) | |||
{ | |||
return get_tensor_from_node(tensor); | |||
} | |||
else if(obj is IVariableV1 variable) | |||
{ | |||
return get_tensor_from_node(variable); | |||
} | |||
else | |||
{ | |||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
}); | |||
var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); | |||
List<Tensor> captured_inputs_list = new(); | |||
concrete_function.set_variables(bound_variables); | |||
if (bound_inputs is not null) | |||
{ | |||
foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||
{ | |||
if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
captured_inputs_list.Add(bound_input); | |||
concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||
if(internal_capture.dtype == dtypes.resource) | |||
{ | |||
if (resource_variable_ops.is_resource_variable(bound_input)) | |||
{ | |||
handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); | |||
} | |||
else | |||
{ | |||
handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
} | |||
} | |||
concrete_function.func_graph.capture(bound_input); | |||
} | |||
} | |||
} | |||
if(captured_inputs_list.Any(inp => inp is null)) | |||
{ | |||
// TODO(Rinne): add warnings. | |||
} | |||
concrete_function.SetExternalCaptures(captured_inputs_list); | |||
} | |||
public static Tensor get_tensor_from_node(Tensor node) | |||
{ | |||
return node; | |||
} | |||
public static Tensor get_tensor_from_node(IVariableV1 node) | |||
{ | |||
if (resource_variable_ops.is_resource_variable(node)) | |||
{ | |||
return node.Handle; | |||
} | |||
else | |||
{ | |||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,282 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Common.Extensions; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
using System.Diagnostics; | |||
namespace Tensorflow.Functions | |||
{ | |||
internal static class monomorphic_function_utils | |||
{ | |||
internal static string _FORWARD_PREFIX = "__forward_"; | |||
internal static string _BACKWARD_PREFIX = "__backward_"; | |||
internal static string _INFERENCE_PREFIX = "__inference_"; | |||
internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; | |||
internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
public static string _inference_name(string name) | |||
{ | |||
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
} | |||
public static string _forward_name(string name) | |||
{ | |||
return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; | |||
} | |||
public static string _backward_name(string name) | |||
{ | |||
return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; | |||
} | |||
public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs, | |||
FuncGraph forward_graph, FuncGraph backwards_graph) | |||
{ | |||
string forward_function_name = _forward_name(forward_graph.Name); | |||
Dictionary<string, AttrValue> common_attributes; | |||
if(attrs is null) | |||
{ | |||
common_attributes = new Dictionary<string, AttrValue>(); | |||
} | |||
else | |||
{ | |||
common_attributes = new Dictionary<string, AttrValue>(attrs); | |||
} | |||
if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) | |||
{ | |||
common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); | |||
} | |||
var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
{ | |||
{FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } | |||
}); | |||
backward_function_attr.Update(common_attributes); | |||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
{ | |||
{BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } | |||
}); | |||
forward_function_attr.Update(common_attributes); | |||
var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, | |||
forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); | |||
return (forward_function, backward_function); | |||
} | |||
public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes) | |||
{ | |||
Dictionary<string, AttrValue> attrs = new(); | |||
foreach(var item in attributes) | |||
{ | |||
var key = item.Key; | |||
var value = item.Value; | |||
if (value is AttrValue attr_value) | |||
{ | |||
attrs[key] = attr_value; | |||
} | |||
else if (value is bool b) | |||
{ | |||
attrs[key] = new AttrValue() { B = b }; | |||
} | |||
else if (value is int i) | |||
{ | |||
attrs[key] = new AttrValue() { I = i }; | |||
} | |||
else if (value is float f) | |||
{ | |||
attrs[key] = new AttrValue() { F = f }; | |||
} | |||
else if(value is string s) | |||
{ | |||
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; | |||
} | |||
else if (value is byte[] bytes) | |||
{ | |||
attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + | |||
$"AttrValue. Got {value.GetType()}."); | |||
} | |||
} | |||
return attrs; | |||
} | |||
public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes) | |||
{ | |||
Dictionary<string, AttrValue> attrs = new(); | |||
foreach (var item in attributes) | |||
{ | |||
var key = item.Key; | |||
var value = item.Value; | |||
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; | |||
} | |||
return attrs; | |||
} | |||
} | |||
public class DelayedRewriteGradientFunctions : TapeGradientFunctions | |||
{ | |||
EagerDefinedFunction _inference_function; | |||
Dictionary<string, AttrValue> _attrs; | |||
int _num_inference_outputs; | |||
Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new(); | |||
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs) | |||
: base(func_graph, false) | |||
{ | |||
_func_graph = func_graph; | |||
_inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), | |||
_func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||
_attrs = attrs; | |||
_num_inference_outputs = _func_graph.Outputs.Length; | |||
} | |||
public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
{ | |||
if (input_tangents is not null) | |||
{ | |||
throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||
$"a class that does not support forwardprop."); | |||
} | |||
return _inference_function; | |||
} | |||
public override void Record(Tensors flat_outputs, Tensors inference_args) | |||
{ | |||
var (backward_function, to_record) = _backward(flat_outputs); | |||
foreach(var tape in tf.GetTapeSet()) | |||
{ | |||
tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||
inference_args, backward_function); | |||
} | |||
} | |||
public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) | |||
{ | |||
if(num_doutputs == -2) | |||
{ | |||
num_doutputs = _num_inference_outputs; | |||
} | |||
if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) | |||
{ | |||
return target; | |||
} | |||
var (forward, backward) = _construct_forward_backward(num_doutputs); | |||
_cached_function_pairs[num_doutputs] = (forward, backward); | |||
return (forward, backward); | |||
} | |||
private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||
{ | |||
Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) | |||
{ | |||
var call_op = outputs[0].op; | |||
return _rewrite_forward_and_call_backward(call_op, args); | |||
} | |||
return (backward_function, outputs); | |||
} | |||
internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) | |||
{ | |||
var (forward_function, backward_function) = forward_backward(doutputs.Length); | |||
if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) | |||
{ | |||
return backward_function.FlatStructuredOutputs; | |||
} | |||
forward_function.AddToGraph(op.graph); | |||
op._set_func_attr("f", forward_function.Name); | |||
op._set_type_list_attr("Tout", forward_function.OutputTypes); | |||
op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). | |||
Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() | |||
); | |||
for(int i = 0; i < op.outputs.Length; i++) | |||
{ | |||
var func_graph_output = forward_function._func_graph_outputs[i]; | |||
handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); | |||
} | |||
var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). | |||
ToDictionary(x => x.Item1, x => x.Item2); | |||
var remapped_captures = backward_function.CapturedInputs.Select( | |||
x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) | |||
); | |||
List<Tensor> cleaned_doutputs = new(); | |||
foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) | |||
{ | |||
if (backprop_util.IsTrainable(placeholder)) | |||
{ | |||
if(doutput is IndexedSlices) | |||
{ | |||
cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); | |||
} | |||
else if(doutput is null) | |||
{ | |||
cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); | |||
} | |||
else if(doutput is Tensor tensor) | |||
{ | |||
cleaned_doutputs.Add(tensor); | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); | |||
} | |||
} | |||
} | |||
return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); | |||
} | |||
private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) | |||
{ | |||
var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); | |||
List<TensorSpec> signature = new(); | |||
foreach(var t in trainable_outputs) | |||
{ | |||
var (shape, dtype) = default_gradient.shape_and_dtype(t); | |||
signature.Add(new TensorSpec(shape, dtype)); | |||
} | |||
Tensor[] _backprop_function(Tensor[] grad_ys) | |||
{ | |||
return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, | |||
grad_ys, src_graph: _func_graph); | |||
} | |||
_func_graph.as_default(); | |||
FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => | |||
{ | |||
Debug.Assert(y is Tensor); | |||
return (Tensor)y; | |||
}).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph); | |||
var backwards_graph_captures = backwards_graph.external_captures; | |||
var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); | |||
HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs); | |||
foreach(var capture in captures_from_forward) | |||
{ | |||
if (!existing_outputs.Contains(capture)) | |||
{ | |||
existing_outputs.Add(capture); | |||
_func_graph.Outputs.Add(capture); | |||
} | |||
} | |||
var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( | |||
_attrs, _func_graph, backwards_graph); | |||
_func_graph.Exit(); | |||
return (forward_function, backward_function); | |||
} | |||
} | |||
} |
@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||
/// Map from tensor to how many references still exist for this tensor in | |||
/// the tape. | |||
/// </summary> | |||
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||
public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||
/// <summary> | |||
/// Maps from op ID to how many output tensors of this op still need to have | |||
/// their gradients computed. | |||
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||
public BackpropInitialState() | |||
{ | |||
op_tape = new OpTape(); | |||
tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||
tensor_usage_counts = new UnorderedMap<long, long>(); | |||
op_missing_tensor = new UnorderedMap<long, long>(); | |||
} | |||
} | |||
@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||
/// <param name="target"></param> | |||
/// <param name="source"></param> | |||
/// <returns></returns> | |||
public Tensor gradient(Tensor target, Tensor source) | |||
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
if(_tape is null) | |||
{ | |||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
"compute one set of gradients (or jacobians)."); | |||
} | |||
ITape tape = stop_recording(); | |||
var results = tf.Runner.TFE_TapeGradient(tape, | |||
new[] { target }, | |||
new[] { source }, | |||
null); | |||
output_gradients, | |||
new[] { source }, | |||
unconnected_gradients); | |||
return results[0]; | |||
} | |||
public Tensor gradient(Tensor target, ResourceVariable source) | |||
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
var results = gradient(target, new List<IVariableV1> { source }); | |||
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||
return results[0]; | |||
} | |||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||
return (results[0], results[1]); | |||
} | |||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
if (_tape is null) | |||
{ | |||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
"compute one set of gradients (or jacobians)."); | |||
} | |||
var tape = stop_recording(); | |||
var results = tf.Runner.TFE_TapeGradient(tape, | |||
new[] { target }, | |||
sources.Select(x => x.Handle).ToArray(), | |||
null); | |||
output_gradients, | |||
sources.Select(x => x.Handle).ToArray(), | |||
unconnected_gradients); | |||
if (!tape.Persistent) | |||
{ | |||
@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||
public interface ITape | |||
{ | |||
void SetTapeId(int id); | |||
bool ShouldRecord(Tensor[] tensors); | |||
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||
void StartRecord(); | |||
void StopRecord(); | |||
bool Persistent { get; } | |||
void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function); | |||
void VariableAccessed(ResourceVariable variable); | |||
void RecordOperation(string op_type, | |||
Tensor[] outputs, | |||
Tensor[] inputs, | |||
BackwardFunction backward_function); | |||
void VariableAccessed(IVariableV1 variable); | |||
void Watch(Tensor x); | |||
ResourceVariable[] WatchedVariables(); | |||
IVariableV1[] WatchedVariables(); | |||
Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
Tensor[] source_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients); | |||
Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
long[] source_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
bool build_default_zeros_grads); | |||
} | |||
} |
@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||
{ | |||
public string op_type { get; set; } | |||
public TapeTensor[] output_tensor_info { get; set; } | |||
public Tensor[] input_tensor_id { get; set; } | |||
public long[] input_tensor_id { get; set; } | |||
public BackwardFunction backward_function { get; set; } | |||
public override string ToString() | |||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||
} | |||
} |
@@ -2,235 +2,246 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public partial class Tape | |||
{ | |||
// int kMinAggregateCount = 4; | |||
// int kMinAggregateBytes = 128 * 1024 * 1024; | |||
static readonly int kMinAggregateCount = 4; | |||
static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
Tensor[] source_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients) | |||
static Tape() | |||
{ | |||
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||
// var gradients_size = new UnorderedMap<Tensor, long>(); | |||
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||
var state = PrepareBackprop( | |||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||
output_gradients, | |||
tensor_tape_, | |||
state.op_tape); | |||
_functionsAcceptingNoneForIndicesMap = new(); | |||
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
} | |||
while (!op_stack.empty()) | |||
public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
long[] source_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
bool build_default_zeros_grads) | |||
{ | |||
UnorderedSet<long> sources_set = new(source_tensor_ids); | |||
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||
UnorderedMap<long, long> gradients_size = new(); | |||
while(op_stack.Count > 0) | |||
{ | |||
var op = op_stack.Dequeue(); | |||
if (!state.op_tape.find(op, out var trace)) | |||
long op = op_stack.Dequeue(); | |||
if(!state.op_tape.TryGetValue(op, out var op_it)) | |||
{ | |||
continue; | |||
// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
} | |||
var trace = op_it; | |||
state.op_tape.erase(op); | |||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||
var unneeded_gradients = new List<long>(); | |||
for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||
List<Tensor> out_gradients = new(); | |||
List<long> unneeded_gradients = new(); | |||
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||
{ | |||
var in_tensor_id = trace.input_tensor_id[i]; | |||
if (!tensor_tape_.find(in_tensor_id) && | |||
!sources_set.find(in_tensor_id)) | |||
long in_tensor_id = trace.input_tensor_id[i]; | |||
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||
{ | |||
unneeded_gradients.Add(i); | |||
} | |||
} | |||
bool any_gradient_nonzero = false; | |||
var zero_indices = new List<int>(); | |||
for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||
List<int> zero_indices = new(); | |||
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||
{ | |||
var id = trace.output_tensor_info[i].GetTensor(); | |||
if (!gradients.find(id, out var grad_it)) | |||
long id = trace.output_tensor_info[i].GetID(); | |||
if(!gradients.TryGetValue(id, out var grad_it)) | |||
{ | |||
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||
func_name_it.find(i)) | |||
out_gradients.Add(null); | |||
if (build_default_zeros_grads) | |||
{ | |||
out_gradients.Add(null); | |||
} | |||
else | |||
{ | |||
out_gradients.Add(null); | |||
zero_indices.Add(i); | |||
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||
!func_name_it.find(i)) | |||
{ | |||
zero_indices.Add(i); | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
any_gradient_nonzero = true; | |||
var new_gradients = grad_it.Count == 1 ? | |||
grad_it[0] : | |||
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||
Tensor new_gradients; | |||
if (grad_it.Count == 1) | |||
{ | |||
new_gradients = grad_it[0]; | |||
} | |||
else | |||
{ | |||
new_gradients = AggregateGradients(grad_it); | |||
} | |||
if (!sources_set.find(id)) | |||
{ | |||
gradients.Remove(id); | |||
} | |||
else | |||
{ | |||
// grad_it.Clear(); | |||
// grad_it.Add(new_gradients); | |||
// vspace.MarkAsResult(new_gradients); | |||
grad_it.Clear(); | |||
grad_it.Add(new_gradients); | |||
// MarkAsResult | |||
} | |||
out_gradients.Add(new_gradients); | |||
} | |||
} | |||
Tensor[] in_gradients; | |||
Tensor[] in_gradients = new Tensor[0]; | |||
if (any_gradient_nonzero) | |||
{ | |||
// foreach (var i in zero_indices) | |||
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||
if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||
if (!_persistent) | |||
foreach(var i in zero_indices) | |||
{ | |||
// trace.backward_function_deleter(trace.backward_function); | |||
trace.backward_function = null; | |||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
} | |||
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||
} | |||
else | |||
{ | |||
in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||
out_gradients.Clear(); | |||
} | |||
bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||
for(int i = 0, end = in_gradients.Length; i < end; i++) | |||
{ | |||
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||
var id = trace.input_tensor_id[k]; | |||
if (in_gradients[i] != null) | |||
long id = trace.input_tensor_id[i]; | |||
if (in_gradients[i] is not null) | |||
{ | |||
var unaggregated_grads = gradients[id]; | |||
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||
unaggregated_grads.Add(in_gradients[i]); | |||
/*if (unaggregated_grads.Count > kMinAggregateCount) | |||
if(unaggregated_grads.Count > kMinAggregateCount) | |||
{ | |||
if (!gradients_size.find(id, out var size)) | |||
if(!gradients_size.TryGetValue(id, out var size)) | |||
{ | |||
size = (long)unaggregated_grads[0].size; | |||
size = NumElements(unaggregated_grads[0]); | |||
gradients_size.emplace(id, size); | |||
} | |||
if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
{ | |||
throw new NotImplementedException(""); | |||
Tensor grad = AggregateGradients(unaggregated_grads); | |||
unaggregated_grads.Clear(); | |||
unaggregated_grads.Add(grad); | |||
} | |||
}*/ | |||
} | |||
} | |||
if (!state.tensor_usage_counts.find(id)) | |||
if(!state.tensor_usage_counts.find(id)) | |||
{ | |||
continue; | |||
} | |||
state.tensor_usage_counts[id]--; | |||
if (state.tensor_usage_counts[id] > 0) | |||
if(state.tensor_usage_counts[id] > 0) | |||
{ | |||
continue; | |||
if (!tensor_tape_.find(id, out var tape_it)) | |||
} | |||
if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||
{ | |||
if (gradients.find(id, out var grad_it)) | |||
if (gradients.find(id)) | |||
{ | |||
// foreach (var g in grad_it) | |||
// DeleteGradient(g); | |||
gradients.erase(id); | |||
} | |||
continue; | |||
} | |||
var op_id = tape_it; | |||
if (op_id == -1) | |||
long op_id = tape_it; | |||
if(op_id == -1) | |||
{ | |||
continue; | |||
if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||
} | |||
if(state.op_missing_tensor.find(op_id)) | |||
{ | |||
state.op_missing_tensor[op_id]--; | |||
if (state.op_missing_tensor[op_id] == 0) | |||
if(state.op_missing_tensor[op_id] == 0) | |||
{ | |||
op_stack.Enqueue(op_id); | |||
} | |||
} | |||
} | |||
} | |||
if (state.op_tape.Count > 0) | |||
if(state.op_tape.Count > 0) | |||
{ | |||
throw new RuntimeError("Invalid tape state."); | |||
var result = new Tensor[source_tensor_ids.Length]; | |||
var j = 0; | |||
foreach (var id in source_tensor_ids) | |||
} | |||
Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||
for(int i = 0; i < source_tensor_ids.Length; i++) | |||
{ | |||
if (gradients.find(id, out var grad_it)) | |||
long tensor_id = source_tensor_ids[i]; | |||
if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||
{ | |||
if (grad_it.Count > 1) | |||
result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||
else | |||
result[j] = grad_it[0]; | |||
result[i] = null; | |||
} | |||
else | |||
{ | |||
if(grad_it.Count > 1) | |||
{ | |||
Tensor grad = AggregateGradients(grad_it); | |||
grad_it.Clear(); | |||
grad_it.Add(grad); | |||
} | |||
result[i] = grad_it[0]; | |||
} | |||
j++; | |||
} | |||
return result; | |||
} | |||
UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | |||
{ | |||
var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
return m; | |||
return _functionsAcceptingNoneForIndicesMap; | |||
} | |||
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients, | |||
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
TensorTape tensor_tape, | |||
OpTape op_tape) | |||
{ | |||
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||
for (int i = 0; i < target_tensor_ids.Length; ++i) | |||
var result = new UnorderedMap<long, List<Tensor>>(); | |||
for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||
{ | |||
var id = target_tensor_ids[i]; | |||
if (output_gradients.Length == 0 || output_gradients[i] == null) | |||
long id = target_tensor_ids[i]; | |||
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||
{ | |||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||
{ | |||
if (!op_tape.find(tensor_tape[id], out var op_it)) | |||
if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||
{ | |||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
"failed to find operation producing a tensor"); | |||
"failed to find operation producing a tensor."); | |||
} | |||
bool found = false; | |||
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||
for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||
{ | |||
if (op_it.output_tensor_info[j].GetTensor() == id) | |||
if (op_it.output_tensor_info[j].GetID() == id) | |||
{ | |||
found = true; | |||
var ones = op_it.output_tensor_info[j].OnesLike(); | |||
result[id].Add(ones); | |||
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
break; | |||
} | |||
} | |||
if (!found) | |||
{ | |||
throw new ValueError("Internal state of the gradient tape is invalid: " + | |||
"none of operations outputs match expected tensor"); | |||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
"none of operations outputs match expected tensor."); | |||
} | |||
} | |||
else | |||
{ | |||
if (sources_that_are_targets.find(id, out var source_tensor)) | |||
result[id].Add(source_tensor.OnesLike()); | |||
if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||
{ | |||
Tensor ones_like = BuildOnesLike(source_tensor); | |||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
result[id].Add(output_gradients[i]); | |||
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||
} | |||
} | |||
@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||
} | |||
return result; | |||
} | |||
Tensor BuildOnesLike(TapeTensor t) | |||
{ | |||
return t.OnesLike(); | |||
} | |||
Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||
{ | |||
if(gradient_tensors.Count == 0) | |||
{ | |||
return gradient_tensors[0]; | |||
} | |||
return tf.add_n(gradient_tensors.ToArray()); | |||
} | |||
void DeleteGradient(Tensor gradient) | |||
{ | |||
// Do not do anything here. Because GC will collect it when it has no reference. | |||
} | |||
long NumElements(Tensor tensor) => 1; | |||
} | |||
} |
@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||
{ | |||
public partial class Tape | |||
{ | |||
public BackpropInitialState PrepareBackprop(Tensor[] target, | |||
public BackpropInitialState PrepareBackprop(long[] target, | |||
TensorTape tensor_tape, | |||
OpTape op_tape, | |||
UnorderedSet<Tensor> sources_set, | |||
UnorderedSet<long> sources_set, | |||
bool persistent_tape) | |||
{ | |||
Stack<long> tensor_stack = new Stack<long>(); | |||
foreach(var t in target) | |||
{ | |||
tensor_stack.Push(t); | |||
} | |||
BackpropInitialState result = new BackpropInitialState(); | |||
var tensor_stack = new Queue<Tensor>(target); | |||
while (tensor_stack.Count > 0) | |||
while(tensor_stack.Count > 0) | |||
{ | |||
var tensor_id = tensor_stack.Dequeue(); | |||
if (!tensor_tape.find(tensor_id, out var op_id)) | |||
long tensor_id = tensor_stack.Pop(); | |||
if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||
{ | |||
continue; | |||
if (op_id == -1 || | |||
!op_tape.find(op_id, out var op_it) || | |||
result.op_tape.find(op_id, out var result_op_it)) | |||
} | |||
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||
|| result.op_tape.find(op_id)) | |||
{ | |||
continue; | |||
} | |||
result.op_tape.emplace(op_id, op_it); | |||
foreach (var it in op_it.input_tensor_id) | |||
foreach(var it in op_it.input_tensor_id) | |||
{ | |||
if (result.tensor_usage_counts.find(it)) | |||
if(result.tensor_usage_counts.find(it)) | |||
{ | |||
result.tensor_usage_counts[it]++; | |||
} | |||
else | |||
{ | |||
result.tensor_usage_counts[it] = 1; | |||
if (tensor_tape.find(it)) | |||
tensor_stack.Enqueue(it); | |||
{ | |||
tensor_stack.Push(it); | |||
} | |||
} | |||
} | |||
if (!persistent_tape) | |||
op_tape.Remove(op_id); | |||
{ | |||
op_tape.erase(op_id); | |||
} | |||
} | |||
foreach (var pair in result.tensor_usage_counts) | |||
foreach(var pair in result.tensor_usage_counts) | |||
{ | |||
if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||
result.op_missing_tensor[it] += 1; | |||
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||
{ | |||
result.op_missing_tensor[it]++; | |||
} | |||
} | |||
if (!persistent_tape) | |||
{ | |||
// Call destructors for all unneeded gradient functions and | |||
// clear the op_tape. We can clear the tape because ownership of | |||
// backward functions that will be used for gradient computation | |||
// has been transferred to `result`. | |||
/*for (const auto&op_pair : *op_tape) { | |||
op_pair.second.backward_function_deleter( | |||
op_pair.second.backward_function); | |||
}*/ | |||
op_tape.Clear(); | |||
} | |||
return result; | |||
} | |||
} | |||
@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||
public partial class Tape | |||
{ | |||
long next_op_id_ = 0; | |||
UnorderedMap<Tensor, long> tensor_usage_; | |||
UnorderedMap<long, long> tensor_usage_; | |||
public void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function) | |||
{ | |||
if (!ShouldRecord(input_tensors)) | |||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
return; | |||
var op_id = next_op_id_++; | |||
foreach (var i in input_tensors) | |||
foreach (var i in input_tensor_id) | |||
{ | |||
tensor_usage_[i]++; | |||
} | |||
long op_id = next_op_id_++; | |||
foreach (var o in output_tensors) | |||
{ | |||
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | |||
tensor_tape_[o.GetTensor()] = op_id; | |||
tensor_usage_[o.GetTensor()] = 1; | |||
tensor_tape_[o.GetID()] = op_id; | |||
tensor_usage_[o.GetID()] = 1; | |||
} | |||
op_tape_[op_id] = new OpTapeEntry | |||
{ | |||
op_type = op_type, | |||
output_tensor_info = output_tensors, | |||
input_tensor_id = input_tensors, | |||
output_tensor_info = output_tensors.ToArray(), | |||
input_tensor_id = input_tensor_id.ToArray(), | |||
backward_function = backward_function | |||
}; | |||
} | |||
public void RecordOperation(string op_type, | |||
Tensor[] outputs, | |||
Tensor[] inputs, | |||
BackwardFunction backward_function) | |||
{ | |||
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||
_created_eagerly = tf.Context.executing_eagerly(); | |||
tensor_tape_ = new TensorTape(); | |||
op_tape_ = new OpTape(); | |||
tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||
tensor_usage_ = new UnorderedMap<long, long>(); | |||
if(_created_eagerly) | |||
tf.Context.start_step(); | |||
// nesting_id = ++tape_nesting_id_counter; | |||
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||
public void Watch(Tensor x) | |||
{ | |||
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | |||
tensor_tape_.emplace(x, -1); | |||
tensor_tape_.emplace(x.Id, -1); | |||
} | |||
public bool ShouldRecord(Tensor[] tensors) | |||
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||
{ | |||
var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||
for (int i = 0; i < tensors.Length; ++i) | |||
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||
for (int i = 0; i < tensor_ids.Length; ++i) | |||
{ | |||
if (tensor_tape_.find(tensors[i])) | |||
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||
{ | |||
if (IsDtypeTrainable(dtypes[i])) | |||
return true; | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
public void VariableAccessed(ResourceVariable variable) | |||
public void VariableAccessed(IVariableV1 variable) | |||
{ | |||
Watch(variable.Handle); | |||
} | |||
public ResourceVariable[] WatchedVariables() | |||
public IVariableV1[] WatchedVariables() | |||
{ | |||
return null; | |||
} | |||
@@ -1,27 +1,63 @@ | |||
using static Tensorflow.Binding; | |||
using OneOf; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public class TapeTensor | |||
{ | |||
Tensor tensor; | |||
long id => tensor.Id; | |||
TF_DataType dtype => tensor.dtype; | |||
Shape shape => tensor.shape; | |||
internal Tensor tensor; | |||
internal long id; | |||
internal TF_DataType dtype; | |||
internal OneOf<Shape, Tensor> shape; | |||
public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||
{ | |||
this.id = id; | |||
this.dtype = dtype; | |||
this.shape = shape; | |||
} | |||
public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||
{ | |||
this.id = id; | |||
this.dtype = dtype; | |||
this.shape = shape; | |||
} | |||
public TapeTensor(Tensor tensor) | |||
{ | |||
this.id = tensor.Id; | |||
this.dtype = tensor.dtype; | |||
this.shape = tensor.shape; | |||
this.tensor = tensor; | |||
} | |||
public long GetID() => tensor.Id; | |||
public Tensor GetTensor() => tensor; | |||
public long GetID() => id; | |||
public Tensor ZerosLike() | |||
=> tf.zeros(shape: shape, dtype: dtype); | |||
{ | |||
if(dtype == dtypes.resource) | |||
{ | |||
return null; | |||
} | |||
if(shape.Index == 1) | |||
{ | |||
return tf.zeros_like(shape.AsT1); | |||
} | |||
return tf.zeros(shape.AsT0, dtype); | |||
} | |||
public Tensor OnesLike() | |||
=> tf.ones(shape: shape, dtype: dtype); | |||
{ | |||
if (shape.Index == 1) | |||
{ | |||
return tf.ones_like(shape.AsT1); | |||
} | |||
return tf.ones(shape.AsT0, dtype); | |||
} | |||
//public Tensor OnesLike() | |||
// => tf.ones(shape: shape, dtype: dtype); | |||
public override string ToString() | |||
=> $"{id}, {shape}, {dtype.as_numpy_name()}"; | |||
@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||
/// produced this tensor. A value of -1 means that the tensor was directly | |||
/// watched and not the result of any operation in the tape. | |||
/// </summary> | |||
public class TensorTape : UnorderedMap<Tensor, long> | |||
public class TensorTape : UnorderedMap<long, long> | |||
{ | |||
} | |||
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public class custom_gradient | |||
{ | |||
public static string generate_name() | |||
{ | |||
return $"CustomGradient-{ops.uid()}"; | |||
} | |||
} | |||
} |
@@ -0,0 +1,52 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Gradients | |||
{ | |||
internal static class default_gradient | |||
{ | |||
public static (Shape, TF_DataType) shape_and_dtype(Tensor t) | |||
{ | |||
if(t.dtype == dtypes.resource) | |||
{ | |||
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
{ | |||
throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
$"of a variable without handle data:\n{t}"); | |||
} | |||
return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); | |||
} | |||
return (t.shape, t.dtype); | |||
} | |||
public static Tensor zeros_like(Tensor t) | |||
{ | |||
if(t.dtype == dtypes.resource) | |||
{ | |||
var (shape, dtype) = shape_and_dtype(t); | |||
return array_ops.zeros(shape, dtype); | |||
} | |||
else | |||
{ | |||
return array_ops.zeros_like(t); | |||
} | |||
} | |||
public static TF_DataType get_zeros_dtype(Tensor t) | |||
{ | |||
if(t.dtype == dtypes.resource) | |||
{ | |||
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
{ | |||
throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
$"of a variable without handle data:\n{t}"); | |||
} | |||
return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); | |||
} | |||
return t.dtype; | |||
} | |||
} | |||
} |
@@ -14,10 +14,15 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.ControlFlows; | |||
using static Tensorflow.Binding; | |||
@@ -25,6 +30,11 @@ namespace Tensorflow | |||
{ | |||
public class gradients_util | |||
{ | |||
// Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are | |||
// unfortunately too slow to use here. | |||
public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; | |||
public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; | |||
public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; | |||
public static Tensor[] _GradientsHelper(Tensor[] ys, | |||
Tensor[] xs, | |||
Tensor[] grad_ys = null, | |||
@@ -143,7 +153,7 @@ namespace Tensorflow | |||
Tensor[] in_grads = null; | |||
Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||
var is_partitioned_call = _IsPartitionedCall(op); | |||
var is_func_call = false; | |||
var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; | |||
var has_out_grads = out_grads.Exists(x => x != null); | |||
if (has_out_grads && !stop_ops.Contains(op)) | |||
{ | |||
@@ -157,14 +167,41 @@ namespace Tensorflow | |||
{ | |||
if (is_func_call) | |||
{ | |||
EagerDefinedFunction func_call = null; | |||
if (is_partitioned_call) | |||
{ | |||
var func_attr = op.get_attr("f"); | |||
Debug.Assert(func_attr is NameAttrList); | |||
var func_name = ((NameAttrList)func_attr).Name; | |||
func_call = src_graph._get_function(func_name); | |||
if(func_call is null && src_graph.OuterGraph is not null) | |||
{ | |||
var graph = src_graph.OuterGraph; | |||
while(graph is not null) | |||
{ | |||
func_call = graph._get_function(func_name); | |||
if(func_call is not null) | |||
{ | |||
break; | |||
} | |||
if(graph.OuterGraph is not null) | |||
{ | |||
graph = graph.OuterGraph; | |||
} | |||
else | |||
{ | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
func_call = src_graph._get_function(op.type); | |||
} | |||
// skip the following codes: | |||
// `func_call = getattr(op, "__defun", func_call)` | |||
grad_fn = func_call.csharp_grad_func; | |||
} | |||
else | |||
{ | |||
@@ -208,6 +245,8 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||
null, (x, y) => _SymGrad(x, y)); | |||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||
} | |||
_VerifyGeneratedGradients(in_grads, op); | |||
@@ -663,6 +702,11 @@ namespace Tensorflow | |||
dtypes.resource, dtypes.variant}.Contains(dtype); | |||
} | |||
public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||
{ | |||
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||
} | |||
/// <summary> | |||
/// Return true if op has real gradient. | |||
/// </summary> | |||
@@ -683,7 +727,7 @@ namespace Tensorflow | |||
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
{ | |||
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
// scope = scope.TrimEnd('/').Replace('/', '_'); | |||
return grad_fn(op, out_grads); | |||
} | |||
@@ -696,5 +740,28 @@ namespace Tensorflow | |||
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | |||
$"inputs {op.inputs._inputs.Count()}"); | |||
} | |||
private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) | |||
{ | |||
var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); | |||
var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); | |||
NameAttrList f = new(); | |||
if (_IsPartitionedCall(op)) | |||
{ | |||
var func_attr = op.get_attr("f"); | |||
Debug.Assert(func_attr is NameAttrList); | |||
f.Name = ((NameAttrList)func_attr).Name; | |||
} | |||
else | |||
{ | |||
f.Name = op.type; | |||
} | |||
foreach(var k in op.node_def.Attr.Keys) | |||
{ | |||
f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); | |||
} | |||
var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); | |||
return in_grads; | |||
} | |||
} | |||
} |
@@ -98,12 +98,23 @@ namespace Tensorflow | |||
{ | |||
if (op.inputs == null) return null; | |||
RegisterFromAssembly(); | |||
var gradient_function = op._gradient_function; | |||
if(gradient_function is null) | |||
{ | |||
RegisterFromAssembly(); | |||
if (!gradientFunctions.ContainsKey(op.type)) | |||
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
if (!gradientFunctions.ContainsKey(op.type)) | |||
throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
return gradientFunctions[op.type]; | |||
} | |||
return gradientFunctions[op.type]; | |||
Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) | |||
{ | |||
return gradient_function(operation, args); | |||
} | |||
// TODO(Rinne): check if this needs to be registered. | |||
return wrapped_gradient_function; | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using MethodBoundaryAspect.Fody.Attributes; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Functions; | |||
@@ -22,7 +23,7 @@ namespace Tensorflow.Graphs | |||
public override void OnEntry(MethodExecutionArgs args) | |||
{ | |||
// TODO: func_name can be cache in FullName + Args | |||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||
if (functions.ContainsKey(func_name)) | |||
{ | |||
@@ -91,6 +92,7 @@ namespace Tensorflow.Graphs | |||
// cache function. | |||
function.ReturnType = args.ReturnValue.GetType(); | |||
function._set_infer_function(); | |||
functions[func_name] = function; | |||
// run function | |||
@@ -1,6 +1,15 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Buffers; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Exceptions; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs; | |||
@@ -10,12 +19,66 @@ namespace Tensorflow.Graphs; | |||
/// </summary> | |||
public class FuncGraph : Graph, IDisposable | |||
{ | |||
SafeFuncGraphHandle _func_graph_handle; | |||
internal SafeFuncGraphHandle _func_graph_handle; | |||
internal HashSet<Tensor> _resource_tensor_inputs; | |||
internal HashSet<WeakReference<IVariableV1>> _watched_variables; | |||
internal IEnumerable<WeakReference<IVariableV1>> _weak_variables; | |||
internal object[] _structured_outputs; | |||
internal Dictionary<long, string> _output_names; | |||
public string FuncName => _graph_key; | |||
public Tensors Inputs { get; set; } = new Tensors(); | |||
public Tensors Outputs { get; set; } = new Tensors(); | |||
public Dictionary<string, string> Attrs { get; set; } | |||
public Tensors FlatStructuredOutputs | |||
{ | |||
get | |||
{ | |||
List<Tensor> res = new(); | |||
foreach(var obj in _structured_outputs) | |||
{ | |||
if(obj is Tensor tensor) | |||
{ | |||
res.Add(tensor); | |||
} | |||
else if(obj is IEnumerable<Tensor> tensors) | |||
{ | |||
res.AddRange(tensors); | |||
} | |||
else | |||
{ | |||
throw new TypeError("The structured outputs member should be tensor or tensors."); | |||
} | |||
} | |||
return res; | |||
} | |||
} | |||
public string Name { get; set; } | |||
public IEnumerable<IVariableV1> Variables | |||
{ | |||
get | |||
{ | |||
return _weak_variables.Select(v => | |||
{ | |||
if (v.TryGetTarget(out var target)) | |||
{ | |||
return target; | |||
} | |||
else | |||
{ | |||
throw new AssertionError("Called a function referencing variables which have been deleted. " + | |||
"This likely means that function-local variables were created and " + | |||
"not referenced elsewhere in the program. This is generally a " + | |||
"mistake; consider storing variables in an object attribute on first call."); | |||
} | |||
}); | |||
} | |||
internal set | |||
{ | |||
_weak_variables = value.Select(x => new WeakReference<IVariableV1>(x)); | |||
} | |||
} | |||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
public Dictionary<string, AttrValue> Attrs { get; set; } | |||
Dictionary<long, (Tensor, Tensor)> _captures | |||
= new Dictionary<long, (Tensor, Tensor)>(); | |||
@@ -39,31 +102,42 @@ public class FuncGraph : Graph, IDisposable | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
_graph_key = Name = name; | |||
building_function = true; | |||
_weak_variables = new List<WeakReference<IVariableV1>>(); | |||
_resource_tensor_inputs = new HashSet<Tensor>(); | |||
_watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
} | |||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base() | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
_graph_key = Name = name; | |||
building_function = true; | |||
Attrs = attrs; | |||
// Will to test if FuncGraph has memory leak | |||
// c_api.TF_DeleteGraph(_handle); | |||
_handle = handle; | |||
_weak_variables = new List<WeakReference<IVariableV1>>(); | |||
_resource_tensor_inputs = new HashSet<Tensor>(); | |||
_watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
} | |||
public void ToGraph(Operation[] opers, | |||
public void replace_capture(Tensor tensor, Tensor placeholder) | |||
{ | |||
_captures[tensor.Id] = (tensor, placeholder); | |||
} | |||
public unsafe void ToGraph(Operation[] opers, | |||
Tensor[] inputs, Tensor[] outputs, | |||
string[] output_names) | |||
{ | |||
var status = new Status(); | |||
if (output_names != null && output_names.Length == 0) | |||
if (output_names is null) | |||
{ | |||
output_names = null; | |||
output_names = new string[0]; | |||
}; | |||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
@@ -75,7 +149,7 @@ public class FuncGraph : Graph, IDisposable | |||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
outputs.Length, | |||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
output_names, | |||
output_names.Length != outputs.Length ? null : output_names, | |||
IntPtr.Zero, | |||
null, | |||
status); | |||
@@ -141,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||
return tensor; | |||
} | |||
public void watch_variable(IVariableV1 v) | |||
{ | |||
if (_resource_tensor_inputs.Contains(v.Handle)) | |||
{ | |||
return; | |||
} | |||
_watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||
//this = this.outer_graph; | |||
} | |||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||
{ | |||
Tensor graph_const = null; | |||
@@ -205,6 +289,19 @@ public class FuncGraph : Graph, IDisposable | |||
Inputs.Add(placeholder); | |||
} | |||
Tensor pop_capture(Tensor tensor) | |||
{ | |||
if(_captures.TryGetValue(tensor.Id, out var capture)) | |||
{ | |||
_captures.Remove(tensor.Id); | |||
return capture.Item2; | |||
} | |||
else | |||
{ | |||
return null; | |||
} | |||
} | |||
Tensor _create_substitute_placeholder(Tensor value, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -228,10 +325,7 @@ public class FuncGraph : Graph, IDisposable | |||
foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
{ | |||
var serialized = new AttrValue | |||
{ | |||
S = ByteString.CopyFromUtf8(attr_value) | |||
}.ToByteArray(); | |||
var serialized = attr_value.ToByteArray(); | |||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||
tf.Status.Check(true); | |||
} | |||
@@ -254,4 +348,261 @@ public class FuncGraph : Graph, IDisposable | |||
{ | |||
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||
} | |||
public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func, | |||
object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null, | |||
FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, | |||
bool add_control_dependencies = true, string[] arg_names = null, | |||
Tensor op_return_value = null, bool capture_by_value = false, | |||
bool acd_record_initial_resource_uses = false) | |||
{ | |||
if(func_graph is null) | |||
{ | |||
func_graph = new FuncGraph(name); | |||
} | |||
// TODO(Rinne): deal with control dependencies. | |||
func_graph.as_default(); | |||
var current_scope = variable_scope.get_variable_scope(); | |||
var default_use_resource = current_scope.use_resource; | |||
current_scope.use_resource = true; | |||
if(signature is not null) | |||
{ | |||
args = signature; | |||
kwargs = new Dictionary<string, object>(); | |||
} | |||
var func_args = _get_defun_inputs_from_args(args, arg_names); | |||
var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); | |||
if(func_kwargs is not null && func_kwargs.Count > 0) | |||
{ | |||
throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); | |||
} | |||
foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs })) | |||
{ | |||
if(arg is Tensor tensor && tensor.dtype == dtypes.resource) | |||
{ | |||
func_graph._resource_tensor_inputs.Add(tensor); | |||
} | |||
else if (arg is ResourceVariable variable) | |||
{ | |||
func_graph._resource_tensor_inputs.Add(variable.Handle); | |||
} | |||
} | |||
// skip the assignment of `func_graph.structured_input_signature`. | |||
var flat_func_args = nest.flatten(func_args as object); | |||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
Tensor convert(object x) | |||
{ | |||
if (x is null) return null; | |||
Tensor res = null; | |||
if(op_return_value is not null && x is Operation) | |||
{ | |||
tf_with(ops.control_dependencies(new object[] { x }), _ => | |||
{ | |||
res = array_ops.identity(op_return_value); | |||
}); | |||
} | |||
else if(x is not TensorArray) | |||
{ | |||
Debug.Assert(x is Tensor); | |||
res = ops.convert_to_tensor_or_composite(x as Tensor); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException($"The `TensorArray` is not supported here currently."); | |||
} | |||
if (add_control_dependencies) | |||
{ | |||
// TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. | |||
} | |||
return res; | |||
} | |||
if (autograph) | |||
{ | |||
throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); | |||
} | |||
var func_outputs = func(func_args); | |||
func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); | |||
func_outputs = func_outputs.Select(x => convert(x)).ToArray(); | |||
// TODO(Rinne): `check_func_mutation`. | |||
current_scope.use_resource = default_use_resource; | |||
var graph_variables = func_graph._watched_variables.ToList(); | |||
HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>(); | |||
List<Tensor> inputs = new(); | |||
foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) | |||
{ | |||
if(arg is BaseResourceVariable variable) | |||
{ | |||
var resource_placeholder = func_graph.pop_capture(variable.Handle); | |||
if(resource_placeholder is null) | |||
{ | |||
continue; | |||
} | |||
Debug.Assert(variable is IVariableV1); | |||
arg_variables.Add(variable as IVariableV1); | |||
inputs.Add(resource_placeholder); | |||
} | |||
else if(arg is Tensor tensor) | |||
{ | |||
inputs.Add(tensor); | |||
} | |||
} | |||
var variables = graph_variables.Select(v => | |||
{ | |||
if (v.TryGetTarget(out var target)) | |||
{ | |||
return target; | |||
} | |||
else | |||
{ | |||
return null; | |||
} | |||
}).Where(v => v is not null && !arg_variables.Contains(v)); | |||
func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); | |||
func_graph._structured_outputs = func_outputs; | |||
func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) | |||
.Select(x => func_graph.capture(x))); | |||
func_graph.Variables = variables; | |||
func_graph.Exit(); | |||
if (add_control_dependencies) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
return func_graph; | |||
} | |||
private static object[] _get_defun_inputs_from_args(object[] args, string[] names) | |||
{ | |||
return _get_defun_inputs(args, names, args) as object[]; | |||
} | |||
private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs) | |||
{ | |||
// TODO(Rinne): implement it. | |||
Debug.Assert(kwargs is null || kwargs.Count == 0); | |||
return kwargs; | |||
//string[] names; | |||
//object[] args; | |||
//if(kwargs is not null && kwargs.Count > 0) | |||
//{ | |||
// var sorted_kwargs = kwargs.OrderBy(x => x.Key); | |||
// names = sorted_kwargs.Select(x => x.Key).ToArray(); | |||
// args = sorted_kwargs.Select(x => x.Value).ToArray(); | |||
//} | |||
//else | |||
//{ | |||
// names = new string[0]; | |||
// args = new object[0]; | |||
//} | |||
//return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>; | |||
} | |||
private static object _get_defun_inputs(object[] args, string[] names, object structured_args) | |||
{ | |||
List<object> function_inputs = new(); | |||
if(names is null) | |||
{ | |||
names = new string[args.Length]; | |||
} | |||
foreach(var (arg_value, name) in zip(args, names)) | |||
{ | |||
foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) | |||
{ | |||
function_inputs.Add(_get_defun_input(val, name)); | |||
} | |||
} | |||
return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true); | |||
} | |||
private static object _get_defun_input(object arg, string name) | |||
{ | |||
var func_graph = ops.get_default_graph() as FuncGraph; | |||
Debug.Assert(func_graph is not null); | |||
if (arg is Tensor tensor) | |||
{ | |||
Tensor placeholder; | |||
try | |||
{ | |||
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||
} | |||
catch (ValueError) | |||
{ | |||
// TODO(Rinne): Add warning here. | |||
placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||
} | |||
handle_data_util.copy_handle_data(tensor, placeholder); | |||
if (name is not null) | |||
{ | |||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
{ | |||
S = tf.compat.as_bytes(name) | |||
}); | |||
} | |||
return placeholder; | |||
} | |||
else if (arg is TensorSpec spec) | |||
{ | |||
string requested_name; | |||
if (!string.IsNullOrEmpty(spec.name)) | |||
{ | |||
requested_name = spec.name; | |||
} | |||
else | |||
{ | |||
requested_name = name; | |||
} | |||
Tensor placeholder; | |||
try | |||
{ | |||
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||
} | |||
catch (ValueError) | |||
{ | |||
// TODO(Rinne): Add warning here. | |||
placeholder = tf.placeholder(spec.dtype, spec.shape); | |||
} | |||
if (name is not null) | |||
{ | |||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
{ | |||
S = tf.compat.as_bytes(requested_name) | |||
}); | |||
} | |||
return placeholder; | |||
} | |||
else if (arg is BaseResourceVariable variable) | |||
{ | |||
var placeholder = func_graph.capture(variable.Handle, name); | |||
placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
{ | |||
S = tf.compat.as_bytes(name) | |||
}); | |||
return arg; | |||
} | |||
// TODO(Rinne): deal with `VariableSpec`. | |||
else | |||
{ | |||
return arg; | |||
} | |||
} | |||
} |
@@ -1,4 +1,6 @@ | |||
namespace Tensorflow | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
@@ -6,5 +8,10 @@ | |||
{ | |||
} | |||
internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map) | |||
{ | |||
return new GraphOverrideGradientContext(this, gradient_function_map); | |||
} | |||
} | |||
} |
@@ -118,7 +118,7 @@ namespace Tensorflow | |||
/// <param name="compute_device">(Optional.) If True, device functions will be executed | |||
/// to compute the device property of the Operation.</param> | |||
/// <returns>An `Operation` object.</returns> | |||
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | |||
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) | |||
{ | |||
var ret = new Operation(c_op, this); | |||
_add_op(ret); | |||
@@ -19,6 +19,9 @@ using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Collections.Specialized; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Common.Extensions; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
@@ -86,6 +89,13 @@ namespace Tensorflow | |||
private int _next_id_counter; | |||
private List<Operation> _unfetchable_ops = new List<Operation>(); | |||
private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||
internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new(); | |||
private VersionDef _graph_def_versions = new VersionDef() | |||
{ | |||
Producer = versions.GRAPH_DEF_VERSION, | |||
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
}; | |||
public string _name_stack = ""; | |||
protected string _graph_key; | |||
@@ -121,6 +131,8 @@ namespace Tensorflow | |||
protected Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
public SafeGraphHandle c_graph => _handle; | |||
public Graph() | |||
{ | |||
@@ -147,6 +159,44 @@ namespace Tensorflow | |||
return ops.set_default_graph(this); | |||
} | |||
public bool IsFunction(string name) | |||
{ | |||
return _functions.ContainsKey(tf.compat.as_str(name)); | |||
} | |||
internal void AddFunction(EagerDefinedFunction function) | |||
{ | |||
_check_not_finalized(); | |||
var name = function.Name; | |||
if(function._grad_func_name is not null && function.csharp_grad_func is not null) | |||
{ | |||
throw new ValueError($"Gradient defined twice for function {name}"); | |||
} | |||
var c_graph = this.c_graph; | |||
var func = function._c_func.Get(); | |||
Status status = new(); | |||
if (function._grad_func is not null) | |||
{ | |||
var gradient = function._grad_func._c_func.Get(); | |||
c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); | |||
status.Check(true); | |||
} | |||
else | |||
{ | |||
c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); | |||
status.Check(true); | |||
} | |||
_functions[tf.compat.as_str(name)] = function; | |||
if(_graph_def_versions.MinConsumer < 12) | |||
{ | |||
_graph_def_versions.MinConsumer = 12; | |||
} | |||
} | |||
private Tensor _as_graph_element(object obj) | |||
{ | |||
if (obj is RefVariable var) | |||
@@ -308,6 +358,9 @@ namespace Tensorflow | |||
private void _create_op_helper(Operation op, bool compute_device = true) | |||
{ | |||
// high priority | |||
// TODO(Rinne): complete the implementation. | |||
op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); | |||
_record_op_seen_by_control_dependencies(op); | |||
} | |||
@@ -524,6 +577,11 @@ namespace Tensorflow | |||
ops.pop_graph(); | |||
} | |||
internal EagerDefinedFunction _get_function(string name) | |||
{ | |||
return _functions.GetOrDefault(name, null); | |||
} | |||
string debugString = string.Empty; | |||
public override string ToString() | |||
{ | |||
@@ -0,0 +1,37 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
namespace Tensorflow.Graphs | |||
{ | |||
internal class GraphOverrideGradientContext: ITensorFlowObject | |||
{ | |||
Graph _graph; | |||
Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map; | |||
public GraphOverrideGradientContext(Graph graph, | |||
Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map) | |||
{ | |||
_graph = graph; | |||
_new_gradient_function_map = new_gradient_function_map; | |||
} | |||
[DebuggerStepThrough] | |||
public void __enter__() | |||
{ | |||
Debug.Assert(_graph._gradient_function_map.Count == 0); | |||
_graph._gradient_function_map = _new_gradient_function_map; | |||
} | |||
[DebuggerStepThrough] | |||
public void __exit__() | |||
{ | |||
_graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>(); | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||
} | |||
public SafeImportGraphDefOptionsHandle Options => _handle; | |||
public void AddReturnOutput(string name, int index) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
@@ -185,6 +185,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||
/// <summary> | |||
/// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
@@ -246,7 +249,7 @@ namespace Tensorflow | |||
/// <param name="ops">TF_ImportGraphDefOptions*</param> | |||
/// <param name="uniquify_prefix">unsigned char</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||
/// <summary> | |||
/// Fetches the return operations requested via | |||
@@ -308,7 +311,7 @@ namespace Tensorflow | |||
/// <param name="types">const TF_DataType*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | |||
SafeStatusHandle status); | |||
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
/// This class has nothing but the attributes different from `LayerArgs`. | |||
/// It's used to serialize the model to `tf` format. | |||
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | |||
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. | |||
/// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. | |||
/// </summary> | |||
public class AutoSerializeLayerArgs: LayerArgs | |||
{ | |||
@@ -7,6 +7,11 @@ using System.Text; | |||
namespace Tensorflow.Keras.Common | |||
{ | |||
class ShapeInfoFromPython | |||
{ | |||
public string class_name { get; set; } | |||
public long?[] items { get; set; } | |||
} | |||
public class CustomizedShapeJsonConverter: JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
@@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common | |||
dims[i] = shape.dims[i]; | |||
} | |||
} | |||
var token = JToken.FromObject(dims); | |||
var token = JToken.FromObject(new ShapeInfoFromPython() | |||
{ | |||
class_name = "__tuple__", | |||
items = dims | |||
}); | |||
token.WriteTo(writer); | |||
} | |||
} | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
long?[] dims; | |||
try | |||
{ | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
catch (JsonSerializationException ex) | |||
{ | |||
if (reader.Value.Equals("class_name")) | |||
{ | |||
reader.Read(); | |||
reader.Read(); | |||
reader.Read(); | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
else | |||
{ | |||
throw ex; | |||
} | |||
} | |||
if (dims is null) | |||
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | |||
if (shape_info_from_python is null) | |||
{ | |||
return null; | |||
} | |||
long ?[]dims = shape_info_from_python.items; | |||
long[] convertedDims = new long[dims.Length]; | |||
for(int i = 0; i < dims.Length; i++) | |||
{ | |||
@@ -4,10 +4,10 @@ public interface IOptimizer | |||
{ | |||
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | |||
Tensor[] clip_gradients(Tensor[] grads); | |||
void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true); | |||
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true); | |||
} |
@@ -20,6 +20,9 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
using Google.Protobuf; | |||
using Google.Protobuf.WellKnownTypes; | |||
using System.Diagnostics; | |||
namespace Tensorflow | |||
{ | |||
@@ -47,6 +50,8 @@ namespace Tensorflow | |||
private readonly Graph _graph; | |||
internal Func<Operation, object[], Tensor[]> _gradient_function; | |||
public string type => OpType; | |||
public Graph graph => _graph; | |||
@@ -61,7 +66,7 @@ namespace Tensorflow | |||
public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
// OperationDescription _opDesc; | |||
//private OperationDescription _op_desc; | |||
public NodeDef node_def => GetNodeDef(); | |||
@@ -216,21 +221,19 @@ namespace Tensorflow | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
string oneof_value = x.ValueCase.ToString(); | |||
if (string.IsNullOrEmpty(oneof_value)) | |||
return null; | |||
var oneof_value = x.ValueCase; | |||
if (oneof_value == AttrValue.ValueOneofCase.None) | |||
return new object[0]; | |||
switch (oneof_value.ToLower()) | |||
if(oneof_value == AttrValue.ValueOneofCase.List) | |||
{ | |||
case "list": | |||
throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
case "type": | |||
return x.Type; | |||
case "s": | |||
return x.S.ToStringUtf8(); | |||
default: | |||
return x.GetType().GetProperty(oneof_value).GetValue(x); | |||
throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
} | |||
if(oneof_value == AttrValue.ValueOneofCase.Type) | |||
{ | |||
return dtypes.as_tf_dtype(x.Type); | |||
} | |||
return ProtoUtils.GetSingleAttrValue(x, oneof_value); | |||
} | |||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||
@@ -238,6 +241,19 @@ namespace Tensorflow | |||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
} | |||
[Obsolete("The implementation is not complete.")] | |||
internal void _set_device_from_string(string device_str) | |||
{ | |||
// TODO(Rinne): complete it with new C API `SetRequestedDevice`. | |||
//c_api.TF_SetDevice(_handle, device_str); | |||
} | |||
[Obsolete("The implementation is not complete.")] | |||
internal void _set_device(string device) | |||
{ | |||
_set_device_from_string(device); | |||
} | |||
private NodeDef GetNodeDef() | |||
{ | |||
var buffer = new Buffer(); | |||
@@ -296,5 +312,60 @@ namespace Tensorflow | |||
} | |||
public NDArray numpy() => throw new NotImplementedException(""); | |||
internal void _add_outputs(TF_DataType[] types, Shape[] shapes) | |||
{ | |||
Debug.Assert(types.Length == shapes.Length); | |||
int orig_num_outputs = this.outputs.Length; | |||
var new_outputs = new List<Tensor>(_outputs); | |||
// Since the `_outputs` is defined as `Array`, when we add new output, we | |||
// have to create a new array, which brings some performance concerns. | |||
// In the future maybe the type of `outputs` should be reconsidered. | |||
for(int i = 0; i < types.Length; i++) | |||
{ | |||
var t = new Tensor(this, orig_num_outputs + i, types[i]); | |||
t.shape = shapes[i]; | |||
new_outputs.Add(t); | |||
} | |||
_outputs = new_outputs.ToArray(); | |||
} | |||
internal void _set_func_attr(string attr_name, string func_name) | |||
{ | |||
var func = new NameAttrList() { Name = func_name }; | |||
_set_attr(attr_name, new AttrValue() { Func = func }); | |||
} | |||
internal void _set_type_list_attr(string attr_name, DataType[] types) | |||
{ | |||
if(types is null || types.Length == 0) | |||
{ | |||
return; | |||
} | |||
var type_list = new AttrValue.Types.ListValue(); | |||
type_list.Type.AddRange(types); | |||
_set_attr(attr_name, new AttrValue() { List = type_list }); | |||
} | |||
internal void _set_attr(string attr_name, AttrValue attr_value) | |||
{ | |||
var buffer = new Buffer(attr_value.ToByteArray()); | |||
try | |||
{ | |||
_set_attr_with_buf(attr_name, buffer); | |||
} | |||
finally | |||
{ | |||
buffer.Release(); | |||
} | |||
} | |||
internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) | |||
{ | |||
Status status = new(); | |||
c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status); | |||
status.Check(true); | |||
} | |||
} | |||
} |
@@ -14,10 +14,14 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using Google.Protobuf.WellKnownTypes; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -25,6 +29,74 @@ namespace Tensorflow | |||
{ | |||
public class functional_ops | |||
{ | |||
public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, | |||
bool executing_eagerly, string config, string executor_type) | |||
{ | |||
if (tout is null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
if (config is null) | |||
{ | |||
config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
} | |||
if (executor_type is null) | |||
{ | |||
executor_type = ""; | |||
} | |||
if (executing_eagerly) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
} | |||
var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); | |||
AttrValue tin_attr = new() | |||
{ | |||
List = new AttrValue.Types.ListValue() | |||
}; | |||
tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); | |||
AttrValue tout_attr = new() | |||
{ | |||
List = new AttrValue.Types.ListValue() | |||
}; | |||
tout_attr.List.Type.AddRange(tout); | |||
AttrValue func_attr = new() | |||
{ | |||
Func = new NameAttrList() | |||
}; | |||
func_attr.Func.Name = f.Name; | |||
AttrValue executor_type_attr = new AttrValue() | |||
{ | |||
S = tf.compat.as_bytes(executor_type) | |||
}; | |||
AttrValue config_proto = new AttrValue() | |||
{ | |||
S = ByteString.CopyFromUtf8(executor_type) | |||
}; | |||
var graph = ops.get_default_graph(); | |||
f.AddToGraph(graph); | |||
// TODO(Rinne): complete it with `f.stateful` | |||
var op_name = "PartitionedCall"; | |||
string xla_compile_attr = "_XlaMustCompile"; | |||
Dictionary<string, AttrValue> op_attrs = new(); | |||
op_attrs["Tin"] = tin_attr; | |||
op_attrs["Tout"] = tout_attr; | |||
op_attrs["f"] = func_attr; | |||
op_attrs["config_proto"] = config_proto; | |||
op_attrs["executor_type"] = executor_type_attr; | |||
// TODO(Rinne): deal with `f.definition`. | |||
var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), | |||
name: op_name, attrs: op_attrs); | |||
var outputs = op.outputs; | |||
// TODO(Rinne): deal with `f.graph`. | |||
return outputs; | |||
} | |||
public static Tensor scan( | |||
Func<Tensor, Tensor, Tensor> fn, | |||
Tensor elems, | |||
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -210,7 +211,51 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor`. Has the same type as `value`.</returns> | |||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||
return _result[0]; | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
try | |||
{ | |||
return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||
attrs["dims"] = dims; | |||
attrs["value"] = value; | |||
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return result.output; | |||
} | |||
public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return _result[0]; | |||
} | |||
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
/// <summary> | |||
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | |||
@@ -0,0 +1,128 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Xml.Linq; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class gen_functional_ops | |||
{ | |||
public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
string config = "", string config_proto = "", string executor_type = "", string name = null) | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||
args, tout, f, config, config_proto, executor_type)); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
if (config is null) | |||
{ | |||
config = ""; | |||
} | |||
if (config_proto is null) | |||
{ | |||
config_proto = ""; | |||
} | |||
if (executor_type is null) | |||
{ | |||
executor_type = ""; | |||
} | |||
Dictionary<string, object> kwargs = new(); | |||
kwargs["args"] = args; | |||
kwargs["Tout"] = tout; | |||
kwargs["f"] = f; | |||
kwargs["config"] = config; | |||
kwargs["config_proto"] = config_proto; | |||
kwargs["executor_type"] = executor_type; | |||
var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||
name, kwargs); | |||
var result = output.outputs; | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return result; | |||
} | |||
public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
string config, string config_proto, string executor_type, string name, Context ctx) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
if(config is null) | |||
{ | |||
config = ""; | |||
} | |||
if(config_proto is null) | |||
{ | |||
config_proto = ""; | |||
} | |||
if(executor_type is null) | |||
{ | |||
executor_type = ""; | |||
} | |||
object[] attrs = new object[] | |||
{ | |||
}; | |||
} | |||
public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"SymbolicGradient", name, input, Tout, f)); | |||
return _result; | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
try | |||
{ | |||
return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); | |||
var result = op.outputs; | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return result; | |||
} | |||
public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; | |||
var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return result; | |||
} | |||
} | |||
} |
@@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); | |||
return _result[0]; | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
try | |||
{ | |||
return ensure_shape_eager_fallback(input, shape, name, ctx); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
dict["shape"] = shape; | |||
var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return op.output; | |||
} | |||
public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; | |||
var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, | |||
attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return _result[0]; | |||
} | |||
/// <summary> | |||
/// Creates or finds a child frame, and makes <c>data</c> available to the child frame. | |||
/// </summary> | |||
@@ -0,0 +1,60 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
namespace Tensorflow.Operations | |||
{ | |||
public static class handle_data_util | |||
{ | |||
public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||
{ | |||
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||
{ | |||
HandleData handle_data; | |||
if(source_t is EagerTensor) | |||
{ | |||
handle_data = source_t.HandleData; | |||
} | |||
else | |||
{ | |||
handle_data = ops.get_resource_handle_data(source_t); | |||
} | |||
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null | |||
&& handle_data.ShapeAndType.Count > 0) | |||
{ | |||
set_handle_data(target_t, handle_data); | |||
} | |||
} | |||
} | |||
public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||
{ | |||
HandleData handle_data = new(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
{ | |||
Shape = shape.as_proto(), | |||
Dtype = dtype.as_datatype_enum() | |||
}); | |||
return handle_data; | |||
} | |||
public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||
{ | |||
if(target_t is EagerTensor) | |||
{ | |||
target_t.HandleData = handle_data; | |||
return; | |||
} | |||
Status status = new(); | |||
var proto = handle_data.ToByteArray(); | |||
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
status.Check(true); | |||
} | |||
public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||
} | |||
} |
@@ -21,6 +21,11 @@ using Tensorflow.Train; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Variables; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using System.Buffers; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow | |||
{ | |||
@@ -31,18 +36,14 @@ namespace Tensorflow | |||
{ | |||
public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | |||
{ | |||
// TODO(Rinne): deal with `_handle_graph`. | |||
var value_tensor = ops.convert_to_tensor(value); | |||
return gen_resource_variable_ops.assign_variable_op(handle, | |||
value_tensor, | |||
name: name); | |||
} | |||
public static bool is_resource_variable(IVariableV1 var) | |||
{ | |||
return var is ResourceVariable; | |||
} | |||
public static bool is_resource_variable(Trackable var) | |||
public static bool is_resource_variable(object var) | |||
{ | |||
return var is BaseResourceVariable; | |||
} | |||
@@ -78,6 +79,18 @@ namespace Tensorflow | |||
string shared_name, string name, bool graph_mode, Tensor initial_value = null) | |||
{ | |||
var container = ops.get_default_graph().Container; | |||
if(container is null) | |||
{ | |||
container = ""; | |||
} | |||
if (!graph_mode) | |||
{ | |||
if(shared_name is not null) | |||
{ | |||
throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||
} | |||
shared_name = tf.Context.anonymous_name(); | |||
} | |||
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | |||
dtype: dtype, | |||
shared_name: shared_name, | |||
@@ -95,26 +108,20 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
// We do not want two distinct ResourceVariable objects for the same | |||
// underlying resource in the runtime. | |||
// When in eager mode, explicitly ensure so here. When in graph mode, it's | |||
// ensured by always generating different variable names. | |||
var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
// We create an assert Op instead of checking right away in order to be | |||
// compatible with ASYNC execution mode. Further, since not all devices | |||
// support string tensors, we encode the assertion string in the Op name | |||
/*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||
new[] { exists }, | |||
name: "EagerVariableNameReuse");*/ | |||
var handle_data = new HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType | |||
var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||
if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||
{ | |||
Dtype = dtype.as_datatype_enum(), | |||
Shape = shape.as_proto() | |||
}); | |||
var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||
if (extra_handle_data is not null && extra_handle_data.IsSet) | |||
{ | |||
if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
{ | |||
throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||
$"but saw: '{handle_data}'"); | |||
} | |||
handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||
} | |||
} | |||
_set_handle_shapes_and_types(handle, handle_data, graph_mode); | |||
return handle; | |||
} | |||
@@ -126,7 +133,7 @@ namespace Tensorflow | |||
/// <param name="handle"></param> | |||
/// <param name="handle_data"></param> | |||
/// <param name="graph_mode"></param> | |||
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
{ | |||
if (!graph_mode) | |||
return; | |||
@@ -144,6 +151,47 @@ namespace Tensorflow | |||
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||
} | |||
//tensor.HandleData = handle_data; | |||
//if (!graph_mode) | |||
// return; | |||
//var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||
//var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||
//{ | |||
// if (!s.UnknownRank) | |||
// { | |||
// return s.Dim.Select(d => (int)d.Size).ToArray(); | |||
// } | |||
// else | |||
// { | |||
// return Memory<int>.Empty; | |||
// } | |||
//}).ToArray(); | |||
//List<MemoryHandle> handles = new(); | |||
//IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||
//foreach(var (i, m) in enumerate(converted_shapes)) | |||
//{ | |||
// if(m.IsEmpty) | |||
// { | |||
// shapes_with_ptr[i] = IntPtr.Zero; | |||
// } | |||
// else | |||
// { | |||
// var handle = m.Pin(); | |||
// handles.Add(handle); | |||
// shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||
// } | |||
//} | |||
//Status status = new(); | |||
//// TODO(Rinne): enable it. | |||
//c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||
// shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||
//handles = null; | |||
} | |||
/// <summary> | |||
@@ -162,24 +210,6 @@ namespace Tensorflow | |||
throw new NotImplementedException(""); | |||
} | |||
private static HandleData get_eager_safe_handle_data(Tensor handle) | |||
{ | |||
if (handle.Handle == null) | |||
{ | |||
var data = new HandleData(); | |||
data.ShapeAndType.Add(new HandleShapeAndType | |||
{ | |||
Shape = handle.shape.as_shape_proto(), | |||
Dtype = handle.dtype.as_datatype_enum() | |||
}); | |||
return data; | |||
} | |||
else | |||
{ | |||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
} | |||
} | |||
/// <summary> | |||
/// Copies an existing variable to a new graph, with no initializer. | |||
/// </summary> | |||
@@ -231,5 +261,60 @@ namespace Tensorflow | |||
} | |||
} | |||
} | |||
public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) | |||
{ | |||
if(dtype == dtypes.variant) | |||
{ | |||
var handle_data = get_eager_safe_handle_data(handle); | |||
if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) | |||
{ | |||
tensor.HandleData = new HandleData() | |||
{ | |||
IsSet = true | |||
}; | |||
tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); | |||
} | |||
} | |||
} | |||
public static HandleData get_eager_safe_handle_data(Tensor handle) | |||
{ | |||
if (handle.Handle == null) | |||
{ | |||
var data = new HandleData(); | |||
data.ShapeAndType.Add(new HandleShapeAndType | |||
{ | |||
Shape = handle.shape.as_shape_proto(), | |||
Dtype = handle.dtype.as_datatype_enum() | |||
}); | |||
return data; | |||
} | |||
else | |||
{ | |||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
} | |||
//if(handle is EagerTensor) | |||
//{ | |||
// return handle.HandleData; | |||
//} | |||
//else | |||
//{ | |||
// return handle_data_util.get_resource_handle_data(handle); | |||
//} | |||
} | |||
public static void variable_accessed(IVariableV1 variable) | |||
{ | |||
if (ops.get_default_graph() is FuncGraph func_graph) | |||
{ | |||
func_graph.watch_variable(variable); | |||
} | |||
if (variable.Trainable) | |||
{ | |||
foreach (var tape in tf.GetTapeSet()) | |||
tape.VariableAccessed(variable); | |||
} | |||
} | |||
} | |||
} |
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/framework/allocation_description.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -43,23 +43,31 @@ namespace Tensorflow { | |||
} | |||
#region Messages | |||
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> { | |||
public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<AllocationDescription> _parser = new pb::MessageParser<AllocationDescription>(() => new AllocationDescription()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<AllocationDescription> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AllocationDescription() { | |||
OnConstruction(); | |||
} | |||
@@ -67,6 +75,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AllocationDescription(AllocationDescription other) : this() { | |||
requestedBytes_ = other.requestedBytes_; | |||
allocatedBytes_ = other.allocatedBytes_; | |||
@@ -78,6 +87,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AllocationDescription Clone() { | |||
return new AllocationDescription(this); | |||
} | |||
@@ -89,6 +99,7 @@ namespace Tensorflow { | |||
/// Total number of bytes requested | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long RequestedBytes { | |||
get { return requestedBytes_; } | |||
set { | |||
@@ -103,6 +114,7 @@ namespace Tensorflow { | |||
/// Total number of bytes allocated if known | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long AllocatedBytes { | |||
get { return allocatedBytes_; } | |||
set { | |||
@@ -117,6 +129,7 @@ namespace Tensorflow { | |||
/// Name of the allocator used | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string AllocatorName { | |||
get { return allocatorName_; } | |||
set { | |||
@@ -131,6 +144,7 @@ namespace Tensorflow { | |||
/// Identifier of the allocated buffer if known | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long AllocationId { | |||
get { return allocationId_; } | |||
set { | |||
@@ -145,6 +159,7 @@ namespace Tensorflow { | |||
/// Set if this tensor only has one remaining reference | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool HasSingleReference { | |||
get { return hasSingleReference_; } | |||
set { | |||
@@ -159,6 +174,7 @@ namespace Tensorflow { | |||
/// Address of the allocation. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ulong Ptr { | |||
get { return ptr_; } | |||
set { | |||
@@ -167,11 +183,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as AllocationDescription); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(AllocationDescription other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -189,6 +207,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode(); | |||
@@ -204,12 +223,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (RequestedBytes != 0L) { | |||
output.WriteRawTag(8); | |||
output.WriteInt64(RequestedBytes); | |||
@@ -237,9 +261,45 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (RequestedBytes != 0L) { | |||
output.WriteRawTag(8); | |||
output.WriteInt64(RequestedBytes); | |||
} | |||
if (AllocatedBytes != 0L) { | |||
output.WriteRawTag(16); | |||
output.WriteInt64(AllocatedBytes); | |||
} | |||
if (AllocatorName.Length != 0) { | |||
output.WriteRawTag(26); | |||
output.WriteString(AllocatorName); | |||
} | |||
if (AllocationId != 0L) { | |||
output.WriteRawTag(32); | |||
output.WriteInt64(AllocationId); | |||
} | |||
if (HasSingleReference != false) { | |||
output.WriteRawTag(40); | |||
output.WriteBool(HasSingleReference); | |||
} | |||
if (Ptr != 0UL) { | |||
output.WriteRawTag(48); | |||
output.WriteUInt64(Ptr); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (RequestedBytes != 0L) { | |||
@@ -267,6 +327,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(AllocationDescription other) { | |||
if (other == null) { | |||
return; | |||
@@ -293,7 +354,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -326,7 +391,47 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 8: { | |||
RequestedBytes = input.ReadInt64(); | |||
break; | |||
} | |||
case 16: { | |||
AllocatedBytes = input.ReadInt64(); | |||
break; | |||
} | |||
case 26: { | |||
AllocatorName = input.ReadString(); | |||
break; | |||
} | |||
case 32: { | |||
AllocationId = input.ReadInt64(); | |||
break; | |||
} | |||
case 40: { | |||
HasSingleReference = input.ReadBool(); | |||
break; | |||
} | |||
case 48: { | |||
Ptr = input.ReadUInt64(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/framework/attr_value.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -63,23 +63,31 @@ namespace Tensorflow { | |||
/// Comment indicates the corresponding attr type. Only the field matching the | |||
/// attr type may be filled. | |||
/// </summary> | |||
public sealed partial class AttrValue : pb::IMessage<AttrValue> { | |||
public sealed partial class AttrValue : pb::IMessage<AttrValue> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<AttrValue> _parser = new pb::MessageParser<AttrValue>(() => new AttrValue()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<AttrValue> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AttrValue() { | |||
OnConstruction(); | |||
} | |||
@@ -87,6 +95,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AttrValue(AttrValue other) : this() { | |||
switch (other.ValueCase) { | |||
case ValueOneofCase.S: | |||
@@ -125,6 +134,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public AttrValue Clone() { | |||
return new AttrValue(this); | |||
} | |||
@@ -135,6 +145,7 @@ namespace Tensorflow { | |||
/// "string" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pb::ByteString S { | |||
get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } | |||
set { | |||
@@ -149,6 +160,7 @@ namespace Tensorflow { | |||
/// "int" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long I { | |||
get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } | |||
set { | |||
@@ -163,6 +175,7 @@ namespace Tensorflow { | |||
/// "float" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public float F { | |||
get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } | |||
set { | |||
@@ -177,6 +190,7 @@ namespace Tensorflow { | |||
/// "bool" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool B { | |||
get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } | |||
set { | |||
@@ -191,6 +205,7 @@ namespace Tensorflow { | |||
/// "type" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.DataType Type { | |||
get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; } | |||
set { | |||
@@ -205,6 +220,7 @@ namespace Tensorflow { | |||
/// "shape" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.TensorShapeProto Shape { | |||
get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } | |||
set { | |||
@@ -219,6 +235,7 @@ namespace Tensorflow { | |||
/// "tensor" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.TensorProto Tensor { | |||
get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } | |||
set { | |||
@@ -233,6 +250,7 @@ namespace Tensorflow { | |||
/// any "list(...)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.AttrValue.Types.ListValue List { | |||
get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } | |||
set { | |||
@@ -250,6 +268,7 @@ namespace Tensorflow { | |||
/// that attr in the instantiation. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.NameAttrList Func { | |||
get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } | |||
set { | |||
@@ -270,6 +289,7 @@ namespace Tensorflow { | |||
/// given the value "bar". | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Placeholder { | |||
get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } | |||
set { | |||
@@ -295,22 +315,26 @@ namespace Tensorflow { | |||
} | |||
private ValueOneofCase valueCase_ = ValueOneofCase.None; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ValueOneofCase ValueCase { | |||
get { return valueCase_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void ClearValue() { | |||
valueCase_ = ValueOneofCase.None; | |||
value_ = null; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as AttrValue); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(AttrValue other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -333,6 +357,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); | |||
@@ -353,12 +378,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (valueCase_ == ValueOneofCase.List) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(List); | |||
@@ -402,9 +432,61 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (valueCase_ == ValueOneofCase.List) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(List); | |||
} | |||
if (valueCase_ == ValueOneofCase.S) { | |||
output.WriteRawTag(18); | |||
output.WriteBytes(S); | |||
} | |||
if (valueCase_ == ValueOneofCase.I) { | |||
output.WriteRawTag(24); | |||
output.WriteInt64(I); | |||
} | |||
if (valueCase_ == ValueOneofCase.F) { | |||
output.WriteRawTag(37); | |||
output.WriteFloat(F); | |||
} | |||
if (valueCase_ == ValueOneofCase.B) { | |||
output.WriteRawTag(40); | |||
output.WriteBool(B); | |||
} | |||
if (valueCase_ == ValueOneofCase.Type) { | |||
output.WriteRawTag(48); | |||
output.WriteEnum((int) Type); | |||
} | |||
if (valueCase_ == ValueOneofCase.Shape) { | |||
output.WriteRawTag(58); | |||
output.WriteMessage(Shape); | |||
} | |||
if (valueCase_ == ValueOneofCase.Tensor) { | |||
output.WriteRawTag(66); | |||
output.WriteMessage(Tensor); | |||
} | |||
if (valueCase_ == ValueOneofCase.Placeholder) { | |||
output.WriteRawTag(74); | |||
output.WriteString(Placeholder); | |||
} | |||
if (valueCase_ == ValueOneofCase.Func) { | |||
output.WriteRawTag(82); | |||
output.WriteMessage(Func); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (valueCase_ == ValueOneofCase.S) { | |||
@@ -444,6 +526,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(AttrValue other) { | |||
if (other == null) { | |||
return; | |||
@@ -497,7 +580,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -567,32 +654,118 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue(); | |||
if (valueCase_ == ValueOneofCase.List) { | |||
subBuilder.MergeFrom(List); | |||
} | |||
input.ReadMessage(subBuilder); | |||
List = subBuilder; | |||
break; | |||
} | |||
case 18: { | |||
S = input.ReadBytes(); | |||
break; | |||
} | |||
case 24: { | |||
I = input.ReadInt64(); | |||
break; | |||
} | |||
case 37: { | |||
F = input.ReadFloat(); | |||
break; | |||
} | |||
case 40: { | |||
B = input.ReadBool(); | |||
break; | |||
} | |||
case 48: { | |||
value_ = input.ReadEnum(); | |||
valueCase_ = ValueOneofCase.Type; | |||
break; | |||
} | |||
case 58: { | |||
global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); | |||
if (valueCase_ == ValueOneofCase.Shape) { | |||
subBuilder.MergeFrom(Shape); | |||
} | |||
input.ReadMessage(subBuilder); | |||
Shape = subBuilder; | |||
break; | |||
} | |||
case 66: { | |||
global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto(); | |||
if (valueCase_ == ValueOneofCase.Tensor) { | |||
subBuilder.MergeFrom(Tensor); | |||
} | |||
input.ReadMessage(subBuilder); | |||
Tensor = subBuilder; | |||
break; | |||
} | |||
case 74: { | |||
Placeholder = input.ReadString(); | |||
break; | |||
} | |||
case 82: { | |||
global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList(); | |||
if (valueCase_ == ValueOneofCase.Func) { | |||
subBuilder.MergeFrom(Func); | |||
} | |||
input.ReadMessage(subBuilder); | |||
Func = subBuilder; | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
#region Nested types | |||
/// <summary>Container for nested types declared in the AttrValue message type.</summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static partial class Types { | |||
/// <summary> | |||
/// LINT.IfChange | |||
/// </summary> | |||
public sealed partial class ListValue : pb::IMessage<ListValue> { | |||
public sealed partial class ListValue : pb::IMessage<ListValue> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<ListValue> _parser = new pb::MessageParser<ListValue>(() => new ListValue()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<ListValue> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ListValue() { | |||
OnConstruction(); | |||
} | |||
@@ -600,6 +773,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ListValue(ListValue other) : this() { | |||
s_ = other.s_.Clone(); | |||
i_ = other.i_.Clone(); | |||
@@ -613,6 +787,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ListValue Clone() { | |||
return new ListValue(this); | |||
} | |||
@@ -626,6 +801,7 @@ namespace Tensorflow { | |||
/// "list(string)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<pb::ByteString> S { | |||
get { return s_; } | |||
} | |||
@@ -639,6 +815,7 @@ namespace Tensorflow { | |||
/// "list(int)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<long> I { | |||
get { return i_; } | |||
} | |||
@@ -652,6 +829,7 @@ namespace Tensorflow { | |||
/// "list(float)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<float> F { | |||
get { return f_; } | |||
} | |||
@@ -665,6 +843,7 @@ namespace Tensorflow { | |||
/// "list(bool)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<bool> B { | |||
get { return b_; } | |||
} | |||
@@ -678,6 +857,7 @@ namespace Tensorflow { | |||
/// "list(type)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.DataType> Type { | |||
get { return type_; } | |||
} | |||
@@ -691,6 +871,7 @@ namespace Tensorflow { | |||
/// "list(shape)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.TensorShapeProto> Shape { | |||
get { return shape_; } | |||
} | |||
@@ -704,6 +885,7 @@ namespace Tensorflow { | |||
/// "list(tensor)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.TensorProto> Tensor { | |||
get { return tensor_; } | |||
} | |||
@@ -717,16 +899,19 @@ namespace Tensorflow { | |||
/// "list(attr)" | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.NameAttrList> Func { | |||
get { return func_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as ListValue); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(ListValue other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -746,6 +931,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= s_.GetHashCode(); | |||
@@ -763,12 +949,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
s_.WriteTo(output, _repeated_s_codec); | |||
i_.WriteTo(output, _repeated_i_codec); | |||
f_.WriteTo(output, _repeated_f_codec); | |||
@@ -780,9 +971,29 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
s_.WriteTo(ref output, _repeated_s_codec); | |||
i_.WriteTo(ref output, _repeated_i_codec); | |||
f_.WriteTo(ref output, _repeated_f_codec); | |||
b_.WriteTo(ref output, _repeated_b_codec); | |||
type_.WriteTo(ref output, _repeated_type_codec); | |||
shape_.WriteTo(ref output, _repeated_shape_codec); | |||
tensor_.WriteTo(ref output, _repeated_tensor_codec); | |||
func_.WriteTo(ref output, _repeated_func_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += s_.CalculateSize(_repeated_s_codec); | |||
@@ -800,6 +1011,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(ListValue other) { | |||
if (other == null) { | |||
return; | |||
@@ -816,7 +1028,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -861,7 +1077,59 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 18: { | |||
s_.AddEntriesFrom(ref input, _repeated_s_codec); | |||
break; | |||
} | |||
case 26: | |||
case 24: { | |||
i_.AddEntriesFrom(ref input, _repeated_i_codec); | |||
break; | |||
} | |||
case 34: | |||
case 37: { | |||
f_.AddEntriesFrom(ref input, _repeated_f_codec); | |||
break; | |||
} | |||
case 42: | |||
case 40: { | |||
b_.AddEntriesFrom(ref input, _repeated_b_codec); | |||
break; | |||
} | |||
case 50: | |||
case 48: { | |||
type_.AddEntriesFrom(ref input, _repeated_type_codec); | |||
break; | |||
} | |||
case 58: { | |||
shape_.AddEntriesFrom(ref input, _repeated_shape_codec); | |||
break; | |||
} | |||
case 66: { | |||
tensor_.AddEntriesFrom(ref input, _repeated_tensor_codec); | |||
break; | |||
} | |||
case 74: { | |||
func_.AddEntriesFrom(ref input, _repeated_func_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -874,23 +1142,31 @@ namespace Tensorflow { | |||
/// A list of attr names and their values. The whole list is attached | |||
/// with a string name. E.g., MatMul[T=float]. | |||
/// </summary> | |||
public sealed partial class NameAttrList : pb::IMessage<NameAttrList> { | |||
public sealed partial class NameAttrList : pb::IMessage<NameAttrList> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<NameAttrList> _parser = new pb::MessageParser<NameAttrList>(() => new NameAttrList()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<NameAttrList> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public NameAttrList() { | |||
OnConstruction(); | |||
} | |||
@@ -898,6 +1174,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public NameAttrList(NameAttrList other) : this() { | |||
name_ = other.name_; | |||
attr_ = other.attr_.Clone(); | |||
@@ -905,6 +1182,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public NameAttrList Clone() { | |||
return new NameAttrList(this); | |||
} | |||
@@ -913,6 +1191,7 @@ namespace Tensorflow { | |||
public const int NameFieldNumber = 1; | |||
private string name_ = ""; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Name { | |||
get { return name_; } | |||
set { | |||
@@ -926,16 +1205,19 @@ namespace Tensorflow { | |||
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); | |||
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | |||
get { return attr_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as NameAttrList); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(NameAttrList other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -949,6 +1231,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
@@ -960,12 +1243,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
@@ -974,9 +1262,26 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
} | |||
attr_.WriteTo(ref output, _map_attr_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Name.Length != 0) { | |||
@@ -990,6 +1295,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(NameAttrList other) { | |||
if (other == null) { | |||
return; | |||
@@ -1002,7 +1308,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -1019,7 +1329,31 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
Name = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
attr_.AddEntriesFrom(ref input, _map_attr_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/python/training/checkpoint_state.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -43,23 +43,31 @@ namespace Tensorflow { | |||
/// <summary> | |||
/// Protocol buffer representing the checkpoint state. | |||
/// </summary> | |||
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> { | |||
public sealed partial class CheckpointState : pb::IMessage<CheckpointState> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CheckpointState() { | |||
OnConstruction(); | |||
} | |||
@@ -67,6 +75,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CheckpointState(CheckpointState other) : this() { | |||
modelCheckpointPath_ = other.modelCheckpointPath_; | |||
allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | |||
@@ -76,6 +85,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CheckpointState Clone() { | |||
return new CheckpointState(this); | |||
} | |||
@@ -87,6 +97,7 @@ namespace Tensorflow { | |||
/// Path to the most-recent model checkpoint. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string ModelCheckpointPath { | |||
get { return modelCheckpointPath_; } | |||
set { | |||
@@ -106,6 +117,7 @@ namespace Tensorflow { | |||
/// this list. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> AllModelCheckpointPaths { | |||
get { return allModelCheckpointPaths_; } | |||
} | |||
@@ -120,6 +132,7 @@ namespace Tensorflow { | |||
/// when each checkpoint was created. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | |||
get { return allModelCheckpointTimestamps_; } | |||
} | |||
@@ -132,6 +145,7 @@ namespace Tensorflow { | |||
/// checkpoint. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public double LastPreservedTimestamp { | |||
get { return lastPreservedTimestamp_; } | |||
set { | |||
@@ -140,11 +154,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CheckpointState); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CheckpointState other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -160,6 +176,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | |||
@@ -173,12 +190,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (ModelCheckpointPath.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ModelCheckpointPath); | |||
@@ -192,9 +214,31 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (ModelCheckpointPath.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ModelCheckpointPath); | |||
} | |||
allModelCheckpointPaths_.WriteTo(ref output, _repeated_allModelCheckpointPaths_codec); | |||
allModelCheckpointTimestamps_.WriteTo(ref output, _repeated_allModelCheckpointTimestamps_codec); | |||
if (LastPreservedTimestamp != 0D) { | |||
output.WriteRawTag(33); | |||
output.WriteDouble(LastPreservedTimestamp); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (ModelCheckpointPath.Length != 0) { | |||
@@ -212,6 +256,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CheckpointState other) { | |||
if (other == null) { | |||
return; | |||
@@ -228,7 +273,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -254,7 +303,40 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
ModelCheckpointPath = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
allModelCheckpointPaths_.AddEntriesFrom(ref input, _repeated_allModelCheckpointPaths_codec); | |||
break; | |||
} | |||
case 26: | |||
case 25: { | |||
allModelCheckpointTimestamps_.AddEntriesFrom(ref input, _repeated_allModelCheckpointTimestamps_codec); | |||
break; | |||
} | |||
case 33: { | |||
LastPreservedTimestamp = input.ReadDouble(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/protobuf/cluster.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -47,23 +47,31 @@ namespace Tensorflow { | |||
/// <summary> | |||
/// Defines a single job in a TensorFlow cluster. | |||
/// </summary> | |||
public sealed partial class JobDef : pb::IMessage<JobDef> { | |||
public sealed partial class JobDef : pb::IMessage<JobDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<JobDef> _parser = new pb::MessageParser<JobDef>(() => new JobDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<JobDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public JobDef() { | |||
OnConstruction(); | |||
} | |||
@@ -71,6 +79,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public JobDef(JobDef other) : this() { | |||
name_ = other.name_; | |||
tasks_ = other.tasks_.Clone(); | |||
@@ -78,6 +87,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public JobDef Clone() { | |||
return new JobDef(this); | |||
} | |||
@@ -89,6 +99,7 @@ namespace Tensorflow { | |||
/// The name of this job. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Name { | |||
get { return name_; } | |||
set { | |||
@@ -109,16 +120,19 @@ namespace Tensorflow { | |||
/// "/job:worker/task:7" will be assigned to "example.org:2222". | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::MapField<int, string> Tasks { | |||
get { return tasks_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as JobDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(JobDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -132,6 +146,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
@@ -143,12 +158,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
@@ -157,9 +177,26 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
} | |||
tasks_.WriteTo(ref output, _map_tasks_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Name.Length != 0) { | |||
@@ -173,6 +210,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(JobDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -185,7 +223,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -202,30 +244,62 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
Name = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
tasks_.AddEntriesFrom(ref input, _map_tasks_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
/// <summary> | |||
/// Defines a TensorFlow cluster as a set of jobs. | |||
/// </summary> | |||
public sealed partial class ClusterDef : pb::IMessage<ClusterDef> { | |||
public sealed partial class ClusterDef : pb::IMessage<ClusterDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<ClusterDef> _parser = new pb::MessageParser<ClusterDef>(() => new ClusterDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<ClusterDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ClusterDef() { | |||
OnConstruction(); | |||
} | |||
@@ -233,12 +307,14 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ClusterDef(ClusterDef other) : this() { | |||
job_ = other.job_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ClusterDef Clone() { | |||
return new ClusterDef(this); | |||
} | |||
@@ -252,16 +328,19 @@ namespace Tensorflow { | |||
/// The jobs that comprise the cluster. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.JobDef> Job { | |||
get { return job_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as ClusterDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(ClusterDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -274,6 +353,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= job_.GetHashCode(); | |||
@@ -284,19 +364,37 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
job_.WriteTo(output, _repeated_job_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
job_.WriteTo(ref output, _repeated_job_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += job_.CalculateSize(_repeated_job_codec); | |||
@@ -307,6 +405,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(ClusterDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -316,7 +415,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -329,7 +432,27 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
job_.AddEntriesFrom(ref input, _repeated_job_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/protobuf/control_flow.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -64,23 +64,31 @@ namespace Tensorflow { | |||
/// <summary> | |||
/// Protocol buffer representing the values in ControlFlowContext. | |||
/// </summary> | |||
public sealed partial class ValuesDef : pb::IMessage<ValuesDef> { | |||
public sealed partial class ValuesDef : pb::IMessage<ValuesDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<ValuesDef> _parser = new pb::MessageParser<ValuesDef>(() => new ValuesDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<ValuesDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ValuesDef() { | |||
OnConstruction(); | |||
} | |||
@@ -88,6 +96,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ValuesDef(ValuesDef other) : this() { | |||
values_ = other.values_.Clone(); | |||
externalValues_ = other.externalValues_.Clone(); | |||
@@ -95,6 +104,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ValuesDef Clone() { | |||
return new ValuesDef(this); | |||
} | |||
@@ -108,6 +118,7 @@ namespace Tensorflow { | |||
/// Value names that have been seen in this context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> Values { | |||
get { return values_; } | |||
} | |||
@@ -121,16 +132,19 @@ namespace Tensorflow { | |||
/// Value names referenced by but external to this context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::MapField<string, string> ExternalValues { | |||
get { return externalValues_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as ValuesDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(ValuesDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -144,6 +158,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= values_.GetHashCode(); | |||
@@ -155,20 +170,39 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
values_.WriteTo(output, _repeated_values_codec); | |||
externalValues_.WriteTo(output, _map_externalValues_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
values_.WriteTo(ref output, _repeated_values_codec); | |||
externalValues_.WriteTo(ref output, _map_externalValues_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += values_.CalculateSize(_repeated_values_codec); | |||
@@ -180,6 +214,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(ValuesDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -190,7 +225,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -207,7 +246,31 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
values_.AddEntriesFrom(ref input, _repeated_values_codec); | |||
break; | |||
} | |||
case 18: { | |||
externalValues_.AddEntriesFrom(ref input, _map_externalValues_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -215,23 +278,31 @@ namespace Tensorflow { | |||
/// Container for any kind of control flow context. Any other control flow | |||
/// contexts that are added below should also be added here. | |||
/// </summary> | |||
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> { | |||
public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<ControlFlowContextDef> _parser = new pb::MessageParser<ControlFlowContextDef>(() => new ControlFlowContextDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<ControlFlowContextDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ControlFlowContextDef() { | |||
OnConstruction(); | |||
} | |||
@@ -239,6 +310,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ControlFlowContextDef(ControlFlowContextDef other) : this() { | |||
switch (other.CtxtCase) { | |||
case CtxtOneofCase.CondCtxt: | |||
@@ -253,6 +325,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ControlFlowContextDef Clone() { | |||
return new ControlFlowContextDef(this); | |||
} | |||
@@ -260,6 +333,7 @@ namespace Tensorflow { | |||
/// <summary>Field number for the "cond_ctxt" field.</summary> | |||
public const int CondCtxtFieldNumber = 1; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.CondContextDef CondCtxt { | |||
get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } | |||
set { | |||
@@ -271,6 +345,7 @@ namespace Tensorflow { | |||
/// <summary>Field number for the "while_ctxt" field.</summary> | |||
public const int WhileCtxtFieldNumber = 2; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.WhileContextDef WhileCtxt { | |||
get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } | |||
set { | |||
@@ -288,22 +363,26 @@ namespace Tensorflow { | |||
} | |||
private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CtxtOneofCase CtxtCase { | |||
get { return ctxtCase_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void ClearCtxt() { | |||
ctxtCase_ = CtxtOneofCase.None; | |||
ctxt_ = null; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as ControlFlowContextDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(ControlFlowContextDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -318,6 +397,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); | |||
@@ -330,12 +410,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(CondCtxt); | |||
@@ -347,9 +432,29 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(CondCtxt); | |||
} | |||
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||
output.WriteRawTag(18); | |||
output.WriteMessage(WhileCtxt); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
@@ -365,6 +470,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(ControlFlowContextDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -388,7 +494,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -415,30 +525,72 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); | |||
if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
subBuilder.MergeFrom(CondCtxt); | |||
} | |||
input.ReadMessage(subBuilder); | |||
CondCtxt = subBuilder; | |||
break; | |||
} | |||
case 18: { | |||
global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); | |||
if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||
subBuilder.MergeFrom(WhileCtxt); | |||
} | |||
input.ReadMessage(subBuilder); | |||
WhileCtxt = subBuilder; | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
/// <summary> | |||
/// Protocol buffer representing a CondContext object. | |||
/// </summary> | |||
public sealed partial class CondContextDef : pb::IMessage<CondContextDef> { | |||
public sealed partial class CondContextDef : pb::IMessage<CondContextDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CondContextDef> _parser = new pb::MessageParser<CondContextDef>(() => new CondContextDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CondContextDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CondContextDef() { | |||
OnConstruction(); | |||
} | |||
@@ -446,6 +598,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CondContextDef(CondContextDef other) : this() { | |||
contextName_ = other.contextName_; | |||
predName_ = other.predName_; | |||
@@ -457,6 +610,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CondContextDef Clone() { | |||
return new CondContextDef(this); | |||
} | |||
@@ -468,6 +622,7 @@ namespace Tensorflow { | |||
/// Name of the context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string ContextName { | |||
get { return contextName_; } | |||
set { | |||
@@ -482,6 +637,7 @@ namespace Tensorflow { | |||
/// Name of the pred tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PredName { | |||
get { return predName_; } | |||
set { | |||
@@ -496,6 +652,7 @@ namespace Tensorflow { | |||
/// Name of the pivot tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PivotName { | |||
get { return pivotName_; } | |||
set { | |||
@@ -510,6 +667,7 @@ namespace Tensorflow { | |||
/// Branch prediction. 0 or 1. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int Branch { | |||
get { return branch_; } | |||
set { | |||
@@ -524,6 +682,7 @@ namespace Tensorflow { | |||
/// Values and external values in control flow context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.ValuesDef ValuesDef { | |||
get { return valuesDef_; } | |||
set { | |||
@@ -540,16 +699,19 @@ namespace Tensorflow { | |||
/// Contexts contained inside this context (e.g. nested conds). | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | |||
get { return nestedContexts_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CondContextDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CondContextDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -567,6 +729,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | |||
@@ -582,12 +745,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (ContextName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ContextName); | |||
@@ -612,9 +780,42 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (ContextName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ContextName); | |||
} | |||
if (PredName.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(PredName); | |||
} | |||
if (PivotName.Length != 0) { | |||
output.WriteRawTag(26); | |||
output.WriteString(PivotName); | |||
} | |||
if (Branch != 0) { | |||
output.WriteRawTag(32); | |||
output.WriteInt32(Branch); | |||
} | |||
if (valuesDef_ != null) { | |||
output.WriteRawTag(42); | |||
output.WriteMessage(ValuesDef); | |||
} | |||
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (ContextName.Length != 0) { | |||
@@ -640,6 +841,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CondContextDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -667,7 +869,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -703,30 +909,81 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
ContextName = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
PredName = input.ReadString(); | |||
break; | |||
} | |||
case 26: { | |||
PivotName = input.ReadString(); | |||
break; | |||
} | |||
case 32: { | |||
Branch = input.ReadInt32(); | |||
break; | |||
} | |||
case 42: { | |||
if (valuesDef_ == null) { | |||
ValuesDef = new global::Tensorflow.ValuesDef(); | |||
} | |||
input.ReadMessage(ValuesDef); | |||
break; | |||
} | |||
case 50: { | |||
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
/// <summary> | |||
/// Protocol buffer representing a WhileContext object. | |||
/// </summary> | |||
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> { | |||
public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<WhileContextDef> _parser = new pb::MessageParser<WhileContextDef>(() => new WhileContextDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<WhileContextDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public WhileContextDef() { | |||
OnConstruction(); | |||
} | |||
@@ -734,6 +991,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public WhileContextDef(WhileContextDef other) : this() { | |||
contextName_ = other.contextName_; | |||
parallelIterations_ = other.parallelIterations_; | |||
@@ -751,6 +1009,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public WhileContextDef Clone() { | |||
return new WhileContextDef(this); | |||
} | |||
@@ -762,6 +1021,7 @@ namespace Tensorflow { | |||
/// Name of the context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string ContextName { | |||
get { return contextName_; } | |||
set { | |||
@@ -776,6 +1036,7 @@ namespace Tensorflow { | |||
/// The number of iterations allowed to run in parallel. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int ParallelIterations { | |||
get { return parallelIterations_; } | |||
set { | |||
@@ -790,6 +1051,7 @@ namespace Tensorflow { | |||
/// Whether backprop is enabled for this while loop. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool BackProp { | |||
get { return backProp_; } | |||
set { | |||
@@ -804,6 +1066,7 @@ namespace Tensorflow { | |||
/// Whether GPU-CPU memory swap is enabled for this loop. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool SwapMemory { | |||
get { return swapMemory_; } | |||
set { | |||
@@ -818,6 +1081,7 @@ namespace Tensorflow { | |||
/// Name of the pivot tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PivotName { | |||
get { return pivotName_; } | |||
set { | |||
@@ -832,6 +1096,7 @@ namespace Tensorflow { | |||
/// Name of the pivot_for_pred tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PivotForPredName { | |||
get { return pivotForPredName_; } | |||
set { | |||
@@ -846,6 +1111,7 @@ namespace Tensorflow { | |||
/// Name of the pivot_for_body tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PivotForBodyName { | |||
get { return pivotForBodyName_; } | |||
set { | |||
@@ -862,6 +1128,7 @@ namespace Tensorflow { | |||
/// List of names for exit tensors. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> LoopExitNames { | |||
get { return loopExitNames_; } | |||
} | |||
@@ -875,6 +1142,7 @@ namespace Tensorflow { | |||
/// List of names for enter tensors. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> LoopEnterNames { | |||
get { return loopEnterNames_; } | |||
} | |||
@@ -886,6 +1154,7 @@ namespace Tensorflow { | |||
/// Values and external values in control flow context. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.ValuesDef ValuesDef { | |||
get { return valuesDef_; } | |||
set { | |||
@@ -900,6 +1169,7 @@ namespace Tensorflow { | |||
/// Optional name of the maximum_iterations tensor. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string MaximumIterationsName { | |||
get { return maximumIterationsName_; } | |||
set { | |||
@@ -916,16 +1186,19 @@ namespace Tensorflow { | |||
/// Contexts contained inside this context (e.g. nested whiles). | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | |||
get { return nestedContexts_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as WhileContextDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(WhileContextDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -949,6 +1222,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | |||
@@ -970,12 +1244,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (ContextName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ContextName); | |||
@@ -1018,9 +1297,60 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (ContextName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ContextName); | |||
} | |||
if (ParallelIterations != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(ParallelIterations); | |||
} | |||
if (BackProp != false) { | |||
output.WriteRawTag(24); | |||
output.WriteBool(BackProp); | |||
} | |||
if (SwapMemory != false) { | |||
output.WriteRawTag(32); | |||
output.WriteBool(SwapMemory); | |||
} | |||
if (PivotName.Length != 0) { | |||
output.WriteRawTag(42); | |||
output.WriteString(PivotName); | |||
} | |||
if (PivotForPredName.Length != 0) { | |||
output.WriteRawTag(50); | |||
output.WriteString(PivotForPredName); | |||
} | |||
if (PivotForBodyName.Length != 0) { | |||
output.WriteRawTag(58); | |||
output.WriteString(PivotForBodyName); | |||
} | |||
loopExitNames_.WriteTo(ref output, _repeated_loopExitNames_codec); | |||
if (valuesDef_ != null) { | |||
output.WriteRawTag(74); | |||
output.WriteMessage(ValuesDef); | |||
} | |||
loopEnterNames_.WriteTo(ref output, _repeated_loopEnterNames_codec); | |||
if (MaximumIterationsName.Length != 0) { | |||
output.WriteRawTag(90); | |||
output.WriteString(MaximumIterationsName); | |||
} | |||
nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (ContextName.Length != 0) { | |||
@@ -1060,6 +1390,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(WhileContextDef other) { | |||
if (other == null) { | |||
return; | |||
@@ -1101,7 +1432,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -1161,7 +1496,74 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
ContextName = input.ReadString(); | |||
break; | |||
} | |||
case 16: { | |||
ParallelIterations = input.ReadInt32(); | |||
break; | |||
} | |||
case 24: { | |||
BackProp = input.ReadBool(); | |||
break; | |||
} | |||
case 32: { | |||
SwapMemory = input.ReadBool(); | |||
break; | |||
} | |||
case 42: { | |||
PivotName = input.ReadString(); | |||
break; | |||
} | |||
case 50: { | |||
PivotForPredName = input.ReadString(); | |||
break; | |||
} | |||
case 58: { | |||
PivotForBodyName = input.ReadString(); | |||
break; | |||
} | |||
case 66: { | |||
loopExitNames_.AddEntriesFrom(ref input, _repeated_loopExitNames_codec); | |||
break; | |||
} | |||
case 74: { | |||
if (valuesDef_ == null) { | |||
ValuesDef = new global::Tensorflow.ValuesDef(); | |||
} | |||
input.ReadMessage(ValuesDef); | |||
break; | |||
} | |||
case 82: { | |||
loopEnterNames_.AddEntriesFrom(ref input, _repeated_loopEnterNames_codec); | |||
break; | |||
} | |||
case 90: { | |||
MaximumIterationsName = input.ReadString(); | |||
break; | |||
} | |||
case 98: { | |||
nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -0,0 +1,791 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/protobuf/coordination_config.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/coordination_config.proto</summary> | |||
public static partial class CoordinationConfigReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for tensorflow/core/protobuf/coordination_config.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static CoordinationConfigReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"CjJ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX2NvbmZp", | |||
"Zy5wcm90bxIKdGVuc29yZmxvdyIxCg5Db29yZGluYXRlZEpvYhIMCgRuYW1l", | |||
"GAEgASgJEhEKCW51bV90YXNrcxgCIAEoBSLdAgoZQ29vcmRpbmF0aW9uU2Vy", | |||
"dmljZUNvbmZpZxIUCgxzZXJ2aWNlX3R5cGUYASABKAkSFgoOc2VydmljZV9s", | |||
"ZWFkZXIYAiABKAkSGwoTZW5hYmxlX2hlYWx0aF9jaGVjaxgDIAEoCBImCh5j", | |||
"bHVzdGVyX3JlZ2lzdGVyX3RpbWVvdXRfaW5fbXMYBCABKAMSHwoXaGVhcnRi", | |||
"ZWF0X3RpbWVvdXRfaW5fbXMYBSABKAMSOAoUY29vcmRpbmF0ZWRfam9iX2xp", | |||
"c3QYCiADKAsyGi50ZW5zb3JmbG93LkNvb3JkaW5hdGVkSm9iEiYKHnNodXRk", | |||
"b3duX2JhcnJpZXJfdGltZW91dF9pbl9tcxgHIAEoAxIqCiJhZ2VudF9kZXN0", | |||
"cnVjdGlvbl93aXRob3V0X3NodXRkb3duGAggASgIEhgKEHJlY292ZXJhYmxl", | |||
"X2pvYnMYCSADKAlKBAgGEAdCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl", | |||
"X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { }, | |||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedJob), global::Tensorflow.CoordinatedJob.Parser, new[]{ "Name", "NumTasks" }, null, null, null, null), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceConfig), global::Tensorflow.CoordinationServiceConfig.Parser, new[]{ "ServiceType", "ServiceLeader", "EnableHealthCheck", "ClusterRegisterTimeoutInMs", "HeartbeatTimeoutInMs", "CoordinatedJobList", "ShutdownBarrierTimeoutInMs", "AgentDestructionWithoutShutdown", "RecoverableJobs" }, null, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// Represents a job type and the number of tasks under this job. | |||
/// For example, ("worker", 20) implies that there will be 20 worker tasks. | |||
/// </summary> | |||
public sealed partial class CoordinatedJob : pb::IMessage<CoordinatedJob> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CoordinatedJob> _parser = new pb::MessageParser<CoordinatedJob>(() => new CoordinatedJob()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CoordinatedJob> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinatedJob() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinatedJob(CoordinatedJob other) : this() { | |||
name_ = other.name_; | |||
numTasks_ = other.numTasks_; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinatedJob Clone() { | |||
return new CoordinatedJob(this); | |||
} | |||
/// <summary>Field number for the "name" field.</summary> | |||
public const int NameFieldNumber = 1; | |||
private string name_ = ""; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Name { | |||
get { return name_; } | |||
set { | |||
name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "num_tasks" field.</summary> | |||
public const int NumTasksFieldNumber = 2; | |||
private int numTasks_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int NumTasks { | |||
get { return numTasks_; } | |||
set { | |||
numTasks_ = value; | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CoordinatedJob); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CoordinatedJob other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (Name != other.Name) return false; | |||
if (NumTasks != other.NumTasks) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
if (NumTasks != 0) hash ^= NumTasks.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
} | |||
if (NumTasks != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(NumTasks); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
} | |||
if (NumTasks != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(NumTasks); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Name.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); | |||
} | |||
if (NumTasks != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumTasks); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CoordinatedJob other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.Name.Length != 0) { | |||
Name = other.Name; | |||
} | |||
if (other.NumTasks != 0) { | |||
NumTasks = other.NumTasks; | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
Name = input.ReadString(); | |||
break; | |||
} | |||
case 16: { | |||
NumTasks = input.ReadInt32(); | |||
break; | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
Name = input.ReadString(); | |||
break; | |||
} | |||
case 16: { | |||
NumTasks = input.ReadInt32(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
/// <summary> | |||
/// Coordination service configuration parameters. | |||
/// The system picks appropriate values for fields that are not set. | |||
/// </summary> | |||
public sealed partial class CoordinationServiceConfig : pb::IMessage<CoordinationServiceConfig> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CoordinationServiceConfig> _parser = new pb::MessageParser<CoordinationServiceConfig>(() => new CoordinationServiceConfig()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CoordinationServiceConfig> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinationServiceConfig() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinationServiceConfig(CoordinationServiceConfig other) : this() { | |||
serviceType_ = other.serviceType_; | |||
serviceLeader_ = other.serviceLeader_; | |||
enableHealthCheck_ = other.enableHealthCheck_; | |||
clusterRegisterTimeoutInMs_ = other.clusterRegisterTimeoutInMs_; | |||
heartbeatTimeoutInMs_ = other.heartbeatTimeoutInMs_; | |||
coordinatedJobList_ = other.coordinatedJobList_.Clone(); | |||
shutdownBarrierTimeoutInMs_ = other.shutdownBarrierTimeoutInMs_; | |||
agentDestructionWithoutShutdown_ = other.agentDestructionWithoutShutdown_; | |||
recoverableJobs_ = other.recoverableJobs_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CoordinationServiceConfig Clone() { | |||
return new CoordinationServiceConfig(this); | |||
} | |||
/// <summary>Field number for the "service_type" field.</summary> | |||
public const int ServiceTypeFieldNumber = 1; | |||
private string serviceType_ = ""; | |||
/// <summary> | |||
/// Type of coordination service implementation to enable. | |||
/// For example, setting the service type as "standalone" starts a service | |||
/// instance on the leader task to provide the coordination services such as | |||
/// heartbeats and consistent key-value store. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string ServiceType { | |||
get { return serviceType_; } | |||
set { | |||
serviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "service_leader" field.</summary> | |||
public const int ServiceLeaderFieldNumber = 2; | |||
private string serviceLeader_ = ""; | |||
/// <summary> | |||
/// Address where the coordination service instance is hosted. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string ServiceLeader { | |||
get { return serviceLeader_; } | |||
set { | |||
serviceLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "enable_health_check" field.</summary> | |||
public const int EnableHealthCheckFieldNumber = 3; | |||
private bool enableHealthCheck_; | |||
/// <summary> | |||
/// Whether to enable the health check mechanism. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool EnableHealthCheck { | |||
get { return enableHealthCheck_; } | |||
set { | |||
enableHealthCheck_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "cluster_register_timeout_in_ms" field.</summary> | |||
public const int ClusterRegisterTimeoutInMsFieldNumber = 4; | |||
private long clusterRegisterTimeoutInMs_; | |||
/// <summary> | |||
/// Maximum wait time for all members in the cluster to be registered. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long ClusterRegisterTimeoutInMs { | |||
get { return clusterRegisterTimeoutInMs_; } | |||
set { | |||
clusterRegisterTimeoutInMs_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "heartbeat_timeout_in_ms" field.</summary> | |||
public const int HeartbeatTimeoutInMsFieldNumber = 5; | |||
private long heartbeatTimeoutInMs_; | |||
/// <summary> | |||
/// Heartbeat timeout, if a task does not record heartbeat in this time | |||
/// window, it will be considered disconnected. | |||
/// Note: This is also used as a grace period to accept any heartbeats after | |||
/// the agent has disconnected, to account for the lag time between the service | |||
/// recording the state change and the agent stopping heartbeats. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long HeartbeatTimeoutInMs { | |||
get { return heartbeatTimeoutInMs_; } | |||
set { | |||
heartbeatTimeoutInMs_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "coordinated_job_list" field.</summary> | |||
public const int CoordinatedJobListFieldNumber = 10; | |||
private static readonly pb::FieldCodec<global::Tensorflow.CoordinatedJob> _repeated_coordinatedJobList_codec | |||
= pb::FieldCodec.ForMessage(82, global::Tensorflow.CoordinatedJob.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.CoordinatedJob> coordinatedJobList_ = new pbc::RepeatedField<global::Tensorflow.CoordinatedJob>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.CoordinatedJob> CoordinatedJobList { | |||
get { return coordinatedJobList_; } | |||
} | |||
/// <summary>Field number for the "shutdown_barrier_timeout_in_ms" field.</summary> | |||
public const int ShutdownBarrierTimeoutInMsFieldNumber = 7; | |||
private long shutdownBarrierTimeoutInMs_; | |||
/// <summary> | |||
/// Denotes how long to wait for all coordination agents to reach the barriers | |||
/// (after the first shutdown request) before disconnecting together. If | |||
/// set to 0, no barrier is imposed upon shutdown and each worker can | |||
/// disconnect individually. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long ShutdownBarrierTimeoutInMs { | |||
get { return shutdownBarrierTimeoutInMs_; } | |||
set { | |||
shutdownBarrierTimeoutInMs_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "agent_destruction_without_shutdown" field.</summary> | |||
public const int AgentDestructionWithoutShutdownFieldNumber = 8; | |||
private bool agentDestructionWithoutShutdown_; | |||
/// <summary> | |||
/// If set, agents do not make an explicit Shutdown() call. Service will only | |||
/// find out about the disconnecte agent via stale heartbeats. Used for | |||
/// testing. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool AgentDestructionWithoutShutdown { | |||
get { return agentDestructionWithoutShutdown_; } | |||
set { | |||
agentDestructionWithoutShutdown_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "recoverable_jobs" field.</summary> | |||
public const int RecoverableJobsFieldNumber = 9; | |||
private static readonly pb::FieldCodec<string> _repeated_recoverableJobs_codec | |||
= pb::FieldCodec.ForString(74); | |||
private readonly pbc::RepeatedField<string> recoverableJobs_ = new pbc::RepeatedField<string>(); | |||
/// <summary> | |||
/// The list of jobs which are recoverable. If a task in this list fails, | |||
/// it will not propagate error to other tasks. | |||
/// If empty, no jobs will be recoverable and every task failure will cause | |||
/// error propagation to other tasks. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> RecoverableJobs { | |||
get { return recoverableJobs_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CoordinationServiceConfig); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CoordinationServiceConfig other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (ServiceType != other.ServiceType) return false; | |||
if (ServiceLeader != other.ServiceLeader) return false; | |||
if (EnableHealthCheck != other.EnableHealthCheck) return false; | |||
if (ClusterRegisterTimeoutInMs != other.ClusterRegisterTimeoutInMs) return false; | |||
if (HeartbeatTimeoutInMs != other.HeartbeatTimeoutInMs) return false; | |||
if(!coordinatedJobList_.Equals(other.coordinatedJobList_)) return false; | |||
if (ShutdownBarrierTimeoutInMs != other.ShutdownBarrierTimeoutInMs) return false; | |||
if (AgentDestructionWithoutShutdown != other.AgentDestructionWithoutShutdown) return false; | |||
if(!recoverableJobs_.Equals(other.recoverableJobs_)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (ServiceType.Length != 0) hash ^= ServiceType.GetHashCode(); | |||
if (ServiceLeader.Length != 0) hash ^= ServiceLeader.GetHashCode(); | |||
if (EnableHealthCheck != false) hash ^= EnableHealthCheck.GetHashCode(); | |||
if (ClusterRegisterTimeoutInMs != 0L) hash ^= ClusterRegisterTimeoutInMs.GetHashCode(); | |||
if (HeartbeatTimeoutInMs != 0L) hash ^= HeartbeatTimeoutInMs.GetHashCode(); | |||
hash ^= coordinatedJobList_.GetHashCode(); | |||
if (ShutdownBarrierTimeoutInMs != 0L) hash ^= ShutdownBarrierTimeoutInMs.GetHashCode(); | |||
if (AgentDestructionWithoutShutdown != false) hash ^= AgentDestructionWithoutShutdown.GetHashCode(); | |||
hash ^= recoverableJobs_.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (ServiceType.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ServiceType); | |||
} | |||
if (ServiceLeader.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(ServiceLeader); | |||
} | |||
if (EnableHealthCheck != false) { | |||
output.WriteRawTag(24); | |||
output.WriteBool(EnableHealthCheck); | |||
} | |||
if (ClusterRegisterTimeoutInMs != 0L) { | |||
output.WriteRawTag(32); | |||
output.WriteInt64(ClusterRegisterTimeoutInMs); | |||
} | |||
if (HeartbeatTimeoutInMs != 0L) { | |||
output.WriteRawTag(40); | |||
output.WriteInt64(HeartbeatTimeoutInMs); | |||
} | |||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||
output.WriteRawTag(56); | |||
output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||
} | |||
if (AgentDestructionWithoutShutdown != false) { | |||
output.WriteRawTag(64); | |||
output.WriteBool(AgentDestructionWithoutShutdown); | |||
} | |||
recoverableJobs_.WriteTo(output, _repeated_recoverableJobs_codec); | |||
coordinatedJobList_.WriteTo(output, _repeated_coordinatedJobList_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (ServiceType.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(ServiceType); | |||
} | |||
if (ServiceLeader.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(ServiceLeader); | |||
} | |||
if (EnableHealthCheck != false) { | |||
output.WriteRawTag(24); | |||
output.WriteBool(EnableHealthCheck); | |||
} | |||
if (ClusterRegisterTimeoutInMs != 0L) { | |||
output.WriteRawTag(32); | |||
output.WriteInt64(ClusterRegisterTimeoutInMs); | |||
} | |||
if (HeartbeatTimeoutInMs != 0L) { | |||
output.WriteRawTag(40); | |||
output.WriteInt64(HeartbeatTimeoutInMs); | |||
} | |||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||
output.WriteRawTag(56); | |||
output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||
} | |||
if (AgentDestructionWithoutShutdown != false) { | |||
output.WriteRawTag(64); | |||
output.WriteBool(AgentDestructionWithoutShutdown); | |||
} | |||
recoverableJobs_.WriteTo(ref output, _repeated_recoverableJobs_codec); | |||
coordinatedJobList_.WriteTo(ref output, _repeated_coordinatedJobList_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (ServiceType.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceType); | |||
} | |||
if (ServiceLeader.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceLeader); | |||
} | |||
if (EnableHealthCheck != false) { | |||
size += 1 + 1; | |||
} | |||
if (ClusterRegisterTimeoutInMs != 0L) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClusterRegisterTimeoutInMs); | |||
} | |||
if (HeartbeatTimeoutInMs != 0L) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatTimeoutInMs); | |||
} | |||
size += coordinatedJobList_.CalculateSize(_repeated_coordinatedJobList_codec); | |||
if (ShutdownBarrierTimeoutInMs != 0L) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownBarrierTimeoutInMs); | |||
} | |||
if (AgentDestructionWithoutShutdown != false) { | |||
size += 1 + 1; | |||
} | |||
size += recoverableJobs_.CalculateSize(_repeated_recoverableJobs_codec); | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CoordinationServiceConfig other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.ServiceType.Length != 0) { | |||
ServiceType = other.ServiceType; | |||
} | |||
if (other.ServiceLeader.Length != 0) { | |||
ServiceLeader = other.ServiceLeader; | |||
} | |||
if (other.EnableHealthCheck != false) { | |||
EnableHealthCheck = other.EnableHealthCheck; | |||
} | |||
if (other.ClusterRegisterTimeoutInMs != 0L) { | |||
ClusterRegisterTimeoutInMs = other.ClusterRegisterTimeoutInMs; | |||
} | |||
if (other.HeartbeatTimeoutInMs != 0L) { | |||
HeartbeatTimeoutInMs = other.HeartbeatTimeoutInMs; | |||
} | |||
coordinatedJobList_.Add(other.coordinatedJobList_); | |||
if (other.ShutdownBarrierTimeoutInMs != 0L) { | |||
ShutdownBarrierTimeoutInMs = other.ShutdownBarrierTimeoutInMs; | |||
} | |||
if (other.AgentDestructionWithoutShutdown != false) { | |||
AgentDestructionWithoutShutdown = other.AgentDestructionWithoutShutdown; | |||
} | |||
recoverableJobs_.Add(other.recoverableJobs_); | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
ServiceType = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
ServiceLeader = input.ReadString(); | |||
break; | |||
} | |||
case 24: { | |||
EnableHealthCheck = input.ReadBool(); | |||
break; | |||
} | |||
case 32: { | |||
ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 40: { | |||
HeartbeatTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 56: { | |||
ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 64: { | |||
AgentDestructionWithoutShutdown = input.ReadBool(); | |||
break; | |||
} | |||
case 74: { | |||
recoverableJobs_.AddEntriesFrom(input, _repeated_recoverableJobs_codec); | |||
break; | |||
} | |||
case 82: { | |||
coordinatedJobList_.AddEntriesFrom(input, _repeated_coordinatedJobList_codec); | |||
break; | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
ServiceType = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
ServiceLeader = input.ReadString(); | |||
break; | |||
} | |||
case 24: { | |||
EnableHealthCheck = input.ReadBool(); | |||
break; | |||
} | |||
case 32: { | |||
ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 40: { | |||
HeartbeatTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 56: { | |||
ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||
break; | |||
} | |||
case 64: { | |||
AgentDestructionWithoutShutdown = input.ReadBool(); | |||
break; | |||
} | |||
case 74: { | |||
recoverableJobs_.AddEntriesFrom(ref input, _repeated_recoverableJobs_codec); | |||
break; | |||
} | |||
case 82: { | |||
coordinatedJobList_.AddEntriesFrom(ref input, _repeated_coordinatedJobList_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/python/framework/cpp_shape_inference.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -55,23 +55,31 @@ namespace Tensorflow { | |||
} | |||
#region Messages | |||
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> { | |||
public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CppShapeInferenceResult> _parser = new pb::MessageParser<CppShapeInferenceResult>(() => new CppShapeInferenceResult()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CppShapeInferenceResult> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceResult() { | |||
OnConstruction(); | |||
} | |||
@@ -79,6 +87,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceResult(CppShapeInferenceResult other) : this() { | |||
shape_ = other.shape_ != null ? other.shape_.Clone() : null; | |||
handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null; | |||
@@ -86,6 +95,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceResult Clone() { | |||
return new CppShapeInferenceResult(this); | |||
} | |||
@@ -94,6 +104,7 @@ namespace Tensorflow { | |||
public const int ShapeFieldNumber = 1; | |||
private global::Tensorflow.TensorShapeProto shape_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.TensorShapeProto Shape { | |||
get { return shape_; } | |||
set { | |||
@@ -105,6 +116,7 @@ namespace Tensorflow { | |||
public const int HandleDataFieldNumber = 4; | |||
private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { | |||
get { return handleData_; } | |||
set { | |||
@@ -113,11 +125,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CppShapeInferenceResult); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CppShapeInferenceResult other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -131,6 +145,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (shape_ != null) hash ^= Shape.GetHashCode(); | |||
@@ -142,12 +157,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (shape_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(Shape); | |||
@@ -159,9 +179,29 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (shape_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(Shape); | |||
} | |||
if (handleData_ != null) { | |||
output.WriteRawTag(34); | |||
output.WriteMessage(HandleData); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (shape_ != null) { | |||
@@ -177,6 +217,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CppShapeInferenceResult other) { | |||
if (other == null) { | |||
return; | |||
@@ -197,7 +238,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -220,29 +265,68 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
if (shape_ == null) { | |||
Shape = new global::Tensorflow.TensorShapeProto(); | |||
} | |||
input.ReadMessage(Shape); | |||
break; | |||
} | |||
case 34: { | |||
if (handleData_ == null) { | |||
HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); | |||
} | |||
input.ReadMessage(HandleData); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
#region Nested types | |||
/// <summary>Container for nested types declared in the CppShapeInferenceResult message type.</summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static partial class Types { | |||
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> { | |||
public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<HandleShapeAndType> _parser = new pb::MessageParser<HandleShapeAndType>(() => new HandleShapeAndType()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<HandleShapeAndType> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleShapeAndType() { | |||
OnConstruction(); | |||
} | |||
@@ -250,6 +334,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleShapeAndType(HandleShapeAndType other) : this() { | |||
shape_ = other.shape_ != null ? other.shape_.Clone() : null; | |||
dtype_ = other.dtype_; | |||
@@ -258,6 +343,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleShapeAndType Clone() { | |||
return new HandleShapeAndType(this); | |||
} | |||
@@ -266,6 +352,7 @@ namespace Tensorflow { | |||
public const int ShapeFieldNumber = 1; | |||
private global::Tensorflow.TensorShapeProto shape_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.TensorShapeProto Shape { | |||
get { return shape_; } | |||
set { | |||
@@ -277,6 +364,7 @@ namespace Tensorflow { | |||
public const int DtypeFieldNumber = 2; | |||
private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.DataType Dtype { | |||
get { return dtype_; } | |||
set { | |||
@@ -288,6 +376,7 @@ namespace Tensorflow { | |||
public const int TypeFieldNumber = 4; | |||
private global::Tensorflow.FullTypeDef type_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.FullTypeDef Type { | |||
get { return type_; } | |||
set { | |||
@@ -296,11 +385,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as HandleShapeAndType); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(HandleShapeAndType other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -315,6 +406,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (shape_ != null) hash ^= Shape.GetHashCode(); | |||
@@ -327,12 +419,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (shape_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(Shape); | |||
@@ -348,9 +445,33 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (shape_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(Shape); | |||
} | |||
if (Dtype != global::Tensorflow.DataType.DtInvalid) { | |||
output.WriteRawTag(16); | |||
output.WriteEnum((int) Dtype); | |||
} | |||
if (type_ != null) { | |||
output.WriteRawTag(34); | |||
output.WriteMessage(Type); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (shape_ != null) { | |||
@@ -369,6 +490,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(HandleShapeAndType other) { | |||
if (other == null) { | |||
return; | |||
@@ -392,7 +514,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -419,27 +545,69 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
if (shape_ == null) { | |||
Shape = new global::Tensorflow.TensorShapeProto(); | |||
} | |||
input.ReadMessage(Shape); | |||
break; | |||
} | |||
case 16: { | |||
Dtype = (global::Tensorflow.DataType) input.ReadEnum(); | |||
break; | |||
} | |||
case 34: { | |||
if (type_ == null) { | |||
Type = new global::Tensorflow.FullTypeDef(); | |||
} | |||
input.ReadMessage(Type); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class HandleData : pb::IMessage<HandleData> { | |||
public sealed partial class HandleData : pb::IMessage<HandleData> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<HandleData> _parser = new pb::MessageParser<HandleData>(() => new HandleData()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<HandleData> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleData() { | |||
OnConstruction(); | |||
} | |||
@@ -447,6 +615,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleData(HandleData other) : this() { | |||
isSet_ = other.isSet_; | |||
shapeAndType_ = other.shapeAndType_.Clone(); | |||
@@ -454,6 +623,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public HandleData Clone() { | |||
return new HandleData(this); | |||
} | |||
@@ -462,6 +632,7 @@ namespace Tensorflow { | |||
public const int IsSetFieldNumber = 1; | |||
private bool isSet_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool IsSet { | |||
get { return isSet_; } | |||
set { | |||
@@ -478,16 +649,19 @@ namespace Tensorflow { | |||
/// Only valid if <is_set>. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | |||
get { return shapeAndType_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as HandleData); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(HandleData other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -501,6 +675,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (IsSet != false) hash ^= IsSet.GetHashCode(); | |||
@@ -512,12 +687,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (IsSet != false) { | |||
output.WriteRawTag(8); | |||
output.WriteBool(IsSet); | |||
@@ -526,9 +706,26 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (IsSet != false) { | |||
output.WriteRawTag(8); | |||
output.WriteBool(IsSet); | |||
} | |||
shapeAndType_.WriteTo(ref output, _repeated_shapeAndType_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (IsSet != false) { | |||
@@ -542,6 +739,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(HandleData other) { | |||
if (other == null) { | |||
return; | |||
@@ -554,7 +752,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -571,8 +773,32 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 8: { | |||
IsSet = input.ReadBool(); | |||
break; | |||
} | |||
case 18: { | |||
shapeAndType_.AddEntriesFrom(ref input, _repeated_shapeAndType_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
} | |||
@@ -580,23 +806,31 @@ namespace Tensorflow { | |||
} | |||
public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> { | |||
public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<CppShapeInferenceInputsNeeded> _parser = new pb::MessageParser<CppShapeInferenceInputsNeeded>(() => new CppShapeInferenceInputsNeeded()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<CppShapeInferenceInputsNeeded> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceInputsNeeded() { | |||
OnConstruction(); | |||
} | |||
@@ -604,6 +838,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() { | |||
inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone(); | |||
inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone(); | |||
@@ -611,6 +846,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public CppShapeInferenceInputsNeeded Clone() { | |||
return new CppShapeInferenceInputsNeeded(this); | |||
} | |||
@@ -621,6 +857,7 @@ namespace Tensorflow { | |||
= pb::FieldCodec.ForInt32(10); | |||
private readonly pbc::RepeatedField<int> inputTensorsNeeded_ = new pbc::RepeatedField<int>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<int> InputTensorsNeeded { | |||
get { return inputTensorsNeeded_; } | |||
} | |||
@@ -631,16 +868,19 @@ namespace Tensorflow { | |||
= pb::FieldCodec.ForInt32(18); | |||
private readonly pbc::RepeatedField<int> inputTensorsAsShapesNeeded_ = new pbc::RepeatedField<int>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<int> InputTensorsAsShapesNeeded { | |||
get { return inputTensorsAsShapesNeeded_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as CppShapeInferenceInputsNeeded); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(CppShapeInferenceInputsNeeded other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -654,6 +894,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= inputTensorsNeeded_.GetHashCode(); | |||
@@ -665,20 +906,39 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec); | |||
inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
inputTensorsNeeded_.WriteTo(ref output, _repeated_inputTensorsNeeded_codec); | |||
inputTensorsAsShapesNeeded_.WriteTo(ref output, _repeated_inputTensorsAsShapesNeeded_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec); | |||
@@ -690,6 +950,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(CppShapeInferenceInputsNeeded other) { | |||
if (other == null) { | |||
return; | |||
@@ -700,7 +961,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -719,7 +984,33 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: | |||
case 8: { | |||
inputTensorsNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsNeeded_codec); | |||
break; | |||
} | |||
case 18: | |||
case 16: { | |||
inputTensorsAsShapesNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsAsShapesNeeded_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/protobuf/debug.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -55,23 +55,31 @@ namespace Tensorflow { | |||
/// <summary> | |||
/// Option for watching a node in TensorFlow Debugger (tfdbg). | |||
/// </summary> | |||
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> { | |||
public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DebugTensorWatch> _parser = new pb::MessageParser<DebugTensorWatch>(() => new DebugTensorWatch()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DebugTensorWatch> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugTensorWatch() { | |||
OnConstruction(); | |||
} | |||
@@ -79,6 +87,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugTensorWatch(DebugTensorWatch other) : this() { | |||
nodeName_ = other.nodeName_; | |||
outputSlot_ = other.outputSlot_; | |||
@@ -89,6 +98,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugTensorWatch Clone() { | |||
return new DebugTensorWatch(this); | |||
} | |||
@@ -102,6 +112,7 @@ namespace Tensorflow { | |||
/// general. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string NodeName { | |||
get { return nodeName_; } | |||
set { | |||
@@ -120,6 +131,7 @@ namespace Tensorflow { | |||
/// errors currently. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int OutputSlot { | |||
get { return outputSlot_; } | |||
set { | |||
@@ -138,6 +150,7 @@ namespace Tensorflow { | |||
/// e.g., {"DebugIdentity", "DebugNanCount"} | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> DebugOps { | |||
get { return debugOps_; } | |||
} | |||
@@ -170,6 +183,7 @@ namespace Tensorflow { | |||
/// TODO(cais): More visible documentation of this in g3docs. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> DebugUrls { | |||
get { return debugUrls_; } | |||
} | |||
@@ -182,6 +196,7 @@ namespace Tensorflow { | |||
/// incompatibility). Instead, just log the failure. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool TolerateDebugOpCreationFailures { | |||
get { return tolerateDebugOpCreationFailures_; } | |||
set { | |||
@@ -190,11 +205,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DebugTensorWatch); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DebugTensorWatch other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -211,6 +228,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); | |||
@@ -225,12 +243,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (NodeName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(NodeName); | |||
@@ -248,9 +271,35 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (NodeName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(NodeName); | |||
} | |||
if (OutputSlot != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(OutputSlot); | |||
} | |||
debugOps_.WriteTo(ref output, _repeated_debugOps_codec); | |||
debugUrls_.WriteTo(ref output, _repeated_debugUrls_codec); | |||
if (TolerateDebugOpCreationFailures != false) { | |||
output.WriteRawTag(40); | |||
output.WriteBool(TolerateDebugOpCreationFailures); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (NodeName.Length != 0) { | |||
@@ -271,6 +320,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DebugTensorWatch other) { | |||
if (other == null) { | |||
return; | |||
@@ -290,7 +340,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -319,30 +373,74 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
NodeName = input.ReadString(); | |||
break; | |||
} | |||
case 16: { | |||
OutputSlot = input.ReadInt32(); | |||
break; | |||
} | |||
case 26: { | |||
debugOps_.AddEntriesFrom(ref input, _repeated_debugOps_codec); | |||
break; | |||
} | |||
case 34: { | |||
debugUrls_.AddEntriesFrom(ref input, _repeated_debugUrls_codec); | |||
break; | |||
} | |||
case 40: { | |||
TolerateDebugOpCreationFailures = input.ReadBool(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
/// <summary> | |||
/// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). | |||
/// </summary> | |||
public sealed partial class DebugOptions : pb::IMessage<DebugOptions> { | |||
public sealed partial class DebugOptions : pb::IMessage<DebugOptions> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DebugOptions> _parser = new pb::MessageParser<DebugOptions>(() => new DebugOptions()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DebugOptions> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugOptions() { | |||
OnConstruction(); | |||
} | |||
@@ -350,6 +448,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugOptions(DebugOptions other) : this() { | |||
debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone(); | |||
globalStep_ = other.globalStep_; | |||
@@ -358,6 +457,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebugOptions Clone() { | |||
return new DebugOptions(this); | |||
} | |||
@@ -371,6 +471,7 @@ namespace Tensorflow { | |||
/// Debugging options | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.DebugTensorWatch> DebugTensorWatchOpts { | |||
get { return debugTensorWatchOpts_; } | |||
} | |||
@@ -384,6 +485,7 @@ namespace Tensorflow { | |||
/// step count. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long GlobalStep { | |||
get { return globalStep_; } | |||
set { | |||
@@ -401,6 +503,7 @@ namespace Tensorflow { | |||
/// are cleaned up from the disk after each Session.run. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool ResetDiskByteUsage { | |||
get { return resetDiskByteUsage_; } | |||
set { | |||
@@ -409,11 +512,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DebugOptions); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DebugOptions other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -428,6 +533,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= debugTensorWatchOpts_.GetHashCode(); | |||
@@ -440,12 +546,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec); | |||
if (GlobalStep != 0L) { | |||
output.WriteRawTag(80); | |||
@@ -458,9 +569,30 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
debugTensorWatchOpts_.WriteTo(ref output, _repeated_debugTensorWatchOpts_codec); | |||
if (GlobalStep != 0L) { | |||
output.WriteRawTag(80); | |||
output.WriteInt64(GlobalStep); | |||
} | |||
if (ResetDiskByteUsage != false) { | |||
output.WriteRawTag(88); | |||
output.WriteBool(ResetDiskByteUsage); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec); | |||
@@ -477,6 +609,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DebugOptions other) { | |||
if (other == null) { | |||
return; | |||
@@ -492,7 +625,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -513,27 +650,63 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 34: { | |||
debugTensorWatchOpts_.AddEntriesFrom(ref input, _repeated_debugTensorWatchOpts_codec); | |||
break; | |||
} | |||
case 80: { | |||
GlobalStep = input.ReadInt64(); | |||
break; | |||
} | |||
case 88: { | |||
ResetDiskByteUsage = input.ReadBool(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> { | |||
public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DebuggedSourceFile> _parser = new pb::MessageParser<DebuggedSourceFile>(() => new DebuggedSourceFile()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DebuggedSourceFile> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFile() { | |||
OnConstruction(); | |||
} | |||
@@ -541,6 +714,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFile(DebuggedSourceFile other) : this() { | |||
host_ = other.host_; | |||
filePath_ = other.filePath_; | |||
@@ -551,6 +725,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFile Clone() { | |||
return new DebuggedSourceFile(this); | |||
} | |||
@@ -562,6 +737,7 @@ namespace Tensorflow { | |||
/// The host name on which a source code file is located. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Host { | |||
get { return host_; } | |||
set { | |||
@@ -576,6 +752,7 @@ namespace Tensorflow { | |||
/// Path to the source code file. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string FilePath { | |||
get { return filePath_; } | |||
set { | |||
@@ -590,6 +767,7 @@ namespace Tensorflow { | |||
/// The timestamp at which the source code file is last modified. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long LastModified { | |||
get { return lastModified_; } | |||
set { | |||
@@ -604,6 +782,7 @@ namespace Tensorflow { | |||
/// Byte size of the file. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long Bytes { | |||
get { return bytes_; } | |||
set { | |||
@@ -620,16 +799,19 @@ namespace Tensorflow { | |||
/// Line-by-line content of the source code file. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<string> Lines { | |||
get { return lines_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DebuggedSourceFile); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DebuggedSourceFile other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -646,6 +828,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Host.Length != 0) hash ^= Host.GetHashCode(); | |||
@@ -660,12 +843,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (Host.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Host); | |||
@@ -686,9 +874,38 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (Host.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Host); | |||
} | |||
if (FilePath.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(FilePath); | |||
} | |||
if (LastModified != 0L) { | |||
output.WriteRawTag(24); | |||
output.WriteInt64(LastModified); | |||
} | |||
if (Bytes != 0L) { | |||
output.WriteRawTag(32); | |||
output.WriteInt64(Bytes); | |||
} | |||
lines_.WriteTo(ref output, _repeated_lines_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Host.Length != 0) { | |||
@@ -711,6 +928,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DebuggedSourceFile other) { | |||
if (other == null) { | |||
return; | |||
@@ -732,7 +950,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -761,27 +983,71 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
Host = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
FilePath = input.ReadString(); | |||
break; | |||
} | |||
case 24: { | |||
LastModified = input.ReadInt64(); | |||
break; | |||
} | |||
case 32: { | |||
Bytes = input.ReadInt64(); | |||
break; | |||
} | |||
case 42: { | |||
lines_.AddEntriesFrom(ref input, _repeated_lines_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> { | |||
public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DebuggedSourceFiles> _parser = new pb::MessageParser<DebuggedSourceFiles>(() => new DebuggedSourceFiles()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DebuggedSourceFiles> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFiles() { | |||
OnConstruction(); | |||
} | |||
@@ -789,12 +1055,14 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFiles(DebuggedSourceFiles other) : this() { | |||
sourceFiles_ = other.sourceFiles_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DebuggedSourceFiles Clone() { | |||
return new DebuggedSourceFiles(this); | |||
} | |||
@@ -808,16 +1076,19 @@ namespace Tensorflow { | |||
/// A collection of source code files. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.DebuggedSourceFile> SourceFiles { | |||
get { return sourceFiles_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DebuggedSourceFiles); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DebuggedSourceFiles other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -830,6 +1101,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= sourceFiles_.GetHashCode(); | |||
@@ -840,19 +1112,37 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
sourceFiles_.WriteTo(ref output, _repeated_sourceFiles_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec); | |||
@@ -863,6 +1153,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DebuggedSourceFiles other) { | |||
if (other == null) { | |||
return; | |||
@@ -872,7 +1163,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -885,7 +1180,27 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
sourceFiles_.AddEntriesFrom(ref input, _repeated_sourceFiles_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -2,7 +2,7 @@ | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/framework/device_attributes.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
@@ -30,44 +30,53 @@ namespace Tensorflow { | |||
"OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl", | |||
"cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo", | |||
"BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm", | |||
"bG93LkxvY2FsTGlua3MirAEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||
"bG93LkxvY2FsTGlua3MiwwEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||
"IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB", | |||
"KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs", | |||
"aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k", | |||
"ZXNjGAcgASgJQpEBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCFkRldmlj", | |||
"ZUF0dHJpYnV0ZXNQcm90b3NQAVpYZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay9kZXZpY2Vf", | |||
"YXR0cmlidXRlc19nb19wcm90b/gBAWIGcHJvdG8z")); | |||
"ZXNjGAcgASgJEhUKDXhsYV9nbG9iYWxfaWQYCCABKANCkQEKGG9yZy50ZW5z", | |||
"b3JmbG93LmZyYW1ld29ya0IWRGV2aWNlQXR0cmlidXRlc1Byb3Rvc1ABWlhn", | |||
"aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", | |||
"L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVzX2dvX3Byb3Rv+AEB", | |||
"YgZwcm90bzM=")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { }, | |||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc" }, null, null, null, null) | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc", "XlaGlobalId" }, null, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> { | |||
public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<InterconnectLink> _parser = new pb::MessageParser<InterconnectLink>(() => new InterconnectLink()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<InterconnectLink> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public InterconnectLink() { | |||
OnConstruction(); | |||
} | |||
@@ -75,6 +84,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public InterconnectLink(InterconnectLink other) : this() { | |||
deviceId_ = other.deviceId_; | |||
type_ = other.type_; | |||
@@ -83,6 +93,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public InterconnectLink Clone() { | |||
return new InterconnectLink(this); | |||
} | |||
@@ -91,6 +102,7 @@ namespace Tensorflow { | |||
public const int DeviceIdFieldNumber = 1; | |||
private int deviceId_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int DeviceId { | |||
get { return deviceId_; } | |||
set { | |||
@@ -102,6 +114,7 @@ namespace Tensorflow { | |||
public const int TypeFieldNumber = 2; | |||
private string type_ = ""; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Type { | |||
get { return type_; } | |||
set { | |||
@@ -113,6 +126,7 @@ namespace Tensorflow { | |||
public const int StrengthFieldNumber = 3; | |||
private int strength_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int Strength { | |||
get { return strength_; } | |||
set { | |||
@@ -121,11 +135,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as InterconnectLink); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(InterconnectLink other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -140,6 +156,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (DeviceId != 0) hash ^= DeviceId.GetHashCode(); | |||
@@ -152,12 +169,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (DeviceId != 0) { | |||
output.WriteRawTag(8); | |||
output.WriteInt32(DeviceId); | |||
@@ -173,9 +195,33 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (DeviceId != 0) { | |||
output.WriteRawTag(8); | |||
output.WriteInt32(DeviceId); | |||
} | |||
if (Type.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(Type); | |||
} | |||
if (Strength != 0) { | |||
output.WriteRawTag(24); | |||
output.WriteInt32(Strength); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (DeviceId != 0) { | |||
@@ -194,6 +240,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(InterconnectLink other) { | |||
if (other == null) { | |||
return; | |||
@@ -211,7 +258,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -232,27 +283,63 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 8: { | |||
DeviceId = input.ReadInt32(); | |||
break; | |||
} | |||
case 18: { | |||
Type = input.ReadString(); | |||
break; | |||
} | |||
case 24: { | |||
Strength = input.ReadInt32(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class LocalLinks : pb::IMessage<LocalLinks> { | |||
public sealed partial class LocalLinks : pb::IMessage<LocalLinks> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<LocalLinks> _parser = new pb::MessageParser<LocalLinks>(() => new LocalLinks()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<LocalLinks> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public LocalLinks() { | |||
OnConstruction(); | |||
} | |||
@@ -260,12 +347,14 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public LocalLinks(LocalLinks other) : this() { | |||
link_ = other.link_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public LocalLinks Clone() { | |||
return new LocalLinks(this); | |||
} | |||
@@ -276,16 +365,19 @@ namespace Tensorflow { | |||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.InterconnectLink> link_ = new pbc::RepeatedField<global::Tensorflow.InterconnectLink>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public pbc::RepeatedField<global::Tensorflow.InterconnectLink> Link { | |||
get { return link_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as LocalLinks); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(LocalLinks other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -298,6 +390,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= link_.GetHashCode(); | |||
@@ -308,19 +401,37 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
link_.WriteTo(output, _repeated_link_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
link_.WriteTo(ref output, _repeated_link_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += link_.CalculateSize(_repeated_link_codec); | |||
@@ -331,6 +442,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(LocalLinks other) { | |||
if (other == null) { | |||
return; | |||
@@ -340,7 +452,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -353,27 +469,55 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
link_.AddEntriesFrom(ref input, _repeated_link_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> { | |||
public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DeviceLocality> _parser = new pb::MessageParser<DeviceLocality>(() => new DeviceLocality()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DeviceLocality> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceLocality() { | |||
OnConstruction(); | |||
} | |||
@@ -381,6 +525,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceLocality(DeviceLocality other) : this() { | |||
busId_ = other.busId_; | |||
numaNode_ = other.numaNode_; | |||
@@ -389,6 +534,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceLocality Clone() { | |||
return new DeviceLocality(this); | |||
} | |||
@@ -401,6 +547,7 @@ namespace Tensorflow { | |||
/// no specific locality. Specific localities are indexed from 1. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int BusId { | |||
get { return busId_; } | |||
set { | |||
@@ -415,6 +562,7 @@ namespace Tensorflow { | |||
/// Optional NUMA locality of device. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int NumaNode { | |||
get { return numaNode_; } | |||
set { | |||
@@ -429,6 +577,7 @@ namespace Tensorflow { | |||
/// Optional local interconnect links to other devices. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.LocalLinks Links { | |||
get { return links_; } | |||
set { | |||
@@ -437,11 +586,13 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DeviceLocality); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DeviceLocality other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -456,6 +607,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (BusId != 0) hash ^= BusId.GetHashCode(); | |||
@@ -468,12 +620,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (BusId != 0) { | |||
output.WriteRawTag(8); | |||
output.WriteInt32(BusId); | |||
@@ -489,9 +646,33 @@ namespace Tensorflow { | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (BusId != 0) { | |||
output.WriteRawTag(8); | |||
output.WriteInt32(BusId); | |||
} | |||
if (NumaNode != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(NumaNode); | |||
} | |||
if (links_ != null) { | |||
output.WriteRawTag(26); | |||
output.WriteMessage(Links); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (BusId != 0) { | |||
@@ -510,6 +691,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DeviceLocality other) { | |||
if (other == null) { | |||
return; | |||
@@ -530,7 +712,11 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -554,27 +740,66 @@ namespace Tensorflow { | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 8: { | |||
BusId = input.ReadInt32(); | |||
break; | |||
} | |||
case 16: { | |||
NumaNode = input.ReadInt32(); | |||
break; | |||
} | |||
case 26: { | |||
if (links_ == null) { | |||
Links = new global::Tensorflow.LocalLinks(); | |||
} | |||
input.ReadMessage(Links); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> { | |||
public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<DeviceAttributes> _parser = new pb::MessageParser<DeviceAttributes>(() => new DeviceAttributes()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<DeviceAttributes> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceAttributes() { | |||
OnConstruction(); | |||
} | |||
@@ -582,6 +807,7 @@ namespace Tensorflow { | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceAttributes(DeviceAttributes other) : this() { | |||
name_ = other.name_; | |||
deviceType_ = other.deviceType_; | |||
@@ -589,10 +815,12 @@ namespace Tensorflow { | |||
locality_ = other.locality_ != null ? other.locality_.Clone() : null; | |||
incarnation_ = other.incarnation_; | |||
physicalDeviceDesc_ = other.physicalDeviceDesc_; | |||
xlaGlobalId_ = other.xlaGlobalId_; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public DeviceAttributes Clone() { | |||
return new DeviceAttributes(this); | |||
} | |||
@@ -604,6 +832,7 @@ namespace Tensorflow { | |||
/// Fully specified name of the device within a cluster. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string Name { | |||
get { return name_; } | |||
set { | |||
@@ -618,6 +847,7 @@ namespace Tensorflow { | |||
/// String representation of device_type. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string DeviceType { | |||
get { return deviceType_; } | |||
set { | |||
@@ -632,6 +862,7 @@ namespace Tensorflow { | |||
/// Memory capacity of device in bytes. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long MemoryLimit { | |||
get { return memoryLimit_; } | |||
set { | |||
@@ -647,6 +878,7 @@ namespace Tensorflow { | |||
/// for supporting efficient data transfers. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Tensorflow.DeviceLocality Locality { | |||
get { return locality_; } | |||
set { | |||
@@ -662,6 +894,7 @@ namespace Tensorflow { | |||
/// initialized. "incarnation" should never be 0. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public ulong Incarnation { | |||
get { return incarnation_; } | |||
set { | |||
@@ -676,6 +909,7 @@ namespace Tensorflow { | |||
/// String representation of the physical device that this device maps to. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public string PhysicalDeviceDesc { | |||
get { return physicalDeviceDesc_; } | |||
set { | |||
@@ -683,12 +917,31 @@ namespace Tensorflow { | |||
} | |||
} | |||
/// <summary>Field number for the "xla_global_id" field.</summary> | |||
public const int XlaGlobalIdFieldNumber = 8; | |||
private long xlaGlobalId_; | |||
/// <summary> | |||
/// A physical device ID for use in XLA DeviceAssignments, unique across | |||
/// clients in a multi-client setup. Set to -1 if unavailable, non-negative | |||
/// otherwise. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public long XlaGlobalId { | |||
get { return xlaGlobalId_; } | |||
set { | |||
xlaGlobalId_ = value; | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as DeviceAttributes); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(DeviceAttributes other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
@@ -702,10 +955,12 @@ namespace Tensorflow { | |||
if (!object.Equals(Locality, other.Locality)) return false; | |||
if (Incarnation != other.Incarnation) return false; | |||
if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false; | |||
if (XlaGlobalId != other.XlaGlobalId) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
@@ -714,6 +969,7 @@ namespace Tensorflow { | |||
if (locality_ != null) hash ^= Locality.GetHashCode(); | |||
if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); | |||
if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode(); | |||
if (XlaGlobalId != 0L) hash ^= XlaGlobalId.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
@@ -721,12 +977,17 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
@@ -751,12 +1012,56 @@ namespace Tensorflow { | |||
output.WriteRawTag(58); | |||
output.WriteString(PhysicalDeviceDesc); | |||
} | |||
if (XlaGlobalId != 0L) { | |||
output.WriteRawTag(64); | |||
output.WriteInt64(XlaGlobalId); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (Name.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(Name); | |||
} | |||
if (DeviceType.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(DeviceType); | |||
} | |||
if (MemoryLimit != 0L) { | |||
output.WriteRawTag(32); | |||
output.WriteInt64(MemoryLimit); | |||
} | |||
if (locality_ != null) { | |||
output.WriteRawTag(42); | |||
output.WriteMessage(Locality); | |||
} | |||
if (Incarnation != 0UL) { | |||
output.WriteRawTag(49); | |||
output.WriteFixed64(Incarnation); | |||
} | |||
if (PhysicalDeviceDesc.Length != 0) { | |||
output.WriteRawTag(58); | |||
output.WriteString(PhysicalDeviceDesc); | |||
} | |||
if (XlaGlobalId != 0L) { | |||
output.WriteRawTag(64); | |||
output.WriteInt64(XlaGlobalId); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Name.Length != 0) { | |||
@@ -777,6 +1082,9 @@ namespace Tensorflow { | |||
if (PhysicalDeviceDesc.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc); | |||
} | |||
if (XlaGlobalId != 0L) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaGlobalId); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
@@ -784,6 +1092,7 @@ namespace Tensorflow { | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(DeviceAttributes other) { | |||
if (other == null) { | |||
return; | |||
@@ -809,11 +1118,18 @@ namespace Tensorflow { | |||
if (other.PhysicalDeviceDesc.Length != 0) { | |||
PhysicalDeviceDesc = other.PhysicalDeviceDesc; | |||
} | |||
if (other.XlaGlobalId != 0L) { | |||
XlaGlobalId = other.XlaGlobalId; | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
@@ -847,9 +1163,60 @@ namespace Tensorflow { | |||
PhysicalDeviceDesc = input.ReadString(); | |||
break; | |||
} | |||
case 64: { | |||
XlaGlobalId = input.ReadInt64(); | |||
break; | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
Name = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
DeviceType = input.ReadString(); | |||
break; | |||
} | |||
case 32: { | |||
MemoryLimit = input.ReadInt64(); | |||
break; | |||
} | |||
case 42: { | |||
if (locality_ == null) { | |||
Locality = new global::Tensorflow.DeviceLocality(); | |||
} | |||
input.ReadMessage(Locality); | |||
break; | |||
} | |||
case 49: { | |||
Incarnation = input.ReadFixed64(); | |||
break; | |||
} | |||
case 58: { | |||
PhysicalDeviceDesc = input.ReadString(); | |||
break; | |||
} | |||
case 64: { | |||
XlaGlobalId = input.ReadInt64(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
@@ -0,0 +1,340 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/compiler/xla/service/cpu/executable.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021, 8981 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Xla.Cpu { | |||
/// <summary>Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||
public static partial class ExecutableReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static ExecutableReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"CjR0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS9leGVjdXRh", | |||
"YmxlLnByb3RvEgd4bGEuY3B1Gjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9z", | |||
"ZXJ2aWNlL2NwdS94bGFfZnJhbWV3b3JrLnByb3RvGil0ZW5zb3JmbG93L2Nv", | |||
"bXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90byLXAQocWGxhUnVudGltZUNw", | |||
"dUV4ZWN1dGFibGVQcm90bxI+ChZ4bGFfcnVudGltZV9leGVjdXRhYmxlGAEg", | |||
"ASgLMh4ueGxhLlhsYVJ1bnRpbWVFeGVjdXRhYmxlUHJvdG8SQAoVeGxhX2Zy", | |||
"YW1ld29ya19tYXBwaW5nGAIgASgLMiEueGxhLmNwdS5YbGFGcmFtZXdvcmtN", | |||
"YXBwaW5nUHJvdG8SNQoRYnVmZmVyX2Fzc2lnbm1lbnQYAyABKAsyGi54bGEu", | |||
"QnVmZmVyQXNzaWdubWVudFByb3Rv")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { global::Xla.Cpu.XlaFrameworkReflection.Descriptor, global::Xla.HloReflection.Descriptor, }, | |||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaRuntimeCpuExecutableProto), global::Xla.Cpu.XlaRuntimeCpuExecutableProto.Parser, new[]{ "XlaRuntimeExecutable", "XlaFrameworkMapping", "BufferAssignment" }, null, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
public sealed partial class XlaRuntimeCpuExecutableProto : pb::IMessage<XlaRuntimeCpuExecutableProto> | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
, pb::IBufferMessage | |||
#endif | |||
{ | |||
private static readonly pb::MessageParser<XlaRuntimeCpuExecutableProto> _parser = new pb::MessageParser<XlaRuntimeCpuExecutableProto>(() => new XlaRuntimeCpuExecutableProto()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pb::MessageParser<XlaRuntimeCpuExecutableProto> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Xla.Cpu.ExecutableReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public XlaRuntimeCpuExecutableProto() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public XlaRuntimeCpuExecutableProto(XlaRuntimeCpuExecutableProto other) : this() { | |||
xlaRuntimeExecutable_ = other.xlaRuntimeExecutable_ != null ? other.xlaRuntimeExecutable_.Clone() : null; | |||
xlaFrameworkMapping_ = other.xlaFrameworkMapping_ != null ? other.xlaFrameworkMapping_.Clone() : null; | |||
bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public XlaRuntimeCpuExecutableProto Clone() { | |||
return new XlaRuntimeCpuExecutableProto(this); | |||
} | |||
/// <summary>Field number for the "xla_runtime_executable" field.</summary> | |||
public const int XlaRuntimeExecutableFieldNumber = 1; | |||
private global::Xla.XlaRuntimeExecutableProto xlaRuntimeExecutable_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Xla.XlaRuntimeExecutableProto XlaRuntimeExecutable { | |||
get { return xlaRuntimeExecutable_; } | |||
set { | |||
xlaRuntimeExecutable_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "xla_framework_mapping" field.</summary> | |||
public const int XlaFrameworkMappingFieldNumber = 2; | |||
private global::Xla.Cpu.XlaFrameworkMappingProto xlaFrameworkMapping_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Xla.Cpu.XlaFrameworkMappingProto XlaFrameworkMapping { | |||
get { return xlaFrameworkMapping_; } | |||
set { | |||
xlaFrameworkMapping_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "buffer_assignment" field.</summary> | |||
public const int BufferAssignmentFieldNumber = 3; | |||
private global::Xla.BufferAssignmentProto bufferAssignment_; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public global::Xla.BufferAssignmentProto BufferAssignment { | |||
get { return bufferAssignment_; } | |||
set { | |||
bufferAssignment_ = value; | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override bool Equals(object other) { | |||
return Equals(other as XlaRuntimeCpuExecutableProto); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public bool Equals(XlaRuntimeCpuExecutableProto other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (!object.Equals(XlaRuntimeExecutable, other.XlaRuntimeExecutable)) return false; | |||
if (!object.Equals(XlaFrameworkMapping, other.XlaFrameworkMapping)) return false; | |||
if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (xlaRuntimeExecutable_ != null) hash ^= XlaRuntimeExecutable.GetHashCode(); | |||
if (xlaFrameworkMapping_ != null) hash ^= XlaFrameworkMapping.GetHashCode(); | |||
if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
output.WriteRawMessage(this); | |||
#else | |||
if (xlaRuntimeExecutable_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(XlaRuntimeExecutable); | |||
} | |||
if (xlaFrameworkMapping_ != null) { | |||
output.WriteRawTag(18); | |||
output.WriteMessage(XlaFrameworkMapping); | |||
} | |||
if (bufferAssignment_ != null) { | |||
output.WriteRawTag(26); | |||
output.WriteMessage(BufferAssignment); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
if (xlaRuntimeExecutable_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(XlaRuntimeExecutable); | |||
} | |||
if (xlaFrameworkMapping_ != null) { | |||
output.WriteRawTag(18); | |||
output.WriteMessage(XlaFrameworkMapping); | |||
} | |||
if (bufferAssignment_ != null) { | |||
output.WriteRawTag(26); | |||
output.WriteMessage(BufferAssignment); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(ref output); | |||
} | |||
} | |||
#endif | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (xlaRuntimeExecutable_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaRuntimeExecutable); | |||
} | |||
if (xlaFrameworkMapping_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaFrameworkMapping); | |||
} | |||
if (bufferAssignment_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(XlaRuntimeCpuExecutableProto other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.xlaRuntimeExecutable_ != null) { | |||
if (xlaRuntimeExecutable_ == null) { | |||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
} | |||
XlaRuntimeExecutable.MergeFrom(other.XlaRuntimeExecutable); | |||
} | |||
if (other.xlaFrameworkMapping_ != null) { | |||
if (xlaFrameworkMapping_ == null) { | |||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
} | |||
XlaFrameworkMapping.MergeFrom(other.XlaFrameworkMapping); | |||
} | |||
if (other.bufferAssignment_ != null) { | |||
if (bufferAssignment_ == null) { | |||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
} | |||
BufferAssignment.MergeFrom(other.BufferAssignment); | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
input.ReadRawMessage(this); | |||
#else | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
if (xlaRuntimeExecutable_ == null) { | |||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
} | |||
input.ReadMessage(XlaRuntimeExecutable); | |||
break; | |||
} | |||
case 18: { | |||
if (xlaFrameworkMapping_ == null) { | |||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
} | |||
input.ReadMessage(XlaFrameworkMapping); | |||
break; | |||
} | |||
case 26: { | |||
if (bufferAssignment_ == null) { | |||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
} | |||
input.ReadMessage(BufferAssignment); | |||
break; | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
[global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
break; | |||
case 10: { | |||
if (xlaRuntimeExecutable_ == null) { | |||
XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
} | |||
input.ReadMessage(XlaRuntimeExecutable); | |||
break; | |||
} | |||
case 18: { | |||
if (xlaFrameworkMapping_ == null) { | |||
XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
} | |||
input.ReadMessage(XlaFrameworkMapping); | |||
break; | |||
} | |||
case 26: { | |||
if (bufferAssignment_ == null) { | |||
BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
} | |||
input.ReadMessage(BufferAssignment); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |