|
|
@@ -22,7 +22,7 @@ namespace TensorFlowNET.Examples |
|
|
|
public bool Enabled { get; set; } = true; |
|
|
|
public string Name => "CNN Text Classification"; |
|
|
|
public int? DataLimit = null; |
|
|
|
public bool IsImportingGraph { get; set; } = false; |
|
|
|
public bool IsImportingGraph { get; set; } = true; |
|
|
|
|
|
|
|
private const string dataDir = "word_cnn"; |
|
|
|
private string dataFileName = "dbpedia_csv.tar.gz"; |
|
|
@@ -44,9 +44,7 @@ namespace TensorFlowNET.Examples |
|
|
|
{ |
|
|
|
PrepareData(); |
|
|
|
|
|
|
|
Train(); |
|
|
|
|
|
|
|
return true; |
|
|
|
return Train(); |
|
|
|
} |
|
|
|
|
|
|
|
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here |
|
|
@@ -305,13 +303,13 @@ namespace TensorFlowNET.Examples |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
return max_accuracy > 0.8; |
|
|
|
} |
|
|
|
|
|
|
|
public bool Train() |
|
|
|
{ |
|
|
|
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); |
|
|
|
string json = JsonConvert.SerializeObject(graph, Formatting.Indented); |
|
|
|
// string json = JsonConvert.SerializeObject(graph, Formatting.Indented); |
|
|
|
return with(tf.Session(graph), sess => Train(sess, graph)); |
|
|
|
} |
|
|
|
|
|
|
|