| @@ -77,14 +77,17 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
| for ins in dataset: | |||||
| def inner_proc(ins): | |||||
| sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
| new_sentence = [None] * len(sentence) | |||||
| new_sentence = [""] * len(sentence) | |||||
| for idx, char in enumerate(sentence): | for idx, char in enumerate(sentence): | ||||
| if char in self.convert_map: | if char in self.convert_map: | ||||
| char = self.convert_map[char] | char = self.convert_map[char] | ||||
| new_sentence[idx] = char | new_sentence[idx] = char | ||||
| ins[self.field_name] = ''.join(new_sentence) | |||||
| return "".join(new_sentence) | |||||
| dataset.apply(inner_proc, new_field_name=self.field_name) | |||||
| return dataset | return dataset | ||||
| @@ -94,9 +97,7 @@ class PreAppendProcessor(Processor): | |||||
| self.data = data | self.data = data | ||||
| def process(self, dataset): | def process(self, dataset): | ||||
| for ins in dataset: | |||||
| sent = ins[self.field_name] | |||||
| ins[self.new_added_field_name] = [self.data] + sent | |||||
| dataset.apply(lambda ins: [self.data] + ins[self.field_name], new_field_name=self.new_added_field_name) | |||||
| return dataset | return dataset | ||||
| @@ -108,9 +109,7 @@ class SliceProcessor(Processor): | |||||
| self.slice = slice(start, end, step) | self.slice = slice(start, end, step) | ||||
| def process(self, dataset): | def process(self, dataset): | ||||
| for ins in dataset: | |||||
| sent = ins[self.field_name] | |||||
| ins[self.new_added_field_name] = sent[self.slice] | |||||
| dataset.apply(lambda ins: ins[self.field_name][self.slice], new_field_name=self.new_added_field_name) | |||||
| return dataset | return dataset | ||||
| @@ -121,14 +120,17 @@ class Num2TagProcessor(Processor): | |||||
| self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | ||||
| def process(self, dataset): | def process(self, dataset): | ||||
| for ins in dataset: | |||||
| def inner_proc(ins): | |||||
| s = ins[self.field_name] | s = ins[self.field_name] | ||||
| new_s = [None] * len(s) | new_s = [None] * len(s) | ||||
| for i, w in enumerate(s): | for i, w in enumerate(s): | ||||
| if re.search(self.pattern, w) is not None: | if re.search(self.pattern, w) is not None: | ||||
| w = self.tag | w = self.tag | ||||
| new_s[i] = w | new_s[i] = w | ||||
| ins[self.new_added_field_name] = new_s | |||||
| return new_s | |||||
| dataset.apply(inner_proc, new_field_name=self.new_added_field_name) | |||||
| return dataset | return dataset | ||||
| @@ -149,11 +151,8 @@ class IndexerProcessor(Processor): | |||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | ||||
| for ins in dataset: | |||||
| tokens = ins[self.field_name] | |||||
| index = [self.vocab.to_index(token) for token in tokens] | |||||
| ins[self.new_added_field_name] = index | |||||
| dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
| new_field_name=self.new_added_field_name) | |||||
| if self.is_input: | if self.is_input: | ||||
| dataset.set_input(self.new_added_field_name) | dataset.set_input(self.new_added_field_name) | ||||
| @@ -167,6 +166,7 @@ class VocabProcessor(Processor): | |||||
| """Build vocabulary with a field in the data set. | """Build vocabulary with a field in the data set. | ||||
| """ | """ | ||||
| def __init__(self, field_name): | def __init__(self, field_name): | ||||
| super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
| self.vocab = Vocabulary() | self.vocab = Vocabulary() | ||||
| @@ -175,8 +175,7 @@ class VocabProcessor(Processor): | |||||
| for dataset in datasets: | for dataset in datasets: | ||||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
| for ins in dataset: | for ins in dataset: | ||||
| tokens = ins[self.field_name] | |||||
| self.vocab.update(tokens) | |||||
| self.vocab.update(ins[self.field_name]) | |||||
| def get_vocab(self): | def get_vocab(self): | ||||
| self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
| @@ -190,9 +189,7 @@ class SeqLenProcessor(Processor): | |||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
| for ins in dataset: | |||||
| length = len(ins[self.field_name]) | |||||
| ins[self.new_added_field_name] = length | |||||
| dataset.apply(lambda ins: len(ins[self.field_name]), new_field_name=self.new_added_field_name) | |||||
| if self.is_input: | if self.is_input: | ||||
| dataset.set_input(self.new_added_field_name) | dataset.set_input(self.new_added_field_name) | ||||
| return dataset | return dataset | ||||
| @@ -225,7 +222,7 @@ class ModelProcessor(Processor): | |||||
| for key, value in prediction.items(): | for key, value in prediction.items(): | ||||
| tmp_batch = [] | tmp_batch = [] | ||||
| value = value.cpu().numpy() | value = value.cpu().numpy() | ||||
| if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||||
| if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||||
| batch_output[key].extend(value.tolist()) | batch_output[key].extend(value.tolist()) | ||||
| else: | else: | ||||
| for idx, seq_len in enumerate(seq_lens): | for idx, seq_len in enumerate(seq_lens): | ||||
| @@ -236,7 +233,7 @@ class ModelProcessor(Processor): | |||||
| # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | ||||
| for field_name, fields in batch_output.items(): | for field_name, fields in batch_output.items(): | ||||
| dataset.add_field(field_name, fields, need_tensor=False, is_target=False) | |||||
| dataset.add_field(field_name, fields, is_input=True, is_target=False) | |||||
| return dataset | return dataset | ||||
| @@ -254,23 +251,8 @@ class Index2WordProcessor(Processor): | |||||
| self.vocab = vocab | self.vocab = vocab | ||||
| def process(self, dataset): | def process(self, dataset): | ||||
| for ins in dataset: | |||||
| new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||||
| ins[self.new_added_field_name] = new_sent | |||||
| return dataset | |||||
| class SetTensorProcessor(Processor): | |||||
| # TODO: remove it. It is strange. | |||||
| def __init__(self, field_dict, default=False): | |||||
| super(SetTensorProcessor, self).__init__(None, None) | |||||
| self.field_dict = field_dict | |||||
| self.default = default | |||||
| def process(self, dataset): | |||||
| set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
| set_dict.update(self.field_dict) | |||||
| dataset._set_need_tensor(**set_dict) | |||||
| dataset.apply(lambda ins: [self.vocab.to_word(w) for w in ins[self.field_name]], | |||||
| new_field_name=self.new_added_field_name) | |||||
| return dataset | return dataset | ||||
| @@ -10,7 +10,7 @@ class Optimizer(object): | |||||
| class SGD(Optimizer): | class SGD(Optimizer): | ||||
| def __init__(self, lr=0.01, momentum=0, model_params=None): | |||||
| def __init__(self, lr=0.001, momentum=0, model_params=None): | |||||
| """ | """ | ||||
| :param float lr: learning rate. Default: 0.01 | :param float lr: learning rate. Default: 0.01 | ||||
| @@ -30,7 +30,7 @@ class SGD(Optimizer): | |||||
| class Adam(Optimizer): | class Adam(Optimizer): | ||||
| def __init__(self, lr=0.01, weight_decay=0, model_params=None): | |||||
| def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | |||||
| """ | """ | ||||
| :param float lr: learning rate | :param float lr: learning rate | ||||
| @@ -39,7 +39,8 @@ class Adam(Optimizer): | |||||
| """ | """ | ||||
| if not isinstance(lr, float): | if not isinstance(lr, float): | ||||
| raise TypeError("learning rate has to be float.") | raise TypeError("learning rate has to be float.") | ||||
| super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||||
| super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | |||||
| weight_decay=weight_decay) | |||||
| def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
| if self.model_params is None: | if self.model_params is None: | ||||
| @@ -31,12 +31,12 @@ class Tester(object): | |||||
| self.use_cuda = use_cuda | self.use_cuda = use_cuda | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.verbose = verbose | self.verbose = verbose | ||||
| self._model_device = model.parameters().__next__().device | |||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| self._model = model.cuda() | self._model = model.cuda() | ||||
| else: | else: | ||||
| self._model = model | self._model = model | ||||
| self._model_device = model.parameters().__next__().device | |||||
| # check predict | # check predict | ||||
| if hasattr(self._model, 'predict'): | if hasattr(self._model, 'predict'): | ||||
| @@ -3,6 +3,7 @@ import time | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from datetime import timedelta | from datetime import timedelta | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
| from torch import nn | from torch import nn | ||||
| @@ -97,7 +98,8 @@ class Trainer(object): | |||||
| if check_code_level > -1: | if check_code_level > -1: | ||||
| _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | ||||
| metric_key=metric_key, check_level=check_code_level) | |||||
| metric_key=metric_key, check_level=check_code_level, | |||||
| batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||||
| self.train_data = train_data | self.train_data = train_data | ||||
| self.dev_data = dev_data # If None, No validation. | self.dev_data = dev_data # If None, No validation. | ||||
| @@ -113,8 +115,6 @@ class Trainer(object): | |||||
| self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
| self.sampler = sampler | self.sampler = sampler | ||||
| self._model_device = model.parameters().__next__().device | |||||
| if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| else: | else: | ||||
| @@ -123,6 +123,7 @@ class Trainer(object): | |||||
| self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
| if self.use_tqdm: | if self.use_tqdm: | ||||
| tester_verbose = 0 | tester_verbose = 0 | ||||
| self.print_every = abs(self.print_every) | |||||
| else: | else: | ||||
| tester_verbose = 1 | tester_verbose = 1 | ||||
| @@ -137,17 +138,44 @@ class Trainer(object): | |||||
| self.step = 0 | self.step = 0 | ||||
| self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
| def train(self): | |||||
| """Start Training. | |||||
| def train(self, load_best_model=True): | |||||
| """ | |||||
| 开始训练过程。主要有以下几个步骤 | |||||
| for epoch in range(num_epochs): | |||||
| (1) 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。非 | |||||
| float,int类型的参数将不会被转换为Tensor,且不进行padding | |||||
| for batch_x, batch_y in Batch(DataSet): | |||||
| # batch_x中为设置为input的field | |||||
| # batch_y中为设置为target的field | |||||
| (2) 将batch_x的数据送入到model.forward函数中,并获取结果 | |||||
| (3) 将batch_y与model.forward的结果一并送入loss中计算loss | |||||
| (4) 获取到loss之后,进行反向求导并更新梯度 | |||||
| if dev_data is not None: | |||||
| 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
| :param load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 | |||||
| 模型参数。 | |||||
| 将会返回一个字典类型的数据, 内含以下内容: | |||||
| seconds: float, 表示训练时长 | |||||
| 以下三个内容只有在提供了dev_data的情况下会有。 | |||||
| best_eval: Dict of Dict, 表示evaluation的结果 | |||||
| best_epoch: int,在第几个epoch取得的最佳值 | |||||
| best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
| return dict: | |||||
| """ | """ | ||||
| results = {} | |||||
| try: | try: | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
| self._model_device = self.model.parameters().__next__().device | |||||
| self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | ||||
| start_time = time.time() | |||||
| print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
| if self.save_path is None: | if self.save_path is None: | ||||
| class psudoSW: | class psudoSW: | ||||
| @@ -165,26 +193,37 @@ class Trainer(object): | |||||
| self._tqdm_train() | self._tqdm_train() | ||||
| else: | else: | ||||
| self._print_train() | self._print_train() | ||||
| if self.dev_data is not None: | |||||
| print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
| self.tester._format_eval_results(self.best_dev_perf),) | |||||
| results['best_eval'] = self.best_dev_perf | |||||
| results['best_epoch'] = self.best_dev_epoch | |||||
| results['best_step'] = self.best_dev_step | |||||
| if load_best_model: | |||||
| model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||||
| # self._load_model(self.model, model_name) | |||||
| print("Reloaded the best model.") | |||||
| finally: | finally: | ||||
| self._summary_writer.close() | self._summary_writer.close() | ||||
| del self._summary_writer | del self._summary_writer | ||||
| results['seconds'] = round(time.time() - start_time, 2) | |||||
| return results | |||||
| def _tqdm_train(self): | def _tqdm_train(self): | ||||
| self.step = 0 | self.step = 0 | ||||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
| as_numpy=False) | as_numpy=False) | ||||
| total_steps = data_iterator.num_batches*self.n_epochs | total_steps = data_iterator.num_batches*self.n_epochs | ||||
| epoch = 1 | |||||
| with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
| ava_loss = 0 | |||||
| avg_loss = 0 | |||||
| for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
| pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
| for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
| prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
| loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
| ava_loss += loss.item() | |||||
| avg_loss += loss.item() | |||||
| self._grad_backward(loss) | self._grad_backward(loss) | ||||
| self._update() | self._update() | ||||
| self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | ||||
| @@ -194,18 +233,18 @@ class Trainer(object): | |||||
| # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | ||||
| # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | ||||
| if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
| pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) | |||||
| ava_loss = 0 | |||||
| pbar.update(1) | |||||
| pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every)) | |||||
| avg_loss = 0 | |||||
| pbar.update(self.print_every) | |||||
| self.step += 1 | self.step += 1 | ||||
| if self.validate_every > 0 and self.step % self.validate_every == 0 \ | if self.validate_every > 0 and self.step % self.validate_every == 0 \ | ||||
| and self.dev_data is not None: | and self.dev_data is not None: | ||||
| eval_res = self._do_validation() | |||||
| eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
| eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
| self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
| pbar.write(eval_str) | pbar.write(eval_str) | ||||
| if self.validate_every < 0 and self.dev_data: | if self.validate_every < 0 and self.dev_data: | ||||
| eval_res = self._do_validation() | |||||
| eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
| eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
| self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
| pbar.write(eval_str) | pbar.write(eval_str) | ||||
| @@ -244,25 +283,29 @@ class Trainer(object): | |||||
| if (self.validate_every > 0 and self.step % self.validate_every == 0 and | if (self.validate_every > 0 and self.step % self.validate_every == 0 and | ||||
| self.dev_data is not None): | self.dev_data is not None): | ||||
| self._do_validation() | |||||
| self._do_validation(epoch=epoch, step=self.step) | |||||
| self.step += 1 | self.step += 1 | ||||
| # validate_every override validation at end of epochs | # validate_every override validation at end of epochs | ||||
| if self.dev_data and self.validate_every <= 0: | if self.dev_data and self.validate_every <= 0: | ||||
| self._do_validation() | |||||
| self._do_validation(epoch=epoch, step=self.step) | |||||
| epoch += 1 | epoch += 1 | ||||
| def _do_validation(self): | |||||
| def _do_validation(self, epoch, step): | |||||
| res = self.tester.test() | res = self.tester.test() | ||||
| for name, metric in res.items(): | for name, metric in res.items(): | ||||
| for metric_key, metric_val in metric.items(): | for metric_key, metric_val in metric.items(): | ||||
| self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
| global_step=self.step) | global_step=self.step) | ||||
| if self.save_path is not None and self._better_eval_result(res): | |||||
| metric_key = self.metric_key if self.metric_key is not None else "" | |||||
| self._save_model(self.model, | |||||
| "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | |||||
| if self._better_eval_result(res): | |||||
| if self.save_path is not None: | |||||
| self._save_model(self.model, | |||||
| "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
| self.best_dev_perf = res | |||||
| self.best_dev_epoch = epoch | |||||
| self.best_dev_step = step | |||||
| return res | return res | ||||
| def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
| @@ -317,6 +360,16 @@ class Trainer(object): | |||||
| else: | else: | ||||
| torch.save(model, model_name) | torch.save(model, model_name) | ||||
| def _load_model(self, model, model_name, only_param=False): | |||||
| # TODO: 这个是不是有问题? | |||||
| if self.save_path is not None: | |||||
| model_name = os.path.join(self.save_path, model_name) | |||||
| if only_param: | |||||
| states = torch.save(model.state_dict(), model_name) | |||||
| else: | |||||
| states = torch.save(model, model_name).state_dict() | |||||
| model.load_state_dict(states) | |||||
| def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
| """Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
| @@ -344,6 +397,21 @@ class Trainer(object): | |||||
| DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
| DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
| def _get_value_info(_dict): | |||||
| # given a dict value, return information about this dict's value. Return list of str | |||||
| strs = [] | |||||
| for key, value in _dict.items(): | |||||
| _str = '' | |||||
| if isinstance(value, torch.Tensor): | |||||
| _str += "\t{}: (1)type:torch.Tensor (2)dtype:{}, (3)shape:{} ".format(key, | |||||
| value.dtype, value.size()) | |||||
| elif isinstance(value, np.ndarray): | |||||
| _str += "\t{}: (1)type:numpy.ndarray (2)dtype:{}, (3)shape:{} ".format(key, | |||||
| value.dtype, value.shape) | |||||
| else: | |||||
| _str += "\t{}: type:{}".format(key, type(value)) | |||||
| strs.append(_str) | |||||
| return strs | |||||
| def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
| dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
| @@ -356,8 +424,24 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
| _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
| # forward check | # forward check | ||||
| if batch_count==0: | if batch_count==0: | ||||
| info_str = "" | |||||
| input_fields = _get_value_info(batch_x) | |||||
| target_fields = _get_value_info(batch_y) | |||||
| if len(input_fields)>0: | |||||
| info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | |||||
| info_str += "\n".join(input_fields) | |||||
| info_str += '\n' | |||||
| else: | |||||
| raise RuntimeError("There is no input field.") | |||||
| if len(target_fields)>0: | |||||
| info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | |||||
| info_str += "\n".join(target_fields) | |||||
| info_str += '\n' | |||||
| else: | |||||
| info_str += 'There is no target field.' | |||||
| print(info_str) | |||||
| _check_forward_error(forward_func=model.forward, dataset=dataset, | _check_forward_error(forward_func=model.forward, dataset=dataset, | ||||
| batch_x=batch_x, check_level=check_level) | |||||
| batch_x=batch_x, check_level=check_level) | |||||
| refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
| pred_dict = model(**refined_batch_x) | pred_dict = model(**refined_batch_x) | ||||
| @@ -125,7 +125,7 @@ def _check_arg_dict_list(func, args): | |||||
| input_args = set(input_arg_count.keys()) | input_args = set(input_arg_count.keys()) | ||||
| missing = list(require_args - input_args) | missing = list(require_args - input_args) | ||||
| unused = list(input_args - all_args) | unused = list(input_args - all_args) | ||||
| varargs = [] if not spect.varargs else [arg for arg in spect.varargs] | |||||
| varargs = [] if not spect.varargs else [spect.varargs] | |||||
| return CheckRes(missing=missing, | return CheckRes(missing=missing, | ||||
| unused=unused, | unused=unused, | ||||
| duplicated=duplicated, | duplicated=duplicated, | ||||
| @@ -0,0 +1,6 @@ | |||||
| import unittest | |||||
| class TestPipeline(unittest.TestCase): | |||||
| def test_case(self): | |||||
| pass | |||||
| @@ -1,6 +1,9 @@ | |||||
| import random | |||||
| import unittest | import unittest | ||||
| from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
| from fastNLP import Vocabulary | |||||
| from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | |||||
| IndexerProcessor, VocabProcessor, SeqLenProcessor | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| @@ -9,4 +12,44 @@ class TestProcessor(unittest.TestCase): | |||||
| ds = DataSet({"word": ["00, u1, u), (u2, u2"]}) | ds = DataSet({"word": ["00, u1, u), (u2, u2"]}) | ||||
| proc = FullSpaceToHalfSpaceProcessor("word") | proc = FullSpaceToHalfSpaceProcessor("word") | ||||
| ds = proc(ds) | ds = proc(ds) | ||||
| self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"]) | |||||
| self.assertEqual(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"]) | |||||
| def test_PreAppendProcessor(self): | |||||
| ds = DataSet({"word": [["1234", "3456"], ["8789", "3464"]]}) | |||||
| proc = PreAppendProcessor(data="abc", field_name="word") | |||||
| ds = proc(ds) | |||||
| self.assertEqual(ds.field_arrays["word"].content, [["abc", "1234", "3456"], ["abc", "8789", "3464"]]) | |||||
| def test_SliceProcessor(self): | |||||
| ds = DataSet({"xx": [[random.randint(0, 10) for _ in range(30)]] * 40}) | |||||
| proc = SliceProcessor(10, 20, 2, "xx", new_added_field_name="yy") | |||||
| ds = proc(ds) | |||||
| self.assertEqual(len(ds.field_arrays["yy"].content[0]), 5) | |||||
| def test_Num2TagProcessor(self): | |||||
| ds = DataSet({"num": [["99.9982", "2134.0"], ["0.002", "234"]]}) | |||||
| proc = Num2TagProcessor("<num>", "num") | |||||
| ds = proc(ds) | |||||
| for data in ds.field_arrays["num"].content: | |||||
| for d in data: | |||||
| self.assertEqual(d, "<num>") | |||||
| def test_VocabProcessor_and_IndexerProcessor(self): | |||||
| ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 40}) | |||||
| vocab_proc = VocabProcessor("xx") | |||||
| vocab_proc(ds) | |||||
| vocab = vocab_proc.vocab | |||||
| self.assertTrue(isinstance(vocab, Vocabulary)) | |||||
| self.assertTrue(len(vocab) > 5) | |||||
| proc = IndexerProcessor(vocab, "xx", "yy") | |||||
| ds = proc(ds) | |||||
| for data in ds.field_arrays["yy"].content[0]: | |||||
| self.assertTrue(isinstance(data, int)) | |||||
| def test_SeqLenProcessor(self): | |||||
| ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 10}) | |||||
| proc = SeqLenProcessor("xx", "len") | |||||
| ds = proc(ds) | |||||
| for data in ds.field_arrays["len"].content: | |||||
| self.assertEqual(data, 30) | |||||
| @@ -52,28 +52,24 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
| def test_AccuaryMetric4(self): | def test_AccuaryMetric4(self): | ||||
| # (5) check reset | # (5) check reset | ||||
| metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
| target_dict = {'target': torch.zeros(4, 3)} | |||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
| target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
| pred_dict = {"pred": torch.randn(4, 3, 2)} | |||||
| target_dict = {'target': torch.ones(4, 3)} | |||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
| self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||||
| ans = torch.argmax(pred_dict["pred"], dim=2).to(target_dict["target"]) == target_dict["target"] | |||||
| res = metric.get_metric() | |||||
| self.assertTrue(isinstance(res, dict)) | |||||
| self.assertTrue("acc" in res) | |||||
| self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | |||||
| def test_AccuaryMetric5(self): | def test_AccuaryMetric5(self): | ||||
| # (5) check reset | # (5) check reset | ||||
| metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
| pred_dict = {"pred": torch.randn(4, 3, 2)} | |||||
| target_dict = {'target': torch.zeros(4, 3)} | target_dict = {'target': torch.zeros(4, 3)} | ||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
| self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
| target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
| self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||||
| res = metric.get_metric(reset=False) | |||||
| ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | |||||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||||
| def test_AccuaryMetric6(self): | def test_AccuaryMetric6(self): | ||||
| # (6) check numpy array is not acceptable | # (6) check numpy array is not acceptable | ||||
| @@ -90,10 +86,12 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
| def test_AccuaryMetric7(self): | def test_AccuaryMetric7(self): | ||||
| # (7) check map, match | # (7) check map, match | ||||
| metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
| pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
| pred_dict = {"predictions": torch.randn(4, 3, 2)} | |||||
| target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
| res = metric.get_metric() | |||||
| ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | |||||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||||
| def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
| # (8) check map, does not match. use stop_fast_param to stop fast param map | # (8) check map, does not match. use stop_fast_param to stop fast param map | ||||
| @@ -1,10 +1,10 @@ | |||||
| import time | |||||
| import unittest | import unittest | ||||
| import numpy as np | import numpy as np | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import nn | from torch import nn | ||||
| import time | |||||
| from fastNLP.core.utils import CheckError | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
| @@ -83,7 +83,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = Model() | model = Model() | ||||
| with self.assertRaises(NameError): | |||||
| with self.assertRaises(RuntimeError): | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| train_data=dataset, | train_data=dataset, | ||||
| model=model | model=model | ||||
| @@ -19,16 +19,52 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 50, | |||||
| "execution_count": 3, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | |||||
| "outputs": [ | |||||
| { | |||||
| "name": "stderr", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||||
| " \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | |||||
| ] | |||||
| } | |||||
| ], | |||||
| "source": [ | "source": [ | ||||
| "import sys\n", | |||||
| "sys.path.append(\"../\")\n", | |||||
| "\n", | |||||
| "from fastNLP import DataSet\n", | "from fastNLP import DataSet\n", | ||||
| "\n", | |||||
| "# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | "# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | ||||
| "win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n", | |||||
| "win_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||||
| "ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')" | "ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')" | ||||
| ] | ] | ||||
| }, | }, | ||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": 8, | |||||
| "metadata": {}, | |||||
| "outputs": [ | |||||
| { | |||||
| "data": { | |||||
| "text/plain": [ | |||||
| "{'raw_sentence': this quiet , introspective and entertaining independent is worth seeking .,\n", | |||||
| "'label': 4,\n", | |||||
| "'label_seq': 4,\n", | |||||
| "'words': ['this', 'quiet', ',', 'introspective', 'and', 'entertaining', 'independent', 'is', 'worth', 'seeking', '.']}" | |||||
| ] | |||||
| }, | |||||
| "execution_count": 8, | |||||
| "metadata": {}, | |||||
| "output_type": "execute_result" | |||||
| } | |||||
| ], | |||||
| "source": [ | |||||
| "ds[1]" | |||||
| ] | |||||
| }, | |||||
| { | { | ||||
| "cell_type": "markdown", | "cell_type": "markdown", | ||||
| "metadata": {}, | "metadata": {}, | ||||
| @@ -42,7 +78,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 52, | |||||
| "execution_count": 4, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -58,65 +94,15 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 60, | |||||
| "metadata": { | |||||
| "collapsed": false | |||||
| }, | |||||
| "execution_count": 5, | |||||
| "metadata": {}, | |||||
| "outputs": [ | "outputs": [ | ||||
| { | { | ||||
| "name": "stdout", | "name": "stdout", | ||||
| "output_type": "stream", | "output_type": "stream", | ||||
| "text": [ | "text": [ | ||||
| "Train size: " | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| " " | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "54" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Test size: " | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| " " | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "23" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| "Train size: 54\n", | |||||
| "Test size: 23\n" | |||||
| ] | ] | ||||
| } | } | ||||
| ], | ], | ||||
| @@ -129,7 +115,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 61, | |||||
| "execution_count": 6, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -177,14 +163,7 @@ | |||||
| "name": "stdout", | "name": "stdout", | ||||
| "output_type": "stream", | "output_type": "stream", | ||||
| "text": [ | "text": [ | ||||
| "training epochs started 2018-12-07 14:03:41" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| "training epochs started 2018-12-07 14:03:41\n" | |||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| @@ -201,84 +180,10 @@ | |||||
| "name": "stdout", | "name": "stdout", | ||||
| "output_type": "stream", | "output_type": "stream", | ||||
| "text": [ | "text": [ | ||||
| "\r" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\r" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\r" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\r" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Train finished!" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "\n" | |||||
| "Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087\n", | |||||
| "Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826\n", | |||||
| "Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696\n", | |||||
| "Train finished!\n" | |||||
| ] | ] | ||||
| } | } | ||||
| ], | ], | ||||
| @@ -311,23 +216,23 @@ | |||||
| ], | ], | ||||
| "metadata": { | "metadata": { | ||||
| "kernelspec": { | "kernelspec": { | ||||
| "display_name": "Python 2", | |||||
| "display_name": "Python 3", | |||||
| "language": "python", | "language": "python", | ||||
| "name": "python2" | |||||
| "name": "python3" | |||||
| }, | }, | ||||
| "language_info": { | "language_info": { | ||||
| "codemirror_mode": { | "codemirror_mode": { | ||||
| "name": "ipython", | "name": "ipython", | ||||
| "version": 2 | |||||
| "version": 3 | |||||
| }, | }, | ||||
| "file_extension": ".py", | "file_extension": ".py", | ||||
| "mimetype": "text/x-python", | "mimetype": "text/x-python", | ||||
| "name": "python", | "name": "python", | ||||
| "nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
| "pygments_lexer": "ipython2", | |||||
| "version": "2.7.6" | |||||
| "pygments_lexer": "ipython3", | |||||
| "version": "3.6.7" | |||||
| } | } | ||||
| }, | }, | ||||
| "nbformat": 4, | "nbformat": 4, | ||||
| "nbformat_minor": 0 | |||||
| "nbformat_minor": 1 | |||||
| } | } | ||||