Browse Source

fix shape issue for IndexedSlice

tags/v0.9
Oceania2018 6 years ago
parent
commit
89f305ca08
7 changed files with 41 additions and 10 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  3. +5
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +7
    -2
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  5. +2
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  6. +15
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +8
    -2
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 3
- 1
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -153,7 +153,9 @@ namespace Tensorflow
{
if (in_grad != null)
{
if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE)
if (in_grad is Tensor &&
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
{
in_grad.shape = t_in.shape;
}


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
}

return ret.ToArray();
return ret.OrderBy(x => x.op.name).ToArray();
}

/// <summary>


+ 5
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -245,9 +245,12 @@ namespace Tensorflow
// If a names ends with a '/' it is a "name scope" and we use it as-is,
// after removing the trailing '/'.
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);

var input_ops = inputs.Select(x => x.op).ToArray();
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape")
;

var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);

var op = new Operation(node_def,


+ 7
- 2
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -1,4 +1,7 @@
using System;
#if GRAPH_SERIALIZE
using Newtonsoft.Json;
#endif
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
@@ -14,7 +17,9 @@ namespace Tensorflow

private Tensor[] _outputs;
public Tensor[] outputs => _outputs;

#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public Tensor output => _outputs.FirstOrDefault();

public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);


+ 2
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -54,6 +54,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="NumSharp" Version="0.10.3" />
</ItemGroup>



+ 15
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -50,9 +50,18 @@ namespace Tensorflow

private TF_DataType _dtype = TF_DataType.DtInvalid;
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);

@@ -61,6 +70,9 @@ namespace Tensorflow
/// <summary>
/// used for keep other pointer when do implicit operating
/// </summary>
#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public object Tag { get; set; }

public int[] shape
@@ -131,7 +143,9 @@ namespace Tensorflow
}
}
}

#if GRAPH_SERIALIZE
[JsonIgnore]
#endif
public int NDims => rank;

public string Device => op.Device;


+ 8
- 2
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -62,6 +62,12 @@ namespace TensorFlowNET.Examples
train_y = y[new Slice(stop: train_size)];
valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE");

train_x = np.Load<int[,]>(Path.Join("word_cnn", "train_x.npy"));
valid_x = np.Load<int[,]>(Path.Join("word_cnn", "valid_x.npy"));
train_y = np.Load<int[]>(Path.Join("word_cnn", "train_y.npy"));
valid_y = np.Load<int[]>(Path.Join("word_cnn", "valid_y.npy"));

return (train_x, valid_x, train_y, valid_y);
}

@@ -114,7 +120,7 @@ namespace TensorFlowNET.Examples
int alphabet_size = 0;

var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
vocabulary_size = len(word_dict);
//vocabulary_size = len(word_dict);
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);

Console.WriteLine("\tDONE ");
@@ -305,7 +311,7 @@ namespace TensorFlowNET.Examples
public bool Train()
{
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
return with(tf.Session(graph), sess => Train(sess, graph));
}



Loading…
Cancel
Save