|
|
@@ -10,7 +10,7 @@ from fastNLP.core.metrics import AccuracyMetric |
|
|
|
class TestTutorial(unittest.TestCase): |
|
|
|
def test_fastnlp_10min_tutorial(self): |
|
|
|
# 从csv读取数据到DataSet |
|
|
|
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" |
|
|
|
sample_path = "data_for_tests/tutorial_sample_dataset.csv" |
|
|
|
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), |
|
|
|
sep='\t') |
|
|
|
print(len(dataset)) |
|
|
@@ -76,9 +76,7 @@ class TestTutorial(unittest.TestCase): |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 |
|
|
|
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 |
|
|
|
train_data.rename_field('label', 'label_seq') |
|
|
|
test_data.rename_field('words', 'word_seq') |
|
|
|
test_data.rename_field('label', 'label_seq') |
|
|
|
|
|
|
|
loss = CrossEntropyLoss(pred="output", target="label_seq") |
|
|
|