diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs
index 4605e37a..36b74e5c 100644
--- a/src/TensorFlowNET.Core/Buffers/Buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs
@@ -34,6 +34,14 @@ namespace Tensorflow
_handle = handle;
}
+ public Buffer(byte[] data)
+ {
+ var dst = Marshal.AllocHGlobal(data.Length);
+ Marshal.Copy(data, 0, dst, data.Length);
+
+ _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length);
+ }
+
public static implicit operator IntPtr(Buffer buffer)
{
return buffer._handle;
diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
index d9792f12..08c71887 100644
--- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
@@ -19,5 +19,15 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
+
+ ///
+ /// Makes a copy of the input and sets an appropriate deallocator. Useful for
+ /// passing in read-only, input protobufs.
+ ///
+ /// const void*
+ /// size_t
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TF_NewBufferFromString(IntPtr proto, ulong proto_len);
}
}
diff --git a/src/TensorFlowNET.Core/Exceptions/KeyError.cs b/src/TensorFlowNET.Core/Exceptions/KeyError.cs
new file mode 100644
index 00000000..0dfa2f97
--- /dev/null
+++ b/src/TensorFlowNET.Core/Exceptions/KeyError.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class KeyError : Exception
+ {
+ public KeyError() : base()
+ {
+
+ }
+
+ public KeyError(string message) : base(message)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.py.cs b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs
index 2ff54c0e..94c0dac0 100644
--- a/src/TensorFlowNET.Core/Framework/c_api_util.py.cs
+++ b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs
@@ -10,13 +10,25 @@ namespace Tensorflow
public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();
- public static IntPtr tf_buffer(byte[] data)
+ public static Buffer tf_buffer(byte[] data) => new Buffer(data);
+
+ public static IEnumerable new_tf_operations(Graph graph)
+ {
+ foreach (var c_op in tf_operations(graph))
+ {
+ if (graph._get_operation_by_tf_operation(c_op) == null)
+ yield return c_op;
+ }
+ }
+
+ public static IEnumerable tf_operations(Graph graph)
{
- if (data != null)
- throw new NotImplementedException("");
- // var buf = c_api.TF_NewBufferFromString(data);
- else
- throw new NotImplementedException("");
+ uint pos = 0;
+ IntPtr c_op;
+ while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
+ {
+ yield return c_op;
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs
index 2fdc8985..e8c971c7 100644
--- a/src/TensorFlowNET.Core/Framework/importer.py.cs
+++ b/src/TensorFlowNET.Core/Framework/importer.py.cs
@@ -42,11 +42,27 @@ namespace Tensorflow
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
var bytes = graph_def.ToByteString().ToArray();
+ IntPtr buffer = c_api_util.tf_buffer(bytes);
var status = new Status();
- c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status);
+ // need to create a class ImportGraphDefWithResults with IDisposal
+ var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
+ status.Check(true);
- throw new NotImplementedException("importer.import_graph_def");
+ _ProcessNewOps(graph);
+
+ if (return_elements == null)
+ return null;
+ else
+ throw new NotImplementedException("import_graph_def return_elements");
+ }
+
+ private static void _ProcessNewOps(Graph graph)
+ {
+ foreach(var new_op in graph._add_new_tf_operations())
+ {
+ var original_device = new_op.Device;
+ }
}
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
index 6e8d354a..060c267a 100644
--- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
+++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
+using static Tensorflow.CollectionDef;
using static Tensorflow.MetaGraphDef.Types;
namespace Tensorflow
@@ -16,7 +17,7 @@ namespace Tensorflow
return meta_graph_def;
}
- public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
+ public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
Dictionary input_map = null,
@@ -51,7 +52,7 @@ namespace Tensorflow
node.Device = "";
var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
- importer.import_graph_def(input_graph_def,
+ var imported_return_elements = importer.import_graph_def(input_graph_def,
name: scope_to_prepend_to_names,
input_map: input_map,
producer_op_list: producer_op_list,
@@ -59,7 +60,41 @@ namespace Tensorflow
// Restores all the other collections.
var variable_objects = new Dictionary();
+ foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
+ {
+ // Don't add unbound_inputs to the new graph.
+ if (col.Key == unbound_inputs_col_name)
+ continue;
+
+ switch (col.Value.KindCase)
+ {
+ case KindOneofCase.NodeList:
+ foreach(var value in col.Value.NodeList.Value)
+ {
+ var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
+ graph.add_to_collection(col.Key, col_op);
+ }
+ break;
+ case KindOneofCase.BytesList:
+ //var proto_type = ops.get_collection_proto_type(key)
+ if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
+ {
+ foreach (var value in col.Value.BytesList.Value)
+ {
+ var proto = VariableDef.Parser.ParseFrom(value);
+ throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
+ }
+ }
+ else
+ {
+ throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
+ }
+
+ break;
+ }
+ }
+ return (null, null);
}
///
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index e2ff80e2..2f16b880 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -1,6 +1,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
@@ -22,5 +23,57 @@ namespace Tensorflow
{
return c_api.TF_NewOperation(_handle, opType, opName);
}
+
+ public ITensorOrOperation _get_operation_by_name_unsafe(string name)
+ {
+ return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null;
+ }
+
+ public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper)
+ {
+ var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper));
+ return _get_operation_by_name_unsafe(op_name);
+ }
+
+ public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
+ {
+ var ret = new Operation(c_op);
+
+ var name_key = ret.name.ToLower();
+ if (!_names_in_use.ContainsKey(name_key))
+ _names_in_use[name_key] = 1;
+
+ _create_op_helper(ret, compute_device: compute_device);
+
+ return ret;
+ }
+
+ ///
+ /// Creates `Operations` in this graph for any new TF_Operations.
+ ///
+ /// This is useful for when TF_Operations are indirectly created by the C API
+ /// outside of the Operation constructor (e.g. by TF_ImportGraphDef,
+ /// TF_FinishWhile). This ensures there are corresponding Operations for all
+ /// TF_Operations in the underlying TF_Graph.
+ ///
+ ///
+ ///
+ public IEnumerable _add_new_tf_operations(bool compute_devices = true)
+ {
+ var new_ops = c_api_util.new_tf_operations(this)
+ .Select(c_op => _create_op_from_tf_operation(c_op, compute_device: compute_devices))
+ .ToArray();
+
+ foreach(var op in new_ops)
+ {
+ var new_control_inputs = _control_dependencies_for_inputs(op.inputs)
+ .Select(x => x as Operation)
+ .ToArray();
+ op._add_control_inputs(new_control_inputs);
+ op._control_flow_post_processing();
+ }
+
+ return new_ops;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 8fba13d5..b281debf 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -86,6 +86,16 @@ namespace Tensorflow
if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];
}
+ else if(!name.Contains(":") & allow_operation)
+ {
+ if (!_nodes_by_name.ContainsKey(name))
+ throw new KeyError($"The name {name} refers to an Operation not in the graph.");
+ return _nodes_by_name[name];
+ }
+ else if (!name.Contains(":") & !allow_operation)
+ {
+ throw new NotImplementedException("_as_graph_element_locked");
+ }
}
if (obj is Tensor tensor && allow_tensor)
@@ -101,7 +111,7 @@ namespace Tensorflow
}
else if (obj is Operation op && allow_operation)
{
- if (op.Graph.Equals(this))
+ if (op.graph.Equals(this))
{
return op;
}
diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs
index 4f387120..9abe7303 100644
--- a/src/TensorFlowNET.Core/Operations/InputList.cs
+++ b/src/TensorFlowNET.Core/Operations/InputList.cs
@@ -26,5 +26,10 @@ namespace Tensorflow
{
return input._inputs.ToList();
}
+
+ public static implicit operator Tensor[](InputList input)
+ {
+ return input._inputs;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 5599ad2b..a51d1ca9 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -16,5 +16,13 @@ namespace Tensorflow
}
}
+
+ public void _add_control_inputs(Operation[] ops)
+ {
+ foreach(var op in ops)
+ {
+ c_api.TF_AddControlInput(graph, op);
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 00a9241b..45a57286 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -11,7 +11,7 @@ namespace Tensorflow
{
private readonly IntPtr _handle; // _c_op in python
- public Graph Graph { get; }
+ public Graph graph { get; }
public int _id => _id_value;
private int _id_value;
@@ -42,15 +42,17 @@ namespace Tensorflow
return;
_handle = handle;
- this.Graph = ops.get_default_graph();
+ this.graph = ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
+
+ graph._add_op(this);
}
public Operation(Graph g, string opType, string oper_name)
{
- Graph = g;
+ graph = g;
var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
@@ -78,7 +80,7 @@ namespace Tensorflow
///
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
- Graph = g;
+ graph = g;
// Build the list of control inputs.
var control_input_ops = new List();
@@ -99,7 +101,7 @@ namespace Tensorflow
// This will be set by self.inputs.
- _id_value = Graph._next_id();
+ _id_value = graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);
@@ -115,7 +117,7 @@ namespace Tensorflow
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
- Graph._add_op(this);
+ graph._add_op(this);
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
@@ -123,7 +125,7 @@ namespace Tensorflow
public void run(FeedItem[] feed_dict = null, Session session = null)
{
- ops._run_using_default_session(this, feed_dict, Graph, session);
+ ops._run_using_default_session(this, feed_dict, graph, session);
}
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs)
diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md
index 64f6e813..662c628e 100644
--- a/src/TensorFlowNET.Core/Protobuf/README.md
+++ b/src/TensorFlowNET.Core/Protobuf/README.md
@@ -3,6 +3,8 @@
set SRC_DIR=D:\Projects\tensorflow
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf
+cd tensorflow
+
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto
@@ -12,6 +14,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.pr
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\variable.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto
diff --git a/src/TensorFlowNET.Core/Protobuf/Variable.cs b/src/TensorFlowNET.Core/Protobuf/Variable.cs
new file mode 100644
index 00000000..18cdd6d2
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/Variable.cs
@@ -0,0 +1,584 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: tensorflow/core/framework/variable.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from tensorflow/core/framework/variable.proto
+ public static partial class VariableReflection {
+
+ #region Descriptor
+ /// File descriptor for tensorflow/core/framework/variable.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static VariableReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZhcmlhYmxlLnByb3RvEgp0",
+ "ZW5zb3JmbG93ItQBCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg",
+ "ASgJEhoKEmluaXRpYWxfdmFsdWVfbmFtZRgGIAEoCRIYChBpbml0aWFsaXpl",
+ "cl9uYW1lGAIgASgJEhUKDXNuYXBzaG90X25hbWUYAyABKAkSOQoTc2F2ZV9z",
+ "bGljZV9pbmZvX2RlZhgEIAEoCzIcLnRlbnNvcmZsb3cuU2F2ZVNsaWNlSW5m",
+ "b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgi",
+ "YAoQU2F2ZVNsaWNlSW5mb0RlZhIRCglmdWxsX25hbWUYASABKAkSEgoKZnVs",
+ "bF9zaGFwZRgCIAMoAxISCgp2YXJfb2Zmc2V0GAMgAygDEhEKCXZhcl9zaGFw",
+ "ZRgEIAMoA0JuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDlZhcmlhYmxl",
+ "UHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3Rl",
+ "bnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw=="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable" }, null, null, null),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveSliceInfoDef), global::Tensorflow.SaveSliceInfoDef.Parser, new[]{ "FullName", "FullShape", "VarOffset", "VarShape" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// Protocol buffer representing a Variable.
+ ///
+ public sealed partial class VariableDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VariableDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VariableDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VariableDef(VariableDef other) : this() {
+ variableName_ = other.variableName_;
+ initialValueName_ = other.initialValueName_;
+ initializerName_ = other.initializerName_;
+ snapshotName_ = other.snapshotName_;
+ saveSliceInfoDef_ = other.saveSliceInfoDef_ != null ? other.saveSliceInfoDef_.Clone() : null;
+ isResource_ = other.isResource_;
+ trainable_ = other.trainable_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VariableDef Clone() {
+ return new VariableDef(this);
+ }
+
+ /// Field number for the "variable_name" field.
+ public const int VariableNameFieldNumber = 1;
+ private string variableName_ = "";
+ ///
+ /// Name of the variable tensor.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string VariableName {
+ get { return variableName_; }
+ set {
+ variableName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "initial_value_name" field.
+ public const int InitialValueNameFieldNumber = 6;
+ private string initialValueName_ = "";
+ ///
+ /// Name of the tensor holding the variable's initial value.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string InitialValueName {
+ get { return initialValueName_; }
+ set {
+ initialValueName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "initializer_name" field.
+ public const int InitializerNameFieldNumber = 2;
+ private string initializerName_ = "";
+ ///
+ /// Name of the initializer op.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string InitializerName {
+ get { return initializerName_; }
+ set {
+ initializerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "snapshot_name" field.
+ public const int SnapshotNameFieldNumber = 3;
+ private string snapshotName_ = "";
+ ///
+ /// Name of the snapshot tensor.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string SnapshotName {
+ get { return snapshotName_; }
+ set {
+ snapshotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "save_slice_info_def" field.
+ public const int SaveSliceInfoDefFieldNumber = 4;
+ private global::Tensorflow.SaveSliceInfoDef saveSliceInfoDef_;
+ ///
+ /// Support for saving variables as slices of a larger variable.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.SaveSliceInfoDef SaveSliceInfoDef {
+ get { return saveSliceInfoDef_; }
+ set {
+ saveSliceInfoDef_ = value;
+ }
+ }
+
+ /// Field number for the "is_resource" field.
+ public const int IsResourceFieldNumber = 5;
+ private bool isResource_;
+ ///
+ /// Whether to represent this as a ResourceVariable.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool IsResource {
+ get { return isResource_; }
+ set {
+ isResource_ = value;
+ }
+ }
+
+ /// Field number for the "trainable" field.
+ public const int TrainableFieldNumber = 7;
+ private bool trainable_;
+ ///
+ /// Whether this variable should be trained.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Trainable {
+ get { return trainable_; }
+ set {
+ trainable_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as VariableDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(VariableDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (VariableName != other.VariableName) return false;
+ if (InitialValueName != other.InitialValueName) return false;
+ if (InitializerName != other.InitializerName) return false;
+ if (SnapshotName != other.SnapshotName) return false;
+ if (!object.Equals(SaveSliceInfoDef, other.SaveSliceInfoDef)) return false;
+ if (IsResource != other.IsResource) return false;
+ if (Trainable != other.Trainable) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (VariableName.Length != 0) hash ^= VariableName.GetHashCode();
+ if (InitialValueName.Length != 0) hash ^= InitialValueName.GetHashCode();
+ if (InitializerName.Length != 0) hash ^= InitializerName.GetHashCode();
+ if (SnapshotName.Length != 0) hash ^= SnapshotName.GetHashCode();
+ if (saveSliceInfoDef_ != null) hash ^= SaveSliceInfoDef.GetHashCode();
+ if (IsResource != false) hash ^= IsResource.GetHashCode();
+ if (Trainable != false) hash ^= Trainable.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (VariableName.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(VariableName);
+ }
+ if (InitializerName.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(InitializerName);
+ }
+ if (SnapshotName.Length != 0) {
+ output.WriteRawTag(26);
+ output.WriteString(SnapshotName);
+ }
+ if (saveSliceInfoDef_ != null) {
+ output.WriteRawTag(34);
+ output.WriteMessage(SaveSliceInfoDef);
+ }
+ if (IsResource != false) {
+ output.WriteRawTag(40);
+ output.WriteBool(IsResource);
+ }
+ if (InitialValueName.Length != 0) {
+ output.WriteRawTag(50);
+ output.WriteString(InitialValueName);
+ }
+ if (Trainable != false) {
+ output.WriteRawTag(56);
+ output.WriteBool(Trainable);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (VariableName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(VariableName);
+ }
+ if (InitialValueName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(InitialValueName);
+ }
+ if (InitializerName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(InitializerName);
+ }
+ if (SnapshotName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(SnapshotName);
+ }
+ if (saveSliceInfoDef_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaveSliceInfoDef);
+ }
+ if (IsResource != false) {
+ size += 1 + 1;
+ }
+ if (Trainable != false) {
+ size += 1 + 1;
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(VariableDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.VariableName.Length != 0) {
+ VariableName = other.VariableName;
+ }
+ if (other.InitialValueName.Length != 0) {
+ InitialValueName = other.InitialValueName;
+ }
+ if (other.InitializerName.Length != 0) {
+ InitializerName = other.InitializerName;
+ }
+ if (other.SnapshotName.Length != 0) {
+ SnapshotName = other.SnapshotName;
+ }
+ if (other.saveSliceInfoDef_ != null) {
+ if (saveSliceInfoDef_ == null) {
+ saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef();
+ }
+ SaveSliceInfoDef.MergeFrom(other.SaveSliceInfoDef);
+ }
+ if (other.IsResource != false) {
+ IsResource = other.IsResource;
+ }
+ if (other.Trainable != false) {
+ Trainable = other.Trainable;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ VariableName = input.ReadString();
+ break;
+ }
+ case 18: {
+ InitializerName = input.ReadString();
+ break;
+ }
+ case 26: {
+ SnapshotName = input.ReadString();
+ break;
+ }
+ case 34: {
+ if (saveSliceInfoDef_ == null) {
+ saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef();
+ }
+ input.ReadMessage(saveSliceInfoDef_);
+ break;
+ }
+ case 40: {
+ IsResource = input.ReadBool();
+ break;
+ }
+ case 50: {
+ InitialValueName = input.ReadString();
+ break;
+ }
+ case 56: {
+ Trainable = input.ReadBool();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ public sealed partial class SaveSliceInfoDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaveSliceInfoDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[1]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaveSliceInfoDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaveSliceInfoDef(SaveSliceInfoDef other) : this() {
+ fullName_ = other.fullName_;
+ fullShape_ = other.fullShape_.Clone();
+ varOffset_ = other.varOffset_.Clone();
+ varShape_ = other.varShape_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SaveSliceInfoDef Clone() {
+ return new SaveSliceInfoDef(this);
+ }
+
+ /// Field number for the "full_name" field.
+ public const int FullNameFieldNumber = 1;
+ private string fullName_ = "";
+ ///
+ /// Name of the full variable of which this is a slice.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string FullName {
+ get { return fullName_; }
+ set {
+ fullName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "full_shape" field.
+ public const int FullShapeFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_fullShape_codec
+ = pb::FieldCodec.ForInt64(18);
+ private readonly pbc::RepeatedField fullShape_ = new pbc::RepeatedField();
+ ///
+ /// Shape of the full variable.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField FullShape {
+ get { return fullShape_; }
+ }
+
+ /// Field number for the "var_offset" field.
+ public const int VarOffsetFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_varOffset_codec
+ = pb::FieldCodec.ForInt64(26);
+ private readonly pbc::RepeatedField varOffset_ = new pbc::RepeatedField();
+ ///
+ /// Offset of this variable into the full variable.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VarOffset {
+ get { return varOffset_; }
+ }
+
+ /// Field number for the "var_shape" field.
+ public const int VarShapeFieldNumber = 4;
+ private static readonly pb::FieldCodec _repeated_varShape_codec
+ = pb::FieldCodec.ForInt64(34);
+ private readonly pbc::RepeatedField varShape_ = new pbc::RepeatedField();
+ ///
+ /// Shape of this variable.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField VarShape {
+ get { return varShape_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as SaveSliceInfoDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(SaveSliceInfoDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (FullName != other.FullName) return false;
+ if(!fullShape_.Equals(other.fullShape_)) return false;
+ if(!varOffset_.Equals(other.varOffset_)) return false;
+ if(!varShape_.Equals(other.varShape_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (FullName.Length != 0) hash ^= FullName.GetHashCode();
+ hash ^= fullShape_.GetHashCode();
+ hash ^= varOffset_.GetHashCode();
+ hash ^= varShape_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (FullName.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(FullName);
+ }
+ fullShape_.WriteTo(output, _repeated_fullShape_codec);
+ varOffset_.WriteTo(output, _repeated_varOffset_codec);
+ varShape_.WriteTo(output, _repeated_varShape_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (FullName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(FullName);
+ }
+ size += fullShape_.CalculateSize(_repeated_fullShape_codec);
+ size += varOffset_.CalculateSize(_repeated_varOffset_codec);
+ size += varShape_.CalculateSize(_repeated_varShape_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(SaveSliceInfoDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.FullName.Length != 0) {
+ FullName = other.FullName;
+ }
+ fullShape_.Add(other.fullShape_);
+ varOffset_.Add(other.varOffset_);
+ varShape_.Add(other.varShape_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ FullName = input.ReadString();
+ break;
+ }
+ case 18:
+ case 16: {
+ fullShape_.AddEntriesFrom(input, _repeated_fullShape_codec);
+ break;
+ }
+ case 26:
+ case 24: {
+ varOffset_.AddEntriesFrom(input, _repeated_varOffset_codec);
+ break;
+ }
+ case 34:
+ case 32: {
+ varShape_.AddEntriesFrom(input, _repeated_varShape_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 96111cd9..d834b608 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -19,7 +19,7 @@ namespace Tensorflow
private int _id;
public int Id => _id;
- public Graph Graph => op?.Graph;
+ public Graph Graph => op?.graph;
public Operation op { get; }
public Tensor[] outputs => op.outputs;
@@ -48,7 +48,7 @@ namespace Tensorflow
if (_handle == IntPtr.Zero)
{
- c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status);
+ c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
else
@@ -84,7 +84,7 @@ namespace Tensorflow
if (_handle == IntPtr.Zero)
{
var output = _as_tf_output();
- return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status);
+ return c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
}
else
{
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 78e25bd8..4f918b2a 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -28,6 +28,7 @@ namespace Tensorflow
///
public static string GLOBAL_VARIABLES = "variables";
+ public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" };
///
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
///
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 74575cff..b7ccfe86 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -287,6 +287,25 @@ namespace Tensorflow
return tf.defaultSession;
}
+ ///
+ /// Prepends name scope to a name.
+ ///
+ ///
+ ///
+ ///
+ public static string prepend_name_scope(string name, string import_scope)
+ {
+ if (!string.IsNullOrEmpty(import_scope))
+ {
+ if (import_scope.EndsWith("/"))
+ import_scope = import_scope.Substring(0, import_scope.Length - 1);
+
+ throw new NotImplementedException("prepend_name_scope");
+ }
+ else
+ return name;
+ }
+
public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
{
if (session == null)