|
|
@@ -62,6 +62,12 @@ namespace TensorFlowNET.Examples |
|
|
|
train_y = y[new Slice(stop: train_size)]; |
|
|
|
valid_y = y[new Slice(start: train_size)]; |
|
|
|
Console.WriteLine("\tDONE"); |
|
|
|
|
|
|
|
train_x = np.Load<int[,]>(Path.Join("word_cnn", "train_x.npy")); |
|
|
|
valid_x = np.Load<int[,]>(Path.Join("word_cnn", "valid_x.npy")); |
|
|
|
train_y = np.Load<int[]>(Path.Join("word_cnn", "train_y.npy")); |
|
|
|
valid_y = np.Load<int[]>(Path.Join("word_cnn", "valid_y.npy")); |
|
|
|
|
|
|
|
return (train_x, valid_x, train_y, valid_y); |
|
|
|
} |
|
|
|
|
|
|
@@ -114,7 +120,7 @@ namespace TensorFlowNET.Examples |
|
|
|
int alphabet_size = 0; |
|
|
|
|
|
|
|
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); |
|
|
|
vocabulary_size = len(word_dict); |
|
|
|
//vocabulary_size = len(word_dict); |
|
|
|
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); |
|
|
|
|
|
|
|
Console.WriteLine("\tDONE "); |
|
|
@@ -305,7 +311,7 @@ namespace TensorFlowNET.Examples |
|
|
|
public bool Train() |
|
|
|
{ |
|
|
|
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); |
|
|
|
|
|
|
|
string json = JsonConvert.SerializeObject(graph, Formatting.Indented); |
|
|
|
return with(tf.Session(graph), sess => Train(sess, graph)); |
|
|
|
} |
|
|
|
|
|
|
|