Browse Source

set no_grad() for test & inference, reduce memory usage

tags/v0.1.0
choosewhatulike 6 years ago
parent
commit
8a87807274
2 changed files with 5 additions and 6 deletions
  1. +2
    -2
      fastNLP/core/inference.py
  2. +3
    -4
      fastNLP/core/tester.py

+ 2
- 2
fastNLP/core/inference.py View File

@@ -76,8 +76,8 @@ class Inference(object):
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))

for batch_x in self.make_batch(iterator, data, use_cuda=False):
prediction = self.data_forward(network, batch_x)
with torch.no_grad():
prediction = self.data_forward(network, batch_x)

self.batch_output.append(prediction)



+ 3
- 4
fastNLP/core/tester.py View File

@@ -50,10 +50,9 @@ class BaseTester(object):
step = 0

for batch_x, batch_y in self.make_batch(iterator, dev_data):

prediction = self.data_forward(network, batch_x)

eval_results = self.evaluate(prediction, batch_y)
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)


Loading…
Cancel
Save