diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 4605e37a..36b74e5c 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -34,6 +34,14 @@ namespace Tensorflow _handle = handle; } + public Buffer(byte[] data) + { + var dst = Marshal.AllocHGlobal(data.Length); + Marshal.Copy(data, 0, dst, data.Length); + + _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); + } + public static implicit operator IntPtr(Buffer buffer) { return buffer._handle; diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index d9792f12..08c71887 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -19,5 +19,15 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); + + /// + /// Makes a copy of the input and sets an appropriate deallocator. Useful for + /// passing in read-only, input protobufs. + /// + /// const void* + /// size_t + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_NewBufferFromString(IntPtr proto, ulong proto_len); } } diff --git a/src/TensorFlowNET.Core/Exceptions/KeyError.cs b/src/TensorFlowNET.Core/Exceptions/KeyError.cs new file mode 100644 index 00000000..0dfa2f97 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/KeyError.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class KeyError : Exception + { + public KeyError() : base() + { + + } + + public KeyError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.py.cs b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs index 2ff54c0e..94c0dac0 100644 --- a/src/TensorFlowNET.Core/Framework/c_api_util.py.cs +++ b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs @@ -10,13 +10,25 @@ namespace Tensorflow public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); - public static IntPtr tf_buffer(byte[] data) + public static Buffer tf_buffer(byte[] data) => new Buffer(data); + + public static IEnumerable new_tf_operations(Graph graph) + { + foreach (var c_op in tf_operations(graph)) + { + if (graph._get_operation_by_tf_operation(c_op) == null) + yield return c_op; + } + } + + public static IEnumerable tf_operations(Graph graph) { - if (data != null) - throw new NotImplementedException(""); - // var buf = c_api.TF_NewBufferFromString(data); - else - throw new NotImplementedException(""); + uint pos = 0; + IntPtr c_op; + while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) + { + yield return c_op; + } } } } diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs index 2fdc8985..e8c971c7 100644 --- a/src/TensorFlowNET.Core/Framework/importer.py.cs +++ b/src/TensorFlowNET.Core/Framework/importer.py.cs @@ -42,11 +42,27 @@ namespace Tensorflow _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); var bytes = graph_def.ToByteString().ToArray(); + IntPtr buffer = c_api_util.tf_buffer(bytes); var status = new Status(); - c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status); + // need to create a class ImportGraphDefWithResults with IDisposal + var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); + status.Check(true); - throw new NotImplementedException("importer.import_graph_def"); + _ProcessNewOps(graph); + + if (return_elements == null) + return null; + else + throw new NotImplementedException("import_graph_def return_elements"); + } + + private static void _ProcessNewOps(Graph graph) + { + foreach(var new_op in graph._add_new_tf_operations()) + { + var original_device = new_op.Device; + } } public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index 6e8d354a..060c267a 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using static Tensorflow.CollectionDef; using static Tensorflow.MetaGraphDef.Types; namespace Tensorflow @@ -16,7 +17,7 @@ namespace Tensorflow return meta_graph_def; } - public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, + public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, bool clear_devices = false, string import_scope = "", Dictionary input_map = null, @@ -51,7 +52,7 @@ namespace Tensorflow node.Device = ""; var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); - importer.import_graph_def(input_graph_def, + var imported_return_elements = importer.import_graph_def(input_graph_def, name: scope_to_prepend_to_names, input_map: input_map, producer_op_list: producer_op_list, @@ -59,7 +60,41 @@ namespace Tensorflow // Restores all the other collections. var variable_objects = new Dictionary(); + foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) + { + // Don't add unbound_inputs to the new graph. + if (col.Key == unbound_inputs_col_name) + continue; + + switch (col.Value.KindCase) + { + case KindOneofCase.NodeList: + foreach(var value in col.Value.NodeList.Value) + { + var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); + graph.add_to_collection(col.Key, col_op); + } + break; + case KindOneofCase.BytesList: + //var proto_type = ops.get_collection_proto_type(key) + if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) + { + foreach (var value in col.Value.BytesList.Value) + { + var proto = VariableDef.Parser.ParseFrom(value); + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + } + else + { + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + + break; + } + } + return (null, null); } /// diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index e2ff80e2..2f16b880 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -1,6 +1,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -22,5 +23,57 @@ namespace Tensorflow { return c_api.TF_NewOperation(_handle, opType, opName); } + + public ITensorOrOperation _get_operation_by_name_unsafe(string name) + { + return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; + } + + public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) + { + var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); + return _get_operation_by_name_unsafe(op_name); + } + + public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) + { + var ret = new Operation(c_op); + + var name_key = ret.name.ToLower(); + if (!_names_in_use.ContainsKey(name_key)) + _names_in_use[name_key] = 1; + + _create_op_helper(ret, compute_device: compute_device); + + return ret; + } + + /// + /// Creates `Operations` in this graph for any new TF_Operations. + /// + /// This is useful for when TF_Operations are indirectly created by the C API + /// outside of the Operation constructor (e.g. by TF_ImportGraphDef, + /// TF_FinishWhile). This ensures there are corresponding Operations for all + /// TF_Operations in the underlying TF_Graph. + /// + /// + /// + public IEnumerable _add_new_tf_operations(bool compute_devices = true) + { + var new_ops = c_api_util.new_tf_operations(this) + .Select(c_op => _create_op_from_tf_operation(c_op, compute_device: compute_devices)) + .ToArray(); + + foreach(var op in new_ops) + { + var new_control_inputs = _control_dependencies_for_inputs(op.inputs) + .Select(x => x as Operation) + .ToArray(); + op._add_control_inputs(new_control_inputs); + op._control_flow_post_processing(); + } + + return new_ops; + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8fba13d5..b281debf 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -86,6 +86,16 @@ namespace Tensorflow if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; } + else if(!name.Contains(":") & allow_operation) + { + if (!_nodes_by_name.ContainsKey(name)) + throw new KeyError($"The name {name} refers to an Operation not in the graph."); + return _nodes_by_name[name]; + } + else if (!name.Contains(":") & !allow_operation) + { + throw new NotImplementedException("_as_graph_element_locked"); + } } if (obj is Tensor tensor && allow_tensor) @@ -101,7 +111,7 @@ namespace Tensorflow } else if (obj is Operation op && allow_operation) { - if (op.Graph.Equals(this)) + if (op.graph.Equals(this)) { return op; } diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs index 4f387120..9abe7303 100644 --- a/src/TensorFlowNET.Core/Operations/InputList.cs +++ b/src/TensorFlowNET.Core/Operations/InputList.cs @@ -26,5 +26,10 @@ namespace Tensorflow { return input._inputs.ToList(); } + + public static implicit operator Tensor[](InputList input) + { + return input._inputs; + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 5599ad2b..a51d1ca9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -16,5 +16,13 @@ namespace Tensorflow } } + + public void _add_control_inputs(Operation[] ops) + { + foreach(var op in ops) + { + c_api.TF_AddControlInput(graph, op); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 00a9241b..45a57286 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -11,7 +11,7 @@ namespace Tensorflow { private readonly IntPtr _handle; // _c_op in python - public Graph Graph { get; } + public Graph graph { get; } public int _id => _id_value; private int _id_value; @@ -42,15 +42,17 @@ namespace Tensorflow return; _handle = handle; - this.Graph = ops.get_default_graph(); + this.graph = ops.get_default_graph(); _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) _outputs[i] = new Tensor(this, i, OutputType(i)); + + graph._add_op(this); } public Operation(Graph g, string opType, string oper_name) { - Graph = g; + graph = g; var desc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); @@ -78,7 +80,7 @@ namespace Tensorflow /// public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { - Graph = g; + graph = g; // Build the list of control inputs. var control_input_ops = new List(); @@ -99,7 +101,7 @@ namespace Tensorflow // This will be set by self.inputs. - _id_value = Graph._next_id(); + _id_value = graph._next_id(); if(op_def == null) op_def = g.GetOpDef(node_def.Op); @@ -115,7 +117,7 @@ namespace Tensorflow for (int i = 0; i < NumOutputs; i++) _outputs[i] = new Tensor(this, i, OutputType(i)); - Graph._add_op(this); + graph._add_op(this); if (_handle != IntPtr.Zero) _control_flow_post_processing(); @@ -123,7 +125,7 @@ namespace Tensorflow public void run(FeedItem[] feed_dict = null, Session session = null) { - ops._run_using_default_session(this, feed_dict, Graph, session); + ops._run_using_default_session(this, feed_dict, graph, session); } private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs) diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index 64f6e813..662c628e 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -3,6 +3,8 @@ set SRC_DIR=D:\Projects\tensorflow set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf +cd tensorflow + protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto @@ -12,6 +14,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.pr protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\variable.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto diff --git a/src/TensorFlowNET.Core/Protobuf/Variable.cs b/src/TensorFlowNET.Core/Protobuf/Variable.cs new file mode 100644 index 00000000..18cdd6d2 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Variable.cs @@ -0,0 +1,584 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/variable.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/variable.proto + public static partial class VariableReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/variable.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VariableReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZhcmlhYmxlLnByb3RvEgp0", + "ZW5zb3JmbG93ItQBCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg", + "ASgJEhoKEmluaXRpYWxfdmFsdWVfbmFtZRgGIAEoCRIYChBpbml0aWFsaXpl", + "cl9uYW1lGAIgASgJEhUKDXNuYXBzaG90X25hbWUYAyABKAkSOQoTc2F2ZV9z", + "bGljZV9pbmZvX2RlZhgEIAEoCzIcLnRlbnNvcmZsb3cuU2F2ZVNsaWNlSW5m", + "b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgi", + "YAoQU2F2ZVNsaWNlSW5mb0RlZhIRCglmdWxsX25hbWUYASABKAkSEgoKZnVs", + "bF9zaGFwZRgCIAMoAxISCgp2YXJfb2Zmc2V0GAMgAygDEhEKCXZhcl9zaGFw", + "ZRgEIAMoA0JuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDlZhcmlhYmxl", + "UHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3Rl", + "bnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveSliceInfoDef), global::Tensorflow.SaveSliceInfoDef.Parser, new[]{ "FullName", "FullShape", "VarOffset", "VarShape" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing a Variable. + /// + public sealed partial class VariableDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VariableDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariableDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariableDef(VariableDef other) : this() { + variableName_ = other.variableName_; + initialValueName_ = other.initialValueName_; + initializerName_ = other.initializerName_; + snapshotName_ = other.snapshotName_; + saveSliceInfoDef_ = other.saveSliceInfoDef_ != null ? other.saveSliceInfoDef_.Clone() : null; + isResource_ = other.isResource_; + trainable_ = other.trainable_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariableDef Clone() { + return new VariableDef(this); + } + + /// Field number for the "variable_name" field. + public const int VariableNameFieldNumber = 1; + private string variableName_ = ""; + /// + /// Name of the variable tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string VariableName { + get { return variableName_; } + set { + variableName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "initial_value_name" field. + public const int InitialValueNameFieldNumber = 6; + private string initialValueName_ = ""; + /// + /// Name of the tensor holding the variable's initial value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string InitialValueName { + get { return initialValueName_; } + set { + initialValueName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "initializer_name" field. + public const int InitializerNameFieldNumber = 2; + private string initializerName_ = ""; + /// + /// Name of the initializer op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string InitializerName { + get { return initializerName_; } + set { + initializerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "snapshot_name" field. + public const int SnapshotNameFieldNumber = 3; + private string snapshotName_ = ""; + /// + /// Name of the snapshot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string SnapshotName { + get { return snapshotName_; } + set { + snapshotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "save_slice_info_def" field. + public const int SaveSliceInfoDefFieldNumber = 4; + private global::Tensorflow.SaveSliceInfoDef saveSliceInfoDef_; + /// + /// Support for saving variables as slices of a larger variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.SaveSliceInfoDef SaveSliceInfoDef { + get { return saveSliceInfoDef_; } + set { + saveSliceInfoDef_ = value; + } + } + + /// Field number for the "is_resource" field. + public const int IsResourceFieldNumber = 5; + private bool isResource_; + /// + /// Whether to represent this as a ResourceVariable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsResource { + get { return isResource_; } + set { + isResource_ = value; + } + } + + /// Field number for the "trainable" field. + public const int TrainableFieldNumber = 7; + private bool trainable_; + /// + /// Whether this variable should be trained. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Trainable { + get { return trainable_; } + set { + trainable_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as VariableDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(VariableDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (VariableName != other.VariableName) return false; + if (InitialValueName != other.InitialValueName) return false; + if (InitializerName != other.InitializerName) return false; + if (SnapshotName != other.SnapshotName) return false; + if (!object.Equals(SaveSliceInfoDef, other.SaveSliceInfoDef)) return false; + if (IsResource != other.IsResource) return false; + if (Trainable != other.Trainable) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (VariableName.Length != 0) hash ^= VariableName.GetHashCode(); + if (InitialValueName.Length != 0) hash ^= InitialValueName.GetHashCode(); + if (InitializerName.Length != 0) hash ^= InitializerName.GetHashCode(); + if (SnapshotName.Length != 0) hash ^= SnapshotName.GetHashCode(); + if (saveSliceInfoDef_ != null) hash ^= SaveSliceInfoDef.GetHashCode(); + if (IsResource != false) hash ^= IsResource.GetHashCode(); + if (Trainable != false) hash ^= Trainable.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (VariableName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(VariableName); + } + if (InitializerName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(InitializerName); + } + if (SnapshotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SnapshotName); + } + if (saveSliceInfoDef_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SaveSliceInfoDef); + } + if (IsResource != false) { + output.WriteRawTag(40); + output.WriteBool(IsResource); + } + if (InitialValueName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(InitialValueName); + } + if (Trainable != false) { + output.WriteRawTag(56); + output.WriteBool(Trainable); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (VariableName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VariableName); + } + if (InitialValueName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InitialValueName); + } + if (InitializerName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InitializerName); + } + if (SnapshotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SnapshotName); + } + if (saveSliceInfoDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaveSliceInfoDef); + } + if (IsResource != false) { + size += 1 + 1; + } + if (Trainable != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(VariableDef other) { + if (other == null) { + return; + } + if (other.VariableName.Length != 0) { + VariableName = other.VariableName; + } + if (other.InitialValueName.Length != 0) { + InitialValueName = other.InitialValueName; + } + if (other.InitializerName.Length != 0) { + InitializerName = other.InitializerName; + } + if (other.SnapshotName.Length != 0) { + SnapshotName = other.SnapshotName; + } + if (other.saveSliceInfoDef_ != null) { + if (saveSliceInfoDef_ == null) { + saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef(); + } + SaveSliceInfoDef.MergeFrom(other.SaveSliceInfoDef); + } + if (other.IsResource != false) { + IsResource = other.IsResource; + } + if (other.Trainable != false) { + Trainable = other.Trainable; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + VariableName = input.ReadString(); + break; + } + case 18: { + InitializerName = input.ReadString(); + break; + } + case 26: { + SnapshotName = input.ReadString(); + break; + } + case 34: { + if (saveSliceInfoDef_ == null) { + saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef(); + } + input.ReadMessage(saveSliceInfoDef_); + break; + } + case 40: { + IsResource = input.ReadBool(); + break; + } + case 50: { + InitialValueName = input.ReadString(); + break; + } + case 56: { + Trainable = input.ReadBool(); + break; + } + } + } + } + + } + + public sealed partial class SaveSliceInfoDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaveSliceInfoDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaveSliceInfoDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaveSliceInfoDef(SaveSliceInfoDef other) : this() { + fullName_ = other.fullName_; + fullShape_ = other.fullShape_.Clone(); + varOffset_ = other.varOffset_.Clone(); + varShape_ = other.varShape_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaveSliceInfoDef Clone() { + return new SaveSliceInfoDef(this); + } + + /// Field number for the "full_name" field. + public const int FullNameFieldNumber = 1; + private string fullName_ = ""; + /// + /// Name of the full variable of which this is a slice. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string FullName { + get { return fullName_; } + set { + fullName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "full_shape" field. + public const int FullShapeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_fullShape_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField fullShape_ = new pbc::RepeatedField(); + /// + /// Shape of the full variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField FullShape { + get { return fullShape_; } + } + + /// Field number for the "var_offset" field. + public const int VarOffsetFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_varOffset_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField varOffset_ = new pbc::RepeatedField(); + /// + /// Offset of this variable into the full variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VarOffset { + get { return varOffset_; } + } + + /// Field number for the "var_shape" field. + public const int VarShapeFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_varShape_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField varShape_ = new pbc::RepeatedField(); + /// + /// Shape of this variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VarShape { + get { return varShape_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SaveSliceInfoDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SaveSliceInfoDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FullName != other.FullName) return false; + if(!fullShape_.Equals(other.fullShape_)) return false; + if(!varOffset_.Equals(other.varOffset_)) return false; + if(!varShape_.Equals(other.varShape_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (FullName.Length != 0) hash ^= FullName.GetHashCode(); + hash ^= fullShape_.GetHashCode(); + hash ^= varOffset_.GetHashCode(); + hash ^= varShape_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (FullName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FullName); + } + fullShape_.WriteTo(output, _repeated_fullShape_codec); + varOffset_.WriteTo(output, _repeated_varOffset_codec); + varShape_.WriteTo(output, _repeated_varShape_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (FullName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FullName); + } + size += fullShape_.CalculateSize(_repeated_fullShape_codec); + size += varOffset_.CalculateSize(_repeated_varOffset_codec); + size += varShape_.CalculateSize(_repeated_varShape_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SaveSliceInfoDef other) { + if (other == null) { + return; + } + if (other.FullName.Length != 0) { + FullName = other.FullName; + } + fullShape_.Add(other.fullShape_); + varOffset_.Add(other.varOffset_); + varShape_.Add(other.varShape_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FullName = input.ReadString(); + break; + } + case 18: + case 16: { + fullShape_.AddEntriesFrom(input, _repeated_fullShape_codec); + break; + } + case 26: + case 24: { + varOffset_.AddEntriesFrom(input, _repeated_varOffset_codec); + break; + } + case 34: + case 32: { + varShape_.AddEntriesFrom(input, _repeated_varShape_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 96111cd9..d834b608 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -19,7 +19,7 @@ namespace Tensorflow private int _id; public int Id => _id; - public Graph Graph => op?.Graph; + public Graph Graph => op?.graph; public Operation op { get; } public Tensor[] outputs => op.outputs; @@ -48,7 +48,7 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { - c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status); + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); status.Check(); } else @@ -84,7 +84,7 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { var output = _as_tf_output(); - return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status); + return c_api.TF_GraphGetTensorNumDims(op.graph, output, status); } else { diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 78e25bd8..4f918b2a 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -28,6 +28,7 @@ namespace Tensorflow /// public static string GLOBAL_VARIABLES = "variables"; + public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" }; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 74575cff..b7ccfe86 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -287,6 +287,25 @@ namespace Tensorflow return tf.defaultSession; } + /// + /// Prepends name scope to a name. + /// + /// + /// + /// + public static string prepend_name_scope(string name, string import_scope) + { + if (!string.IsNullOrEmpty(import_scope)) + { + if (import_scope.EndsWith("/")) + import_scope = import_scope.Substring(0, import_scope.Length - 1); + + throw new NotImplementedException("prepend_name_scope"); + } + else + return name; + } + public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) { if (session == null)