diff --git a/fastNLP/modules/prototype/example.py b/fastNLP/modules/prototype/example.py index 782937fe..a19898c6 100644 --- a/fastNLP/modules/prototype/example.py +++ b/fastNLP/modules/prototype/example.py @@ -60,7 +60,7 @@ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) criterion = nn.CrossEntropyLoss() - dataset = dataloader.DataLoader("test_set.pkl", batch_size, using_cuda=using_cuda) + dataset = dataloader.DataLoader("train_set.pkl", batch_size, using_cuda=using_cuda) #statistics loss_count = 0