@@ -19,7 +19,6 @@ from collections import defaultdict | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
from .field import _get_ele_type_and_dim | |||||
from ._logger import logger | from ._logger import logger | ||||
@@ -34,20 +33,20 @@ def _set_python_is_exit(): | |||||
atexit.register(_set_python_is_exit) | atexit.register(_set_python_is_exit) | ||||
def may_to_tensor(data, as_numpy, fn): | |||||
if not as_numpy: | |||||
dtype, dim = _get_ele_type_and_dim(data) | |||||
try: | |||||
data, flag = _to_tensor(data, dtype) | |||||
except TypeError as e: | |||||
logger.error(f"Field {fn} cannot be converted to torch.tensor.") | |||||
raise e | |||||
return data | |||||
def _pad(batch_dict, dataset, as_numpy): | |||||
result = {} | |||||
for n, vlist in batch_dict.items(): | |||||
f = dataset.field_arrays[n] | |||||
if f.padder is None: | |||||
result[n] = np.array(vlist) | |||||
else: | |||||
res = f.pad(vlist) | |||||
if not as_numpy: | |||||
res, _ = _to_tensor(res, field_dtype=f.dtype) | |||||
result[n] = res | |||||
return result | |||||
def convert_tensor(batch_dict, as_numpy): | |||||
for n, v in batch_dict.items(): | |||||
batch_dict[n] = may_to_tensor(v, as_numpy, n) | |||||
class DataSetGetter: | class DataSetGetter: | ||||
""" | """ | ||||
@@ -78,6 +77,7 @@ class DataSetGetter: | |||||
""" | """ | ||||
indices = [] | indices = [] | ||||
sin_x, sin_y = defaultdict(list), defaultdict(list) | sin_x, sin_y = defaultdict(list), defaultdict(list) | ||||
# 收集需要关注的field的数据 | |||||
for idx, ins in ins_list: | for idx, ins in ins_list: | ||||
indices.append(idx) | indices.append(idx) | ||||
for n, v in ins.items(): | for n, v in ins.items(): | ||||
@@ -85,28 +85,16 @@ class DataSetGetter: | |||||
sin_x[n].append(v) | sin_x[n].append(v) | ||||
if n in self.y_names: | if n in self.y_names: | ||||
sin_y[n].append(v) | sin_y[n].append(v) | ||||
def pad(batch_dict): | |||||
result = {} | |||||
for n, vlist in batch_dict.items(): | |||||
f = self.dataset.field_arrays[n] | |||||
if f.padder is None: | |||||
result[n] = np.array(vlist) | |||||
else: | |||||
result[n] = f.pad(vlist) | |||||
return result | |||||
sin_x = pad(sin_x) | |||||
sin_y = pad(sin_y) | |||||
convert_tensor(sin_x, self.as_numpy) | |||||
convert_tensor(sin_y, self.as_numpy) | |||||
# 根据情况,进行pad | |||||
sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) | |||||
sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) | |||||
if not self.dataset.collector.is_empty(): | if not self.dataset.collector.is_empty(): | ||||
bx, by = self.dataset._collect_batch(ins_list) | bx, by = self.dataset._collect_batch(ins_list) | ||||
sin_x.update(bx) | sin_x.update(bx) | ||||
sin_y.update(by) | sin_y.update(by) | ||||
return (indices, sin_x, sin_y) | |||||
return indices, sin_x, sin_y | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
if hasattr(self.dataset, item): | if hasattr(self.dataset, item): | ||||
@@ -134,8 +122,7 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||||
class BatchIter: | class BatchIter: | ||||
""" | """ | ||||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(), | |||||
__iter__()方法。 | |||||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及属性。 | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
@@ -272,9 +259,129 @@ class DataSetIter(BatchIter): | |||||
class TorchLoaderIter(BatchIter): | class TorchLoaderIter(BatchIter): | ||||
""" | """ | ||||
与DataSetIter类似,但用于pytorch的DataSet对象。 | |||||
通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
与DataSetIter类似,但可以用于pytorch的DataSet对象。可以通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||||
或者也可以传入任何实现了类似以下方法的对象 | |||||
Example:: | |||||
import random | |||||
from fastNLP import TorchLoaderIter | |||||
import torch | |||||
class UdfDataSet: | |||||
def __init__(self, num_samples): | |||||
self.num_samples = num_samples | |||||
def __getitem__(self, idx): # 必须实现的方法,输入参数是一个int,范围为[0, len(self)) | |||||
x = [random.random() for _ in range(3)] | |||||
y = random.random() | |||||
return x,y | |||||
def __len__(self): # 需要实现该方法返回值需要是一个int数据 | |||||
return self.num_samples | |||||
# 需要实现collact_fn将数据转换为tensor | |||||
def collact_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
xs, ys = [], [] | |||||
for l in data_list: | |||||
x, y = l | |||||
xs.append(x) | |||||
ys.append(y) | |||||
# 不需要转移到gpu,Trainer和Tester会将其转移到model所在的device | |||||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
return {'x':x, 'y':y}, {'y':y} | |||||
udf_dataset = UdfDataSet(10) | |||||
dataset = TorchLoaderIter(udf_dataset, collate_fn=collact_fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(3, 1) | |||||
def forward(self, x, y): | |||||
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||||
def predict(self, x): | |||||
return {'pred':self.fc(x).squeeze(0)} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
除此之外,还可以通过该方法实现OnTheFly的训练,如下面的代码所示 | |||||
Example:: | |||||
import tempfile | |||||
import random | |||||
import torch | |||||
tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True) | |||||
try: | |||||
num_samples, data = 10, [] | |||||
for _ in range(num_samples): | |||||
x, y = [random.random() for _ in range(3)], random.random() | |||||
data.append(x + [y]) | |||||
with open(tmp_file_path, 'w') as f: | |||||
for d in data: | |||||
f.write(' '.join(map(str, d)) + '\n') | |||||
class FileDataSet: | |||||
def __init__(self, tmp_file): | |||||
num_samples = 0 | |||||
line_pos = [0] # 对应idx是某一行对应的位置 | |||||
self.tmp_file_handler = open(tmp_file, 'r', encoding='utf-8') | |||||
line = self.tmp_file_handler.readline() | |||||
while line: | |||||
if line.strip(): | |||||
num_samples += 1 | |||||
line_pos.append(self.tmp_file_handler.tell()) | |||||
line = self.tmp_file_handler.readline() | |||||
self.tmp_file_handler.seek(0) | |||||
self.num_samples = num_samples | |||||
self.line_pos = line_pos | |||||
def __getitem__(self, idx): | |||||
line_start, line_end = self.line_pos[idx], self.line_pos[idx + 1] | |||||
self.tmp_file_handler.seek(line_start) | |||||
line = self.tmp_file_handler.read(line_end - line_start).strip() | |||||
values = list(map(float, line.split())) | |||||
x, y = values[:3], values[-1] | |||||
return x, y | |||||
def __len__(self): | |||||
return self.num_samples | |||||
def collact_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
xs, ys = [], [] | |||||
for l in data_list: | |||||
x, y = l | |||||
xs.append(x) | |||||
ys.append(y) | |||||
x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
return {'x': x, 'y': y}, {'y': y} | |||||
file_data = FileDataSet(tmp_file_path) | |||||
dataset = TorchLoaderIter(file_data, collate_fn=collact_fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(3, 1) | |||||
def forward(self, x, y): | |||||
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | |||||
def predict(self, x): | |||||
return {'pred': self.fc(x).squeeze(0)} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False, n_epochs=2) | |||||
trainer.train(load_best_model=False) | |||||
finally: | |||||
import os | |||||
if os.path.exists(tmp_file_path): | |||||
os.remove(tmp_file_path) | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
@@ -291,12 +398,13 @@ class TorchLoaderIter(BatchIter): | |||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | ||||
:param timeout: 生成一个batch的timeout值 | :param timeout: 生成一个batch的timeout值 | ||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
:param collate_fn: 用于将样本组合成batch的函数""" | |||||
:param collate_fn: 用于将样本组合成batch的函数。 | |||||
""" | |||||
assert len(dataset) > 0 | assert len(dataset) > 0 | ||||
ins = dataset[0] | ins = dataset[0] | ||||
assert len(ins) == 2 and \ | |||||
isinstance(ins[0], dict) and \ | |||||
isinstance(ins[1], dict), 'DataSet should return two dict, as X and Y' | |||||
if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None: | |||||
raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the" | |||||
" `collate_fn` must be provided.") | |||||
super().__init__( | super().__init__( | ||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | dataset=dataset, batch_size=batch_size, sampler=sampler, | ||||
@@ -775,7 +775,8 @@ class DataSet(object): | |||||
def set_ignore_type(self, *field_names, flag=True): | def set_ignore_type(self, *field_names, flag=True): | ||||
""" | """ | ||||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | ||||
默认情况下也不进行pad。 | |||||
默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder | |||||
中转换,但不需要在padder中移动到gpu。 | |||||
:param str field_names: field的名称 | :param str field_names: field的名称 | ||||
:param bool flag: 将field_name的ignore_type状态设置为flag | :param bool flag: 将field_name的ignore_type状态设置为flag | ||||
@@ -492,6 +492,7 @@ class Trainer(object): | |||||
elif isinstance(train_data, BatchIter): | elif isinstance(train_data, BatchIter): | ||||
self.data_iterator = train_data | self.data_iterator = train_data | ||||
train_data = train_data.dataset | train_data = train_data.dataset | ||||
check_code_level = -1 # 强制跳过校验 | |||||
else: | else: | ||||
raise TypeError("train_data type {} not support".format(type(train_data))) | raise TypeError("train_data type {} not support".format(type(train_data))) | ||||
@@ -873,26 +874,9 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL | |||||
dev_data=None, metric_key=None, check_level=0): | dev_data=None, metric_key=None, check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_device = _get_model_device(model=model) | model_device = _get_model_device(model=model) | ||||
def _iter(): | |||||
start_idx = 0 | |||||
while start_idx<len(dataset): | |||||
batch_x = {} | |||||
batch_y = {} | |||||
for field_name, field in dataset.get_all_fields().items(): | |||||
indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch, _ = _to_tensor(batch, field.dtype) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
yield (batch_x, batch_y) | |||||
start_idx += batch_size | |||||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||||
_iter = DataSetIter(dataset, batch_size=batch_size, sampler=None) | |||||
for batch_count, (batch_x, batch_y) in enumerate(_iter): | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_device) | _move_dict_value_to_device(batch_x, batch_y, device=model_device) | ||||
# forward check | # forward check | ||||
if batch_count == 0: | if batch_count == 0: | ||||
@@ -64,7 +64,7 @@ class TestCase1(unittest.TestCase): | |||||
for _, _ in batch: | for _, _ in batch: | ||||
cnt += 1 | cnt += 1 | ||||
self.assertEqual(cnt, 10) | self.assertEqual(cnt, 10) | ||||
def test_dataset_batching(self): | def test_dataset_batching(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
@@ -151,6 +151,32 @@ class TestCase1(unittest.TestCase): | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
def test_udf_padder(self): | |||||
from fastNLP.core.field import Padder | |||||
alphas = list('abcdefghijk') | |||||
class UDFPadder(Padder): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
results = [alphas[:con] for con in contents] | |||||
return results | |||||
batch_size = 32 | |||||
num_samples = 1000 | |||||
dataset = generate_fake_dataset(num_samples) | |||||
contents = np.random.randint(5, size=(num_samples)) | |||||
dataset.add_field('test', contents, is_input=True, padder=UDFPadder(), | |||||
ignore_type=True) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | |||||
test = batch_x['test'] | |||||
indices = batch.cur_batch_indices | |||||
cons = contents[indices] | |||||
for con,t in zip(cons, test): | |||||
self.assertEqual(alphas[:con], t) | |||||
def test_collect_fn(self): | def test_collect_fn(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
num_samples = 1000 | num_samples = 1000 | ||||
@@ -13,6 +13,8 @@ from fastNLP import AccuracyMetric | |||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
from fastNLP import TorchLoaderIter | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
@@ -71,8 +73,11 @@ class TrainerTestGround(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, | metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, | ||||
use_tqdm=True, check_code_level=2) | use_tqdm=True, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
import os | |||||
if os.path.exists(save_path): | |||||
import shutil | |||||
shutil.rmtree(save_path) | |||||
def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||
# 检查报错提示能否正确提醒用户。 | # 检查报错提示能否正确提醒用户。 | ||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | ||||
@@ -222,7 +227,224 @@ class TrainerTestGround(unittest.TestCase): | |||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | ||||
metrics=AccuracyMetric(), use_tqdm=False) | metrics=AccuracyMetric(), use_tqdm=False) | ||||
def test_udf_dataiter(self): | |||||
import random | |||||
import torch | |||||
class UdfDataSet: | |||||
def __init__(self, num_samples): | |||||
self.num_samples = num_samples | |||||
def __getitem__(self, idx): | |||||
x = [random.random() for _ in range(3)] | |||||
y = random.random() | |||||
return x,y | |||||
def __len__(self): | |||||
return self.num_samples | |||||
def collect_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
xs, ys = [], [] | |||||
for l in data_list: | |||||
x, y = l | |||||
xs.append(x) | |||||
ys.append(y) | |||||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
return {'x':x, 'y':y}, {'y':y} | |||||
dataset = UdfDataSet(10) | |||||
dataset = TorchLoaderIter(dataset, collate_fn=collect_fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(3, 1) | |||||
def forward(self, x, y): | |||||
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||||
def predict(self, x): | |||||
return {'pred':self.fc(x).squeeze(0)} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
def test_onthefly_iter(self): | |||||
import tempfile | |||||
import random | |||||
import torch | |||||
tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True) | |||||
try: | |||||
num_samples = 10 | |||||
data = [] | |||||
for _ in range(num_samples): | |||||
x, y = [random.random() for _ in range(3)], random.random() | |||||
data.append(x + [y]) | |||||
with open(tmp_file_path, 'w') as f: | |||||
for d in data: | |||||
f.write(' '.join(map(str, d)) + '\n') | |||||
class FileDataSet: | |||||
def __init__(self, tmp_file): | |||||
num_samples = 0 | |||||
line_pos = [0] # 对应idx是某一行对应的位置 | |||||
self.tmp_file_handler = open(tmp_file, 'r', encoding='utf-8') | |||||
line = self.tmp_file_handler.readline() | |||||
while line: | |||||
if line.strip(): | |||||
num_samples += 1 | |||||
line_pos.append(self.tmp_file_handler.tell()) | |||||
line = self.tmp_file_handler.readline() | |||||
self.tmp_file_handler.seek(0) | |||||
self.num_samples = num_samples | |||||
self.line_pos = line_pos | |||||
def __getitem__(self, idx): | |||||
line_start, line_end = self.line_pos[idx], self.line_pos[idx + 1] | |||||
self.tmp_file_handler.seek(line_start) | |||||
line = self.tmp_file_handler.read(line_end - line_start).strip() | |||||
values = list(map(float, line.split())) | |||||
gold_d = data[idx] | |||||
assert all([g==v for g,v in zip(gold_d, values)]), "Should have the same data" | |||||
x, y = values[:3], values[-1] | |||||
return x, y | |||||
def __len__(self): | |||||
return self.num_samples | |||||
def collact_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
xs, ys = [], [] | |||||
for l in data_list: | |||||
x, y = l | |||||
xs.append(x) | |||||
ys.append(y) | |||||
x, y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
return {'x': x, 'y': y}, {'y': y} | |||||
dataset = FileDataSet(tmp_file_path) | |||||
dataset = TorchLoaderIter(dataset, collate_fn=collact_fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(3, 1) | |||||
def forward(self, x, y): | |||||
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | |||||
def predict(self, x): | |||||
return {'pred': self.fc(x).squeeze(0)} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False, n_epochs=2) | |||||
trainer.train(load_best_model=False) | |||||
finally: | |||||
import os | |||||
if os.path.exists(tmp_file_path): | |||||
os.remove(tmp_file_path) | |||||
def test_collecct_fn(self): | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
import torch | |||||
def fn(ins_list): | |||||
x = [] | |||||
for ind, ins in ins_list: | |||||
x.append(ins['x1']+ins['x2']) | |||||
x = torch.FloatTensor(x) | |||||
return {'x':x}, {} | |||||
dataset.add_collect_fn(fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, x): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = self.fc(x) | |||||
sum_x = x1 + x2 + x | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'pred': sum_x} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(target='y'), print_every=2, | |||||
dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
trainer.train() | |||||
def test_collect_fn2(self): | |||||
"""测试能否实现batch_x, batch_y""" | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
import torch | |||||
def fn(ins_list): | |||||
x = [] | |||||
for ind, ins in ins_list: | |||||
x.append(ins['x1']+ins['x2']) | |||||
x = torch.FloatTensor(x) | |||||
return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} | |||||
dataset.add_collect_fn(fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, x): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = self.fc(x) | |||||
sum_x = x1 + x2 + x | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'pred': sum_x} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, | |||||
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) | |||||
trainer.train() | |||||
def test_collect_fn3(self): | |||||
""" | |||||
测试应该会覆盖 | |||||
:return: | |||||
""" | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y') | |||||
import torch | |||||
def fn(ins_list): | |||||
x = [] | |||||
for ind, ins in ins_list: | |||||
x.append(ins['x1']+ins['x2']) | |||||
x = torch.FloatTensor(x) | |||||
return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} | |||||
dataset.add_collect_fn(fn) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 1, bias=False) | |||||
def forward(self, x1): | |||||
x1 = self.fc(x1) | |||||
assert x1.sum()==0, "Should be replaced to one" | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'pred': x1} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, | |||||
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False, n_epochs=1) | |||||
best_metric = trainer.train()['best_eval']['AccuracyMetric']['acc'] | |||||
self.assertTrue(best_metric==1) | |||||
""" | """ | ||||
def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||