Browse Source

remove order of _control_dependencies_for_inputs.

tags/v0.9
Oceania2018 6 years ago
parent
commit
477d03db16
4 changed files with 6 additions and 12 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +0
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +1
    -2
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +4
    -6
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
}

return ret.OrderBy(x => x.op.name).ToArray();
return ret.ToArray();
}

/// <summary>


+ 0
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -248,9 +248,6 @@ namespace Tensorflow
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);

var input_ops = inputs.Select(x => x.op).ToArray();
if (name == "loss/gradients/embedding/embedding_lookup_grad/Reshape")
;

var control_inputs = _control_dependencies_for_inputs(input_ops);

var op = new Operation(node_def,


+ 1
- 2
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;GRAPH_SERIALIZE</DefineConstants>
<DefineConstants>TRACE;DEBUG</DefineConstants>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
@@ -54,7 +54,6 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.13.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="NumSharp" Version="0.10.3" />
</ItemGroup>



+ 4
- 6
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -22,7 +22,7 @@ namespace TensorFlowNET.Examples
public bool Enabled { get; set; } = true;
public string Name => "CNN Text Classification";
public int? DataLimit = null;
public bool IsImportingGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = true;

private const string dataDir = "word_cnn";
private string dataFileName = "dbpedia_csv.tar.gz";
@@ -44,9 +44,7 @@ namespace TensorFlowNET.Examples
{
PrepareData();

Train();

return true;
return Train();
}

// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
@@ -305,13 +303,13 @@ namespace TensorFlowNET.Examples
}
}

return false;
return max_accuracy > 0.8;
}

public bool Train()
{
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
// string json = JsonConvert.SerializeObject(graph, Formatting.Indented);
return with(tf.Session(graph), sess => Train(sess, graph));
}



Loading…
Cancel
Save