| @@ -19,7 +19,6 @@ from collections import defaultdict | |||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
| from .field import _get_ele_type_and_dim | |||||
| from ._logger import logger | from ._logger import logger | ||||
| @@ -34,20 +33,20 @@ def _set_python_is_exit(): | |||||
| atexit.register(_set_python_is_exit) | atexit.register(_set_python_is_exit) | ||||
| def may_to_tensor(data, as_numpy, fn): | |||||
| if not as_numpy: | |||||
| dtype, dim = _get_ele_type_and_dim(data) | |||||
| try: | |||||
| data, flag = _to_tensor(data, dtype) | |||||
| except TypeError as e: | |||||
| logger.error(f"Field {fn} cannot be converted to torch.tensor.") | |||||
| raise e | |||||
| return data | |||||
| def _pad(batch_dict, dataset, as_numpy): | |||||
| result = {} | |||||
| for n, vlist in batch_dict.items(): | |||||
| f = dataset.field_arrays[n] | |||||
| if f.padder is None: | |||||
| result[n] = np.array(vlist) | |||||
| else: | |||||
| res = f.pad(vlist) | |||||
| if not as_numpy: | |||||
| res, _ = _to_tensor(res, field_dtype=f.dtype) | |||||
| result[n] = res | |||||
| return result | |||||
| def convert_tensor(batch_dict, as_numpy): | |||||
| for n, v in batch_dict.items(): | |||||
| batch_dict[n] = may_to_tensor(v, as_numpy, n) | |||||
| class DataSetGetter: | class DataSetGetter: | ||||
| """ | """ | ||||
| @@ -78,6 +77,7 @@ class DataSetGetter: | |||||
| """ | """ | ||||
| indices = [] | indices = [] | ||||
| sin_x, sin_y = defaultdict(list), defaultdict(list) | sin_x, sin_y = defaultdict(list), defaultdict(list) | ||||
| # 收集需要关注的field的数据 | |||||
| for idx, ins in ins_list: | for idx, ins in ins_list: | ||||
| indices.append(idx) | indices.append(idx) | ||||
| for n, v in ins.items(): | for n, v in ins.items(): | ||||
| @@ -85,28 +85,16 @@ class DataSetGetter: | |||||
| sin_x[n].append(v) | sin_x[n].append(v) | ||||
| if n in self.y_names: | if n in self.y_names: | ||||
| sin_y[n].append(v) | sin_y[n].append(v) | ||||
| def pad(batch_dict): | |||||
| result = {} | |||||
| for n, vlist in batch_dict.items(): | |||||
| f = self.dataset.field_arrays[n] | |||||
| if f.padder is None: | |||||
| result[n] = np.array(vlist) | |||||
| else: | |||||
| result[n] = f.pad(vlist) | |||||
| return result | |||||
| sin_x = pad(sin_x) | |||||
| sin_y = pad(sin_y) | |||||
| convert_tensor(sin_x, self.as_numpy) | |||||
| convert_tensor(sin_y, self.as_numpy) | |||||
| # 根据情况,进行pad | |||||
| sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) | |||||
| sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) | |||||
| if not self.dataset.collector.is_empty(): | if not self.dataset.collector.is_empty(): | ||||
| bx, by = self.dataset._collect_batch(ins_list) | bx, by = self.dataset._collect_batch(ins_list) | ||||
| sin_x.update(bx) | sin_x.update(bx) | ||||
| sin_y.update(by) | sin_y.update(by) | ||||
| return (indices, sin_x, sin_y) | |||||
| return indices, sin_x, sin_y | |||||
| def __getattr__(self, item): | def __getattr__(self, item): | ||||
| if hasattr(self.dataset, item): | if hasattr(self.dataset, item): | ||||
| @@ -134,8 +122,7 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||||
| class BatchIter: | class BatchIter: | ||||
| """ | """ | ||||
| Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(), | |||||
| __iter__()方法。 | |||||
| Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及属性。 | |||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
| @@ -272,9 +259,129 @@ class DataSetIter(BatchIter): | |||||
| class TorchLoaderIter(BatchIter): | class TorchLoaderIter(BatchIter): | ||||
| """ | """ | ||||
| 与DataSetIter类似,但用于pytorch的DataSet对象。 | |||||
| 通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
| 与DataSetIter类似,但可以用于pytorch的DataSet对象。可以通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
| 或者也可以传入任何实现了类似以下方法的对象 | |||||
| Example:: | |||||
| import random | |||||
| from fastNLP import TorchLoaderIter | |||||
| import torch | |||||
| class UdfDataSet: | |||||
| def __init__(self, num_samples): | |||||
| self.num_samples = num_samples | |||||
| def __getitem__(self, idx): # 必须实现的方法,输入参数是一个int,范围为[0, len(self)) | |||||
| x = [random.random() for _ in range(3)] | |||||
| y = random.random() | |||||
| return x,y | |||||
| def __len__(self): # 需要实现该方法返回值需要是一个int数据 | |||||
| return self.num_samples | |||||
| # 需要实现collact_fn将数据转换为tensor | |||||
| def collact_fn(data_list): | |||||
| # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
| xs, ys = [], [] | |||||
| for l in data_list: | |||||
| x, y = l | |||||
| xs.append(x) | |||||
| ys.append(y) | |||||
| # 不需要转移到gpu,Trainer和Tester会将其转移到model所在的device | |||||
| x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
| return {'x':x, 'y':y}, {'y':y} | |||||
| udf_dataset = UdfDataSet(10) | |||||
| dataset = TorchLoaderIter(udf_dataset, collate_fn=collact_fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(3, 1) | |||||
| def forward(self, x, y): | |||||
| return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||||
| def predict(self, x): | |||||
| return {'pred':self.fc(x).squeeze(0)} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
| metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
| trainer.train(load_best_model=False) | |||||
| 除此之外,还可以通过该方法实现OnTheFly的训练,如下面的代码所示 | |||||
| Example:: | |||||
| import tempfile | |||||
| import random | |||||
| import torch | |||||
| tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True) | |||||
| try: | |||||
| num_samples, data = 10, [] | |||||
| for _ in range(num_samples): | |||||
| x, y = [random.random() for _ in range(3)], random.random() | |||||
| data.append(x + [y]) | |||||
| with open(tmp_file_path, 'w') as f: | |||||
| for d in data: | |||||
| f.write(' '.join(map(str, d)) + '\n') | |||||
| class FileDataSet: | |||||
| def __init__(self, tmp_file): | |||||
| num_samples = 0 | |||||
| line_pos = [0] # 对应idx是某一行对应的位置 | |||||
| self.tmp_file_handler = open(tmp_file, 'r', encoding='utf-8') | |||||
| line = self.tmp_file_handler.readline() | |||||
| while line: | |||||
| if line.strip(): | |||||
| num_samples += 1 | |||||
| line_pos.append(self.tmp_file_handler.tell()) | |||||
| line = self.tmp_file_handler.readline() | |||||
| self.tmp_file_handler.seek(0) | |||||
| self.num_samples = num_samples | |||||
| self.line_pos = line_pos | |||||
| def __getitem__(self, idx): | |||||
| line_start, line_end = self.line_pos[idx], self.line_pos[idx + 1] | |||||
| self.tmp_file_handler.seek(line_start) | |||||
| line = self.tmp_file_handler.read(line_end - line_start).strip() | |||||
| values = list(map(float, line.split())) | |||||
| x, y = values[:3], values[-1] | |||||
| return x, y | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| def collact_fn(data_list): | |||||
| # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
| xs, ys = [], [] | |||||
| for l in data_list: | |||||
| x, y = l | |||||
| xs.append(x) | |||||
| ys.append(y) | |||||
| x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
| return {'x': x, 'y': y}, {'y': y} | |||||
| file_data = FileDataSet(tmp_file_path) | |||||
| dataset = TorchLoaderIter(file_data, collate_fn=collact_fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(3, 1) | |||||
| def forward(self, x, y): | |||||
| return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | |||||
| def predict(self, x): | |||||
| return {'pred': self.fc(x).squeeze(0)} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
| metrics=AccuracyMetric(target='y'), use_tqdm=False, n_epochs=2) | |||||
| trainer.train(load_best_model=False) | |||||
| finally: | |||||
| import os | |||||
| if os.path.exists(tmp_file_path): | |||||
| os.remove(tmp_file_path) | |||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
| num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
| @@ -291,12 +398,13 @@ class TorchLoaderIter(BatchIter): | |||||
| :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | ||||
| :param timeout: 生成一个batch的timeout值 | :param timeout: 生成一个batch的timeout值 | ||||
| :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
| :param collate_fn: 用于将样本组合成batch的函数""" | |||||
| :param collate_fn: 用于将样本组合成batch的函数。 | |||||
| """ | |||||
| assert len(dataset) > 0 | assert len(dataset) > 0 | ||||
| ins = dataset[0] | ins = dataset[0] | ||||
| assert len(ins) == 2 and \ | |||||
| isinstance(ins[0], dict) and \ | |||||
| isinstance(ins[1], dict), 'DataSet should return two dict, as X and Y' | |||||
| if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None: | |||||
| raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the" | |||||
| " `collate_fn` must be provided.") | |||||
| super().__init__( | super().__init__( | ||||
| dataset=dataset, batch_size=batch_size, sampler=sampler, | dataset=dataset, batch_size=batch_size, sampler=sampler, | ||||
| @@ -775,7 +775,8 @@ class DataSet(object): | |||||
| def set_ignore_type(self, *field_names, flag=True): | def set_ignore_type(self, *field_names, flag=True): | ||||
| """ | """ | ||||
| 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | ||||
| 默认情况下也不进行pad。 | |||||
| 默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder | |||||
| 中转换,但不需要在padder中移动到gpu。 | |||||
| :param str field_names: field的名称 | :param str field_names: field的名称 | ||||
| :param bool flag: 将field_name的ignore_type状态设置为flag | :param bool flag: 将field_name的ignore_type状态设置为flag | ||||
| @@ -492,6 +492,7 @@ class Trainer(object): | |||||
| elif isinstance(train_data, BatchIter): | elif isinstance(train_data, BatchIter): | ||||
| self.data_iterator = train_data | self.data_iterator = train_data | ||||
| train_data = train_data.dataset | train_data = train_data.dataset | ||||
| check_code_level = -1 # 强制跳过校验 | |||||
| else: | else: | ||||
| raise TypeError("train_data type {} not support".format(type(train_data))) | raise TypeError("train_data type {} not support".format(type(train_data))) | ||||
| @@ -873,26 +874,9 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL | |||||
| dev_data=None, metric_key=None, check_level=0): | dev_data=None, metric_key=None, check_level=0): | ||||
| # check get_loss 方法 | # check get_loss 方法 | ||||
| model_device = _get_model_device(model=model) | model_device = _get_model_device(model=model) | ||||
| def _iter(): | |||||
| start_idx = 0 | |||||
| while start_idx<len(dataset): | |||||
| batch_x = {} | |||||
| batch_y = {} | |||||
| for field_name, field in dataset.get_all_fields().items(): | |||||
| indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||||
| if field.is_target or field.is_input: | |||||
| batch = field.get(indices) | |||||
| if field.dtype is not None and \ | |||||
| issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
| batch, _ = _to_tensor(batch, field.dtype) | |||||
| if field.is_target: | |||||
| batch_y[field_name] = batch | |||||
| if field.is_input: | |||||
| batch_x[field_name] = batch | |||||
| yield (batch_x, batch_y) | |||||
| start_idx += batch_size | |||||
| for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||||
| _iter = DataSetIter(dataset, batch_size=batch_size, sampler=None) | |||||
| for batch_count, (batch_x, batch_y) in enumerate(_iter): | |||||
| _move_dict_value_to_device(batch_x, batch_y, device=model_device) | _move_dict_value_to_device(batch_x, batch_y, device=model_device) | ||||
| # forward check | # forward check | ||||
| if batch_count == 0: | if batch_count == 0: | ||||
| @@ -64,7 +64,7 @@ class TestCase1(unittest.TestCase): | |||||
| for _, _ in batch: | for _, _ in batch: | ||||
| cnt += 1 | cnt += 1 | ||||
| self.assertEqual(cnt, 10) | self.assertEqual(cnt, 10) | ||||
| def test_dataset_batching(self): | def test_dataset_batching(self): | ||||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
| ds.set_input("x") | ds.set_input("x") | ||||
| @@ -151,6 +151,32 @@ class TestCase1(unittest.TestCase): | |||||
| for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
| pass | pass | ||||
| def test_udf_padder(self): | |||||
| from fastNLP.core.field import Padder | |||||
| alphas = list('abcdefghijk') | |||||
| class UDFPadder(Padder): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
| results = [alphas[:con] for con in contents] | |||||
| return results | |||||
| batch_size = 32 | |||||
| num_samples = 1000 | |||||
| dataset = generate_fake_dataset(num_samples) | |||||
| contents = np.random.randint(5, size=(num_samples)) | |||||
| dataset.add_field('test', contents, is_input=True, padder=UDFPadder(), | |||||
| ignore_type=True) | |||||
| batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
| for batch_x, batch_y in batch: | |||||
| test = batch_x['test'] | |||||
| indices = batch.cur_batch_indices | |||||
| cons = contents[indices] | |||||
| for con,t in zip(cons, test): | |||||
| self.assertEqual(alphas[:con], t) | |||||
| def test_collect_fn(self): | def test_collect_fn(self): | ||||
| batch_size = 32 | batch_size = 32 | ||||
| num_samples = 1000 | num_samples = 1000 | ||||
| @@ -13,6 +13,8 @@ from fastNLP import AccuracyMetric | |||||
| from fastNLP import SGD | from fastNLP import SGD | ||||
| from fastNLP import Trainer | from fastNLP import Trainer | ||||
| from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
| from fastNLP import TorchLoaderIter | |||||
| def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
| mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
| @@ -71,8 +73,11 @@ class TrainerTestGround(unittest.TestCase): | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, | metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, | ||||
| use_tqdm=True, check_code_level=2) | use_tqdm=True, check_code_level=2) | ||||
| trainer.train() | trainer.train() | ||||
| import os | |||||
| if os.path.exists(save_path): | |||||
| import shutil | |||||
| shutil.rmtree(save_path) | |||||
| def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||
| # 检查报错提示能否正确提醒用户。 | # 检查报错提示能否正确提醒用户。 | ||||
| # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | ||||
| @@ -222,7 +227,224 @@ class TrainerTestGround(unittest.TestCase): | |||||
| with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | ||||
| metrics=AccuracyMetric(), use_tqdm=False) | metrics=AccuracyMetric(), use_tqdm=False) | ||||
| def test_udf_dataiter(self): | |||||
| import random | |||||
| import torch | |||||
| class UdfDataSet: | |||||
| def __init__(self, num_samples): | |||||
| self.num_samples = num_samples | |||||
| def __getitem__(self, idx): | |||||
| x = [random.random() for _ in range(3)] | |||||
| y = random.random() | |||||
| return x,y | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| def collect_fn(data_list): | |||||
| # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
| xs, ys = [], [] | |||||
| for l in data_list: | |||||
| x, y = l | |||||
| xs.append(x) | |||||
| ys.append(y) | |||||
| x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
| return {'x':x, 'y':y}, {'y':y} | |||||
| dataset = UdfDataSet(10) | |||||
| dataset = TorchLoaderIter(dataset, collate_fn=collect_fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(3, 1) | |||||
| def forward(self, x, y): | |||||
| return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||||
| def predict(self, x): | |||||
| return {'pred':self.fc(x).squeeze(0)} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
| metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
| trainer.train(load_best_model=False) | |||||
| def test_onthefly_iter(self): | |||||
| import tempfile | |||||
| import random | |||||
| import torch | |||||
| tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True) | |||||
| try: | |||||
| num_samples = 10 | |||||
| data = [] | |||||
| for _ in range(num_samples): | |||||
| x, y = [random.random() for _ in range(3)], random.random() | |||||
| data.append(x + [y]) | |||||
| with open(tmp_file_path, 'w') as f: | |||||
| for d in data: | |||||
| f.write(' '.join(map(str, d)) + '\n') | |||||
| class FileDataSet: | |||||
| def __init__(self, tmp_file): | |||||
| num_samples = 0 | |||||
| line_pos = [0] # 对应idx是某一行对应的位置 | |||||
| self.tmp_file_handler = open(tmp_file, 'r', encoding='utf-8') | |||||
| line = self.tmp_file_handler.readline() | |||||
| while line: | |||||
| if line.strip(): | |||||
| num_samples += 1 | |||||
| line_pos.append(self.tmp_file_handler.tell()) | |||||
| line = self.tmp_file_handler.readline() | |||||
| self.tmp_file_handler.seek(0) | |||||
| self.num_samples = num_samples | |||||
| self.line_pos = line_pos | |||||
| def __getitem__(self, idx): | |||||
| line_start, line_end = self.line_pos[idx], self.line_pos[idx + 1] | |||||
| self.tmp_file_handler.seek(line_start) | |||||
| line = self.tmp_file_handler.read(line_end - line_start).strip() | |||||
| values = list(map(float, line.split())) | |||||
| gold_d = data[idx] | |||||
| assert all([g==v for g,v in zip(gold_d, values)]), "Should have the same data" | |||||
| x, y = values[:3], values[-1] | |||||
| return x, y | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| def collact_fn(data_list): | |||||
| # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
| xs, ys = [], [] | |||||
| for l in data_list: | |||||
| x, y = l | |||||
| xs.append(x) | |||||
| ys.append(y) | |||||
| x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
| return {'x': x, 'y': y}, {'y': y} | |||||
| dataset = FileDataSet(tmp_file_path) | |||||
| dataset = TorchLoaderIter(dataset, collate_fn=collact_fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(3, 1) | |||||
| def forward(self, x, y): | |||||
| return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | |||||
| def predict(self, x): | |||||
| return {'pred': self.fc(x).squeeze(0)} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
| metrics=AccuracyMetric(target='y'), use_tqdm=False, n_epochs=2) | |||||
| trainer.train(load_best_model=False) | |||||
| finally: | |||||
| import os | |||||
| if os.path.exists(tmp_file_path): | |||||
| os.remove(tmp_file_path) | |||||
| def test_collecct_fn(self): | |||||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||||
| dataset.set_input('x1', 'x2') | |||||
| dataset.set_target('y', 'x1') | |||||
| import torch | |||||
| def fn(ins_list): | |||||
| x = [] | |||||
| for ind, ins in ins_list: | |||||
| x.append(ins['x1']+ins['x2']) | |||||
| x = torch.FloatTensor(x) | |||||
| return {'x':x}, {} | |||||
| dataset.add_collect_fn(fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 4) | |||||
| def forward(self, x1, x2, x): | |||||
| x1 = self.fc(x1) | |||||
| x2 = self.fc(x2) | |||||
| x = self.fc(x) | |||||
| sum_x = x1 + x2 + x | |||||
| time.sleep(0.1) | |||||
| # loss = F.cross_entropy(x, y) | |||||
| return {'pred': sum_x} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(target='y'), print_every=2, | |||||
| dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
| trainer.train() | |||||
| def test_collect_fn2(self): | |||||
| """测试能否实现batch_x, batch_y""" | |||||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||||
| dataset.set_input('x1', 'x2') | |||||
| dataset.set_target('y', 'x1') | |||||
| import torch | |||||
| def fn(ins_list): | |||||
| x = [] | |||||
| for ind, ins in ins_list: | |||||
| x.append(ins['x1']+ins['x2']) | |||||
| x = torch.FloatTensor(x) | |||||
| return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} | |||||
| dataset.add_collect_fn(fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 4) | |||||
| def forward(self, x1, x2, x): | |||||
| x1 = self.fc(x1) | |||||
| x2 = self.fc(x2) | |||||
| x = self.fc(x) | |||||
| sum_x = x1 + x2 + x | |||||
| time.sleep(0.1) | |||||
| # loss = F.cross_entropy(x, y) | |||||
| return {'pred': sum_x} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, | |||||
| dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) | |||||
| trainer.train() | |||||
| def test_collect_fn3(self): | |||||
| """ | |||||
| 测试应该会覆盖 | |||||
| :return: | |||||
| """ | |||||
| dataset = prepare_fake_dataset2('x1', 'x2') | |||||
| dataset.set_input('x1', 'x2') | |||||
| dataset.set_target('y') | |||||
| import torch | |||||
| def fn(ins_list): | |||||
| x = [] | |||||
| for ind, ins in ins_list: | |||||
| x.append(ins['x1']+ins['x2']) | |||||
| x = torch.FloatTensor(x) | |||||
| return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} | |||||
| dataset.add_collect_fn(fn) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 1, bias=False) | |||||
| def forward(self, x1): | |||||
| x1 = self.fc(x1) | |||||
| assert x1.sum()==0, "Should be replaced to one" | |||||
| # loss = F.cross_entropy(x, y) | |||||
| return {'pred': x1} | |||||
| model = Model() | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, | |||||
| dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False, n_epochs=1) | |||||
| best_metric = trainer.train()['best_eval']['AccuracyMetric']['acc'] | |||||
| self.assertTrue(best_metric==1) | |||||
| """ | """ | ||||
| def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
| dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||