@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); | ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); | ||||
} | } | ||||
return ret.OrderBy(x => x.op.name).ToArray(); | |||||
return ret.ToArray(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -248,9 +248,6 @@ namespace Tensorflow | |||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape") | |||||
; | |||||
var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
var op = new Operation(node_def, | var op = new Operation(node_def, | ||||
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants> | |||||
<DefineConstants>TRACE;DEBUG</DefineConstants> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
@@ -54,7 +54,6 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> | <PackageReference Include="Google.Protobuf" Version="3.8.0" /> | ||||
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" /> | <PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||||
<PackageReference Include="NumSharp" Version="0.10.3" /> | <PackageReference Include="NumSharp" Version="0.10.3" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -22,7 +22,7 @@ namespace TensorFlowNET.Examples | |||||
public bool Enabled { get; set; } = true; | public bool Enabled { get; set; } = true; | ||||
public string Name => "CNN Text Classification"; | public string Name => "CNN Text Classification"; | ||||
public int? DataLimit = null; | public int? DataLimit = null; | ||||
public bool IsImportingGraph { get; set; } = false; | |||||
public bool IsImportingGraph { get; set; } = true; | |||||
private const string dataDir = "word_cnn"; | private const string dataDir = "word_cnn"; | ||||
private string dataFileName = "dbpedia_csv.tar.gz"; | private string dataFileName = "dbpedia_csv.tar.gz"; | ||||
@@ -44,9 +44,7 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
PrepareData(); | PrepareData(); | ||||
Train(); | |||||
return true; | |||||
return Train(); | |||||
} | } | ||||
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here | // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here | ||||
@@ -305,13 +303,13 @@ namespace TensorFlowNET.Examples | |||||
} | } | ||||
} | } | ||||
return false; | |||||
return max_accuracy > 0.8; | |||||
} | } | ||||
public bool Train() | public bool Train() | ||||
{ | { | ||||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
string json = JsonConvert.SerializeObject(graph, Formatting.Indented); | |||||
// string json = JsonConvert.SerializeObject(graph, Formatting.Indented); | |||||
return with(tf.Session(graph), sess => Train(sess, graph)); | return with(tf.Session(graph), sess => Train(sess, graph)); | ||||
} | } | ||||