Browse Source

fix random process

tags/v0.4.10
QipengGuo 5 years ago
parent
commit
204bb06769
1 changed files with 2 additions and 3 deletions
  1. +2
    -3
      reproduction/Star_transformer/train.py

+ 2
- 3
reproduction/Star_transformer/train.py View File

@@ -1,4 +1,6 @@
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args from util import get_argparser, set_gpu, set_rng_seeds, add_model_args
seed = set_rng_seeds(15360)
print('RNG SEED {}'.format(seed))
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
import torch.nn as nn import torch.nn as nn
import torch import torch
@@ -109,9 +111,6 @@ class MyCallback(FN.core.callback.Callback):




def train(): def train():
seed = set_rng_seeds(28848)
#seed = set_rng_seeds(np.random.randint(65536))
print('RNG SEED {}'.format(seed))
print('loading data') print('loading data')
ds_list, word_v, tag_v = g_datasets['{}-{}'.format( ds_list, word_v, tag_v = g_datasets['{}-{}'.format(
g_args.ds, g_args.task)]() g_args.ds, g_args.task)]()


Loading…
Cancel
Save