Browse Source

Merge branch 'dev' of gitee.com:fastnlp/fastNLP into dev

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
e9c6bf751b
2 changed files with 6 additions and 1 deletions
  1. +5
    -0
      fastNLP/core/dist_trainer.py
  2. +1
    -1
      tests/core/test_batch.py

+ 5
- 0
fastNLP/core/dist_trainer.py View File

@@ -177,8 +177,13 @@ class DistTrainer():
self.batch_size = self.world_size * self.batch_size_per_gpu
self.n_steps = self._get_n_steps()

self.dev_data = dev_data
self.metrics = metrics
self.test_use_tqdm = True
self.kwargs = kwargs
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)

# for evaluation, only run eval on master proc
if dev_data and metrics:
cb = _TesterCallback(


+ 1
- 1
tests/core/test_batch.py View File

@@ -445,7 +445,7 @@ class TestCase1(unittest.TestCase):
sample_count = 0
for batch_x, batch_y in data_iter:
sample_count += len(batch_x['seq_len'])
self.assertTrue(sum(batch_x['seq_len'])<120)
self.assertTrue(sum(batch_x['seq_len'])<=120)
self.assertEqual(sample_count, num_samples)

"""


Loading…
Cancel
Save