|
- using NumSharp.Core;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using Tensorflow;
-
- namespace TensorFlowNET.Examples.CnnTextClassification
- {
- public class CnnTextTrain : Python, IExample
- {
- // Percentage of the training data to use for validation
- private float dev_sample_percentage = 0.1f;
- // Data source for the positive data.
- private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
- // Data source for the negative data.
- private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
- // Dimensionality of character embedding (default: 128)
- private int embedding_dim = 128;
- // Comma-separated filter sizes (default: '3,4,5')
- private string filter_sizes = "3,4,5";
- // Number of filters per filter size (default: 128)
- private int num_filters = 128;
- // Dropout keep probability (default: 0.5)
- private float dropout_keep_prob = 0.5f;
- // L2 regularization lambda (default: 0.0)
- private float l2_reg_lambda = 0.0f;
- // Batch Size (default: 64)
- private int batch_size = 64;
- // Number of training epochs (default: 200)
- private int num_epochs = 200;
- // Evaluate model on dev set after this many steps (default: 100)
- private int evaluate_every = 100;
- // Save model after this many steps (default: 100)
- private int checkpoint_every = 100;
- // Number of checkpoints to store (default: 5)
- private int num_checkpoints = 5;
- // Allow device soft device placement
- private bool allow_soft_placement = true;
- // Log placement of ops on devices
- private bool log_device_placement = false;
-
- public void Run()
- {
- var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
- }
-
- public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
- {
- var (x_text, y) = DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);
-
- // Build vocabulary
- int max_document_length = x_text.Select(x => x.Split(' ').Length).Max();
- var vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
- throw new NotImplementedException("");
- }
- }
- }
|