|
|
@@ -1,4 +1,6 @@ |
|
|
|
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args |
|
|
|
seed = set_rng_seeds(15360) |
|
|
|
print('RNG SEED {}'.format(seed)) |
|
|
|
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN |
|
|
|
import torch.nn as nn |
|
|
|
import torch |
|
|
@@ -109,9 +111,6 @@ class MyCallback(FN.core.callback.Callback): |
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
seed = set_rng_seeds(28848) |
|
|
|
#seed = set_rng_seeds(np.random.randint(65536)) |
|
|
|
print('RNG SEED {}'.format(seed)) |
|
|
|
print('loading data') |
|
|
|
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( |
|
|
|
g_args.ds, g_args.task)]() |
|
|
|