@@ -66,8 +66,9 @@ from ..io.model_io import ModelSaver, ModelLoader | |||||
try: | try: | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
tensorboardX_flag = True | |||||
except: | except: | ||||
pass | |||||
tensorboardX_flag = False | |||||
class Callback(object): | class Callback(object): | ||||
@@ -581,7 +582,8 @@ class TensorboardCallback(Callback): | |||||
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | ||||
else: | else: | ||||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | ||||
self._summary_writer = SummaryWriter(path) | |||||
if tensorboardX_flag: | |||||
self._summary_writer = SummaryWriter(path) | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
if "model" in self.options and self.graph_added is False: | if "model" in self.options and self.graph_added is False: | ||||
@@ -78,7 +78,7 @@ class ENASTrainer(Trainer): | |||||
results['seconds'] = 0. | results['seconds'] = 0. | ||||
return results | return results | ||||
try: | try: | ||||
if torch.cuda.is_available() and self.use_cuda: | |||||
if torch.cuda.is_available() and "cuda" in self.device: | |||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self._model_device = self.model.parameters().__next__().device | self._model_device = self.model.parameters().__next__().device | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
@@ -337,7 +337,7 @@ class ENASTrainer(Trainer): | |||||
# policy loss | # policy loss | ||||
loss = -log_probs*utils.get_variable(adv, | loss = -log_probs*utils.get_variable(adv, | ||||
self.use_cuda, | |||||
'cuda' in self.device, | |||||
requires_grad=False) | requires_grad=False) | ||||
loss = loss.sum() # or loss.mean() | loss = loss.sum() # or loss.mean() | ||||
@@ -5,13 +5,13 @@ import torch | |||||
warnings.filterwarnings('ignore') | warnings.filterwarnings('ignore') | ||||
import os | import os | ||||
from ..core.dataset import DataSet | |||||
from fastNLP.core.dataset import DataSet | |||||
from .utils import load_url | from .utils import load_url | ||||
from .processor import ModelProcessor | from .processor import ModelProcessor | ||||
from ..io.dataset_loader import _cut_long_sentence, ConllLoader | |||||
from ..core.instance import Instance | |||||
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader | |||||
from fastNLP.core.instance import Instance | |||||
from ..api.pipeline import Pipeline | from ..api.pipeline import Pipeline | ||||
from ..core.metrics import SpanFPreRecMetric | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from .processor import IndexerProcessor | from .processor import IndexerProcessor | ||||
# TODO add pretrain urls | # TODO add pretrain urls |
@@ -3,10 +3,10 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
from ..core.batch import Batch | |||||
from ..core.dataset import DataSet | |||||
from ..core.sampler import SequentialSampler | |||||
from ..core.vocabulary import Vocabulary | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor(object): | class Processor(object): | ||||
@@ -232,7 +232,7 @@ class SeqLenProcessor(Processor): | |||||
return dataset | return dataset | ||||
from ..core.utils import _build_args | |||||
from fastNLP.core.utils import _build_args | |||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
@@ -257,10 +257,7 @@ class ModelProcessor(Processor): | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
if hasattr(self.model, "predict"): | |||||
predict_func = self.model.predict | |||||
else: | |||||
predict_func = self.model.forward | |||||
predict_func = self.model.forward | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
refined_batch_x = _build_args(predict_func, **batch_x) | refined_batch_x = _build_args(predict_func, **batch_x) |
@@ -22,7 +22,7 @@ except ImportError: | |||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
except: | except: | ||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
# matches bfd8deac from resnet18-bfd8deac.pth | # matches bfd8deac from resnet18-bfd8deac.pth | ||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') |
@@ -11,15 +11,15 @@ import torch | |||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
except: | except: | ||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from ..core.batch import Batch | |||||
from ..core.callback import CallbackException | |||||
from ..core.dataset import DataSet | |||||
from ..core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.callback import CallbackException | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
import fastNLP | import fastNLP | ||||
from . import enas_utils as utils | from . import enas_utils as utils | ||||
from ..core.utils import _build_args | |||||
from fastNLP.core.utils import _build_args | |||||
from torch.optim import Adam | from torch.optim import Adam | ||||