@@ -153,7 +153,9 @@ namespace Tensorflow | |||||
{ | { | ||||
if (in_grad != null) | if (in_grad != null) | ||||
{ | { | ||||
if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) | |||||
if (in_grad is Tensor && | |||||
in_grad.Tag == null && // maybe a IndexedSlice | |||||
t_in.dtype != TF_DataType.TF_RESOURCE) | |||||
{ | { | ||||
in_grad.shape = t_in.shape; | in_grad.shape = t_in.shape; | ||||
} | } | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); | ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); | ||||
} | } | ||||
return ret.ToArray(); | |||||
return ret.OrderBy(x => x.op.name).ToArray(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -245,9 +245,12 @@ namespace Tensorflow | |||||
// If a names ends with a '/' it is a "name scope" and we use it as-is, | // If a names ends with a '/' it is a "name scope" and we use it as-is, | ||||
// after removing the trailing '/'. | // after removing the trailing '/'. | ||||
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | ||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | |||||
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape") | |||||
; | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | |||||
var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
var op = new Operation(node_def, | var op = new Operation(node_def, | ||||
@@ -1,4 +1,7 @@ | |||||
using System; | |||||
#if GRAPH_SERIALIZE | |||||
using Newtonsoft.Json; | |||||
#endif | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
@@ -14,7 +17,9 @@ namespace Tensorflow | |||||
private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
<DefineConstants>TRACE;DEBUG</DefineConstants> | |||||
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
@@ -54,6 +54,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> | <PackageReference Include="Google.Protobuf" Version="3.8.0" /> | ||||
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" /> | <PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||||
<PackageReference Include="NumSharp" Version="0.10.3" /> | <PackageReference Include="NumSharp" Version="0.10.3" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -50,9 +50,18 @@ namespace Tensorflow | |||||
private TF_DataType _dtype = TF_DataType.DtInvalid; | private TF_DataType _dtype = TF_DataType.DtInvalid; | ||||
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
@@ -61,6 +70,9 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// used for keep other pointer when do implicit operating | /// used for keep other pointer when do implicit operating | ||||
/// </summary> | /// </summary> | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public object Tag { get; set; } | public object Tag { get; set; } | ||||
public int[] shape | public int[] shape | ||||
@@ -131,7 +143,9 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
} | } | ||||
#if GRAPH_SERIALIZE | |||||
[JsonIgnore] | |||||
#endif | |||||
public int NDims => rank; | public int NDims => rank; | ||||
public string Device => op.Device; | public string Device => op.Device; | ||||
@@ -62,6 +62,12 @@ namespace TensorFlowNET.Examples | |||||
train_y = y[new Slice(stop: train_size)]; | train_y = y[new Slice(stop: train_size)]; | ||||
valid_y = y[new Slice(start: train_size)]; | valid_y = y[new Slice(start: train_size)]; | ||||
Console.WriteLine("\tDONE"); | Console.WriteLine("\tDONE"); | ||||
train_x = np.Load<int[,]>(Path.Join("word_cnn", "train_x.npy")); | |||||
valid_x = np.Load<int[,]>(Path.Join("word_cnn", "valid_x.npy")); | |||||
train_y = np.Load<int[]>(Path.Join("word_cnn", "train_y.npy")); | |||||
valid_y = np.Load<int[]>(Path.Join("word_cnn", "valid_y.npy")); | |||||
return (train_x, valid_x, train_y, valid_y); | return (train_x, valid_x, train_y, valid_y); | ||||
} | } | ||||
@@ -114,7 +120,7 @@ namespace TensorFlowNET.Examples | |||||
int alphabet_size = 0; | int alphabet_size = 0; | ||||
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); | var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); | ||||
vocabulary_size = len(word_dict); | |||||
//vocabulary_size = len(word_dict); | |||||
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); | var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); | ||||
Console.WriteLine("\tDONE "); | Console.WriteLine("\tDONE "); | ||||
@@ -305,7 +311,7 @@ namespace TensorFlowNET.Examples | |||||
public bool Train() | public bool Train() | ||||
{ | { | ||||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
string json = JsonConvert.SerializeObject(graph, Formatting.Indented); | |||||
return with(tf.Session(graph), sess => Train(sess, graph)); | return with(tf.Session(graph), sess => Train(sess, graph)); | ||||
} | } | ||||