Browse Source

added KeyError exception.

added VariableDef
tf_buffer, tf_operations
tags/v0.8.0
Oceania2018 6 years ago
parent
commit
5c2818e389
16 changed files with 806 additions and 21 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  3. +19
    -0
      src/TensorFlowNET.Core/Exceptions/KeyError.cs
  4. +18
    -6
      src/TensorFlowNET.Core/Framework/c_api_util.py.cs
  5. +18
    -2
      src/TensorFlowNET.Core/Framework/importer.py.cs
  6. +37
    -2
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  7. +53
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  8. +11
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  9. +5
    -0
      src/TensorFlowNET.Core/Operations/InputList.cs
  10. +8
    -0
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  11. +9
    -7
      src/TensorFlowNET.Core/Operations/Operation.cs
  12. +3
    -0
      src/TensorFlowNET.Core/Protobuf/README.md
  13. +584
    -0
      src/TensorFlowNET.Core/Protobuf/Variable.cs
  14. +3
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  15. +1
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  16. +19
    -0
      src/TensorFlowNET.Core/ops.py.cs

+ 8
- 0
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -34,6 +34,14 @@ namespace Tensorflow
_handle = handle;
}

public Buffer(byte[] data)
{
var dst = Marshal.AllocHGlobal(data.Length);
Marshal.Copy(data, 0, dst, data.Length);

_handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length);
}

public static implicit operator IntPtr(Buffer buffer)
{
return buffer._handle;


+ 10
- 0
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -19,5 +19,15 @@ namespace Tensorflow

[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);

/// <summary>
/// Makes a copy of the input and sets an appropriate deallocator. Useful for
/// passing in read-only, input protobufs.
/// </summary>
/// <param name="proto">const void*</param>
/// <param name="proto_len">size_t</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewBufferFromString(IntPtr proto, ulong proto_len);
}
}

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/KeyError.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class KeyError : Exception
{
public KeyError() : base()
{

}

public KeyError(string message) : base(message)
{

}
}
}

+ 18
- 6
src/TensorFlowNET.Core/Framework/c_api_util.py.cs View File

@@ -10,13 +10,25 @@ namespace Tensorflow

public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();

public static IntPtr tf_buffer(byte[] data)
public static Buffer tf_buffer(byte[] data) => new Buffer(data);

public static IEnumerable<Operation> new_tf_operations(Graph graph)
{
foreach (var c_op in tf_operations(graph))
{
if (graph._get_operation_by_tf_operation(c_op) == null)
yield return c_op;
}
}

public static IEnumerable<Operation> tf_operations(Graph graph)
{
if (data != null)
throw new NotImplementedException("");
// var buf = c_api.TF_NewBufferFromString(data);
else
throw new NotImplementedException("");
uint pos = 0;
IntPtr c_op;
while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
{
yield return c_op;
}
}
}
}

+ 18
- 2
src/TensorFlowNET.Core/Framework/importer.py.cs View File

@@ -42,11 +42,27 @@ namespace Tensorflow
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);

var bytes = graph_def.ToByteString().ToArray();
IntPtr buffer = c_api_util.tf_buffer(bytes);

var status = new Status();
c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status);
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);

throw new NotImplementedException("importer.import_graph_def");
_ProcessNewOps(graph);

if (return_elements == null)
return null;
else
throw new NotImplementedException("import_graph_def return_elements");
}

private static void _ProcessNewOps(Graph graph)
{
foreach(var new_op in graph._add_new_tf_operations())
{
var original_device = new_op.Device;
}
}

public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,


+ 37
- 2
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using static Tensorflow.CollectionDef;
using static Tensorflow.MetaGraphDef.Types;

namespace Tensorflow
@@ -16,7 +17,7 @@ namespace Tensorflow
return meta_graph_def;
}

public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
Dictionary<string, Tensor> input_map = null,
@@ -51,7 +52,7 @@ namespace Tensorflow
node.Device = "";

var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
importer.import_graph_def(input_graph_def,
var imported_return_elements = importer.import_graph_def(input_graph_def,
name: scope_to_prepend_to_names,
input_map: input_map,
producer_op_list: producer_op_list,
@@ -59,7 +60,41 @@ namespace Tensorflow

// Restores all the other collections.
var variable_objects = new Dictionary<string, RefVariable>();
foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
{
// Don't add unbound_inputs to the new graph.
if (col.Key == unbound_inputs_col_name)
continue;

switch (col.Value.KindCase)
{
case KindOneofCase.NodeList:
foreach(var value in col.Value.NodeList.Value)
{
var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
graph.add_to_collection(col.Key, col_op);
}
break;
case KindOneofCase.BytesList:
//var proto_type = ops.get_collection_proto_type(key)
if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
{
foreach (var value in col.Value.BytesList.Value)
{
var proto = VariableDef.Parser.ParseFrom(value);
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
}
else
{
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
break;
}
}

return (null, null);
}

/// <summary>


+ 53
- 0
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -1,6 +1,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

@@ -22,5 +23,57 @@ namespace Tensorflow
{
return c_api.TF_NewOperation(_handle, opType, opName);
}

public ITensorOrOperation _get_operation_by_name_unsafe(string name)
{
return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null;
}

public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper)
{
var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper));
return _get_operation_by_name_unsafe(op_name);
}

public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
{
var ret = new Operation(c_op);

var name_key = ret.name.ToLower();
if (!_names_in_use.ContainsKey(name_key))
_names_in_use[name_key] = 1;

_create_op_helper(ret, compute_device: compute_device);

return ret;
}

/// <summary>
/// Creates `Operations` in this graph for any new TF_Operations.
///
/// This is useful for when TF_Operations are indirectly created by the C API
/// outside of the Operation constructor (e.g. by TF_ImportGraphDef,
/// TF_FinishWhile). This ensures there are corresponding Operations for all
/// TF_Operations in the underlying TF_Graph.
/// </summary>
/// <param name="compute_devices"></param>
/// <returns></returns>
public IEnumerable<Operation> _add_new_tf_operations(bool compute_devices = true)
{
var new_ops = c_api_util.new_tf_operations(this)
.Select(c_op => _create_op_from_tf_operation(c_op, compute_device: compute_devices))
.ToArray();

foreach(var op in new_ops)
{
var new_control_inputs = _control_dependencies_for_inputs(op.inputs)
.Select(x => x as Operation)
.ToArray();
op._add_control_inputs(new_control_inputs);
op._control_flow_post_processing();
}

return new_ops;
}
}
}

+ 11
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -86,6 +86,16 @@ namespace Tensorflow
if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];
}
else if(!name.Contains(":") & allow_operation)
{
if (!_nodes_by_name.ContainsKey(name))
throw new KeyError($"The name {name} refers to an Operation not in the graph.");
return _nodes_by_name[name];
}
else if (!name.Contains(":") & !allow_operation)
{
throw new NotImplementedException("_as_graph_element_locked");
}
}

if (obj is Tensor tensor && allow_tensor)
@@ -101,7 +111,7 @@ namespace Tensorflow
}
else if (obj is Operation op && allow_operation)
{
if (op.Graph.Equals(this))
if (op.graph.Equals(this))
{
return op;
}


+ 5
- 0
src/TensorFlowNET.Core/Operations/InputList.cs View File

@@ -26,5 +26,10 @@ namespace Tensorflow
{
return input._inputs.ToList();
}

public static implicit operator Tensor[](InputList input)
{
return input._inputs;
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -16,5 +16,13 @@ namespace Tensorflow

}
}

public void _add_control_inputs(Operation[] ops)
{
foreach(var op in ops)
{
c_api.TF_AddControlInput(graph, op);
}
}
}
}

+ 9
- 7
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow
{
private readonly IntPtr _handle; // _c_op in python

public Graph Graph { get; }
public Graph graph { get; }
public int _id => _id_value;
private int _id_value;

@@ -42,15 +42,17 @@ namespace Tensorflow
return;

_handle = handle;
this.Graph = ops.get_default_graph();
this.graph = ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));

graph._add_op(this);
}

public Operation(Graph g, string opType, string oper_name)
{
Graph = g;
graph = g;

var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
@@ -78,7 +80,7 @@ namespace Tensorflow
/// <param name="op_def"></param>
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
Graph = g;
graph = g;

// Build the list of control inputs.
var control_input_ops = new List<Operation>();
@@ -99,7 +101,7 @@ namespace Tensorflow

// This will be set by self.inputs.

_id_value = Graph._next_id();
_id_value = graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);

@@ -115,7 +117,7 @@ namespace Tensorflow
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));

Graph._add_op(this);
graph._add_op(this);

if (_handle != IntPtr.Zero)
_control_flow_post_processing();
@@ -123,7 +125,7 @@ namespace Tensorflow

public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, Graph, session);
ops._run_using_default_session(this, feed_dict, graph, session);
}

private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)


+ 3
- 0
src/TensorFlowNET.Core/Protobuf/README.md View File

@@ -3,6 +3,8 @@
set SRC_DIR=D:\Projects\tensorflow
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf

cd tensorflow

protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\types.proto
@@ -12,6 +14,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\node_def.pr
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\versions.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\function.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\framework\variable.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\core\protobuf\meta_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow\python\training\checkpoint_state.proto


+ 584
- 0
src/TensorFlowNET.Core/Protobuf/Variable.cs View File

@@ -0,0 +1,584 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/framework/variable.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from tensorflow/core/framework/variable.proto</summary>
public static partial class VariableReflection {

#region Descriptor
/// <summary>File descriptor for tensorflow/core/framework/variable.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static VariableReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZhcmlhYmxlLnByb3RvEgp0",
"ZW5zb3JmbG93ItQBCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg",
"ASgJEhoKEmluaXRpYWxfdmFsdWVfbmFtZRgGIAEoCRIYChBpbml0aWFsaXpl",
"cl9uYW1lGAIgASgJEhUKDXNuYXBzaG90X25hbWUYAyABKAkSOQoTc2F2ZV9z",
"bGljZV9pbmZvX2RlZhgEIAEoCzIcLnRlbnNvcmZsb3cuU2F2ZVNsaWNlSW5m",
"b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgi",
"YAoQU2F2ZVNsaWNlSW5mb0RlZhIRCglmdWxsX25hbWUYASABKAkSEgoKZnVs",
"bF9zaGFwZRgCIAMoAxISCgp2YXJfb2Zmc2V0GAMgAygDEhEKCXZhcl9zaGFw",
"ZRgEIAMoA0JuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDlZhcmlhYmxl",
"UHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3Rl",
"bnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveSliceInfoDef), global::Tensorflow.SaveSliceInfoDef.Parser, new[]{ "FullName", "FullShape", "VarOffset", "VarShape" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Protocol buffer representing a Variable.
/// </summary>
public sealed partial class VariableDef : pb::IMessage<VariableDef> {
private static readonly pb::MessageParser<VariableDef> _parser = new pb::MessageParser<VariableDef>(() => new VariableDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<VariableDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VariableDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VariableDef(VariableDef other) : this() {
variableName_ = other.variableName_;
initialValueName_ = other.initialValueName_;
initializerName_ = other.initializerName_;
snapshotName_ = other.snapshotName_;
saveSliceInfoDef_ = other.saveSliceInfoDef_ != null ? other.saveSliceInfoDef_.Clone() : null;
isResource_ = other.isResource_;
trainable_ = other.trainable_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VariableDef Clone() {
return new VariableDef(this);
}

/// <summary>Field number for the "variable_name" field.</summary>
public const int VariableNameFieldNumber = 1;
private string variableName_ = "";
/// <summary>
/// Name of the variable tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string VariableName {
get { return variableName_; }
set {
variableName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "initial_value_name" field.</summary>
public const int InitialValueNameFieldNumber = 6;
private string initialValueName_ = "";
/// <summary>
/// Name of the tensor holding the variable's initial value.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string InitialValueName {
get { return initialValueName_; }
set {
initialValueName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "initializer_name" field.</summary>
public const int InitializerNameFieldNumber = 2;
private string initializerName_ = "";
/// <summary>
/// Name of the initializer op.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string InitializerName {
get { return initializerName_; }
set {
initializerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "snapshot_name" field.</summary>
public const int SnapshotNameFieldNumber = 3;
private string snapshotName_ = "";
/// <summary>
/// Name of the snapshot tensor.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string SnapshotName {
get { return snapshotName_; }
set {
snapshotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "save_slice_info_def" field.</summary>
public const int SaveSliceInfoDefFieldNumber = 4;
private global::Tensorflow.SaveSliceInfoDef saveSliceInfoDef_;
/// <summary>
/// Support for saving variables as slices of a larger variable.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.SaveSliceInfoDef SaveSliceInfoDef {
get { return saveSliceInfoDef_; }
set {
saveSliceInfoDef_ = value;
}
}

/// <summary>Field number for the "is_resource" field.</summary>
public const int IsResourceFieldNumber = 5;
private bool isResource_;
/// <summary>
/// Whether to represent this as a ResourceVariable.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool IsResource {
get { return isResource_; }
set {
isResource_ = value;
}
}

/// <summary>Field number for the "trainable" field.</summary>
public const int TrainableFieldNumber = 7;
private bool trainable_;
/// <summary>
/// Whether this variable should be trained.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Trainable {
get { return trainable_; }
set {
trainable_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as VariableDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(VariableDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (VariableName != other.VariableName) return false;
if (InitialValueName != other.InitialValueName) return false;
if (InitializerName != other.InitializerName) return false;
if (SnapshotName != other.SnapshotName) return false;
if (!object.Equals(SaveSliceInfoDef, other.SaveSliceInfoDef)) return false;
if (IsResource != other.IsResource) return false;
if (Trainable != other.Trainable) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (VariableName.Length != 0) hash ^= VariableName.GetHashCode();
if (InitialValueName.Length != 0) hash ^= InitialValueName.GetHashCode();
if (InitializerName.Length != 0) hash ^= InitializerName.GetHashCode();
if (SnapshotName.Length != 0) hash ^= SnapshotName.GetHashCode();
if (saveSliceInfoDef_ != null) hash ^= SaveSliceInfoDef.GetHashCode();
if (IsResource != false) hash ^= IsResource.GetHashCode();
if (Trainable != false) hash ^= Trainable.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (VariableName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(VariableName);
}
if (InitializerName.Length != 0) {
output.WriteRawTag(18);
output.WriteString(InitializerName);
}
if (SnapshotName.Length != 0) {
output.WriteRawTag(26);
output.WriteString(SnapshotName);
}
if (saveSliceInfoDef_ != null) {
output.WriteRawTag(34);
output.WriteMessage(SaveSliceInfoDef);
}
if (IsResource != false) {
output.WriteRawTag(40);
output.WriteBool(IsResource);
}
if (InitialValueName.Length != 0) {
output.WriteRawTag(50);
output.WriteString(InitialValueName);
}
if (Trainable != false) {
output.WriteRawTag(56);
output.WriteBool(Trainable);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (VariableName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(VariableName);
}
if (InitialValueName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(InitialValueName);
}
if (InitializerName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(InitializerName);
}
if (SnapshotName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(SnapshotName);
}
if (saveSliceInfoDef_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaveSliceInfoDef);
}
if (IsResource != false) {
size += 1 + 1;
}
if (Trainable != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(VariableDef other) {
if (other == null) {
return;
}
if (other.VariableName.Length != 0) {
VariableName = other.VariableName;
}
if (other.InitialValueName.Length != 0) {
InitialValueName = other.InitialValueName;
}
if (other.InitializerName.Length != 0) {
InitializerName = other.InitializerName;
}
if (other.SnapshotName.Length != 0) {
SnapshotName = other.SnapshotName;
}
if (other.saveSliceInfoDef_ != null) {
if (saveSliceInfoDef_ == null) {
saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef();
}
SaveSliceInfoDef.MergeFrom(other.SaveSliceInfoDef);
}
if (other.IsResource != false) {
IsResource = other.IsResource;
}
if (other.Trainable != false) {
Trainable = other.Trainable;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
VariableName = input.ReadString();
break;
}
case 18: {
InitializerName = input.ReadString();
break;
}
case 26: {
SnapshotName = input.ReadString();
break;
}
case 34: {
if (saveSliceInfoDef_ == null) {
saveSliceInfoDef_ = new global::Tensorflow.SaveSliceInfoDef();
}
input.ReadMessage(saveSliceInfoDef_);
break;
}
case 40: {
IsResource = input.ReadBool();
break;
}
case 50: {
InitialValueName = input.ReadString();
break;
}
case 56: {
Trainable = input.ReadBool();
break;
}
}
}
}

}

public sealed partial class SaveSliceInfoDef : pb::IMessage<SaveSliceInfoDef> {
private static readonly pb::MessageParser<SaveSliceInfoDef> _parser = new pb::MessageParser<SaveSliceInfoDef>(() => new SaveSliceInfoDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<SaveSliceInfoDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaveSliceInfoDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaveSliceInfoDef(SaveSliceInfoDef other) : this() {
fullName_ = other.fullName_;
fullShape_ = other.fullShape_.Clone();
varOffset_ = other.varOffset_.Clone();
varShape_ = other.varShape_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaveSliceInfoDef Clone() {
return new SaveSliceInfoDef(this);
}

/// <summary>Field number for the "full_name" field.</summary>
public const int FullNameFieldNumber = 1;
private string fullName_ = "";
/// <summary>
/// Name of the full variable of which this is a slice.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string FullName {
get { return fullName_; }
set {
fullName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "full_shape" field.</summary>
public const int FullShapeFieldNumber = 2;
private static readonly pb::FieldCodec<long> _repeated_fullShape_codec
= pb::FieldCodec.ForInt64(18);
private readonly pbc::RepeatedField<long> fullShape_ = new pbc::RepeatedField<long>();
/// <summary>
/// Shape of the full variable.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<long> FullShape {
get { return fullShape_; }
}

/// <summary>Field number for the "var_offset" field.</summary>
public const int VarOffsetFieldNumber = 3;
private static readonly pb::FieldCodec<long> _repeated_varOffset_codec
= pb::FieldCodec.ForInt64(26);
private readonly pbc::RepeatedField<long> varOffset_ = new pbc::RepeatedField<long>();
/// <summary>
/// Offset of this variable into the full variable.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<long> VarOffset {
get { return varOffset_; }
}

/// <summary>Field number for the "var_shape" field.</summary>
public const int VarShapeFieldNumber = 4;
private static readonly pb::FieldCodec<long> _repeated_varShape_codec
= pb::FieldCodec.ForInt64(34);
private readonly pbc::RepeatedField<long> varShape_ = new pbc::RepeatedField<long>();
/// <summary>
/// Shape of this variable.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<long> VarShape {
get { return varShape_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as SaveSliceInfoDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(SaveSliceInfoDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (FullName != other.FullName) return false;
if(!fullShape_.Equals(other.fullShape_)) return false;
if(!varOffset_.Equals(other.varOffset_)) return false;
if(!varShape_.Equals(other.varShape_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (FullName.Length != 0) hash ^= FullName.GetHashCode();
hash ^= fullShape_.GetHashCode();
hash ^= varOffset_.GetHashCode();
hash ^= varShape_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (FullName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(FullName);
}
fullShape_.WriteTo(output, _repeated_fullShape_codec);
varOffset_.WriteTo(output, _repeated_varOffset_codec);
varShape_.WriteTo(output, _repeated_varShape_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (FullName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(FullName);
}
size += fullShape_.CalculateSize(_repeated_fullShape_codec);
size += varOffset_.CalculateSize(_repeated_varOffset_codec);
size += varShape_.CalculateSize(_repeated_varShape_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(SaveSliceInfoDef other) {
if (other == null) {
return;
}
if (other.FullName.Length != 0) {
FullName = other.FullName;
}
fullShape_.Add(other.fullShape_);
varOffset_.Add(other.varOffset_);
varShape_.Add(other.varShape_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
FullName = input.ReadString();
break;
}
case 18:
case 16: {
fullShape_.AddEntriesFrom(input, _repeated_fullShape_codec);
break;
}
case 26:
case 24: {
varOffset_.AddEntriesFrom(input, _repeated_varOffset_codec);
break;
}
case 34:
case 32: {
varShape_.AddEntriesFrom(input, _repeated_varShape_codec);
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 3
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow
private int _id;
public int Id => _id;

public Graph Graph => op?.Graph;
public Graph Graph => op?.graph;
public Operation op { get; }
public Tensor[] outputs => op.outputs;

@@ -48,7 +48,7 @@ namespace Tensorflow

if (_handle == IntPtr.Zero)
{
c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status);
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
else
@@ -84,7 +84,7 @@ namespace Tensorflow
if (_handle == IntPtr.Zero)
{
var output = _as_tf_output();
return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status);
return c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
}
else
{


+ 1
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -28,6 +28,7 @@ namespace Tensorflow
/// </summary>
public static string GLOBAL_VARIABLES = "variables";

public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" };
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>


+ 19
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -287,6 +287,25 @@ namespace Tensorflow
return tf.defaultSession;
}

/// <summary>
/// Prepends name scope to a name.
/// </summary>
/// <param name="name"></param>
/// <param name="import_scope"></param>
/// <returns></returns>
public static string prepend_name_scope(string name, string import_scope)
{
if (!string.IsNullOrEmpty(import_scope))
{
if (import_scope.EndsWith("/"))
import_scope = import_scope.Substring(0, import_scope.Length - 1);

throw new NotImplementedException("prepend_name_scope");
}
else
return name;
}

public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
{
if (session == null)


Loading…
Cancel
Save