@@ -190,9 +190,6 @@ class EchoCallback(Callback): | |||||
def before_batch(self, batch_x, batch_y, indices): | def before_batch(self, batch_x, batch_y, indices): | ||||
print("before_batch") | print("before_batch") | ||||
print("batch_x:", batch_x) | |||||
print("batch_y:", batch_y) | |||||
print("indices: ", indices) | |||||
def before_loss(self, batch_y, predict_y): | def before_loss(self, batch_y, predict_y): | ||||
print("before_loss") | print("before_loss") | ||||
@@ -257,7 +257,7 @@ class Trainer(object): | |||||
self._update() | self._update() | ||||
# lr scheduler; lr_finder; one_cycle | # lr scheduler; lr_finder; one_cycle | ||||
self.callback_manager.after_step() | |||||
self.callback_manager.after_step(self.optimizer) | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | ||||
for name, param in self.model.named_parameters(): | for name, param in self.model.named_parameters(): | ||||
@@ -197,4 +197,4 @@ class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | def test__repr__(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
for iter in ds: | for iter in ds: | ||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}") |
@@ -360,7 +360,8 @@ class TestBMESF1PreRecMetric(unittest.TestCase): | |||||
metric = BMESF1PreRecMetric() | metric = BMESF1PreRecMetric() | ||||
metric(pred_dict, target_dict) | metric(pred_dict, target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) | |||||
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | |||||
class TestUsefulFunctions(unittest.TestCase): | class TestUsefulFunctions(unittest.TestCase): | ||||
# 测试metrics.py中一些看上去挺有用的函数 | # 测试metrics.py中一些看上去挺有用的函数 | ||||