|
|
@@ -25,13 +25,13 @@ class RegistryTest(unittest.TestCase): |
|
|
|
def test_register_class_with_task(self): |
|
|
|
MODELS = Registry('models') |
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.image_classfication, 'SwinT') |
|
|
|
@MODELS.register_module(Tasks.image_classification, 'SwinT') |
|
|
|
class SwinTForCls(object): |
|
|
|
pass |
|
|
|
|
|
|
|
self.assertTrue(Tasks.image_classfication in MODELS.modules) |
|
|
|
self.assertTrue(Tasks.image_classification in MODELS.modules) |
|
|
|
self.assertTrue( |
|
|
|
MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls) |
|
|
|
MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls) |
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert') |
|
|
|
class BertForSentimentAnalysis(object): |
|
|
@@ -54,7 +54,7 @@ class RegistryTest(unittest.TestCase): |
|
|
|
def test_list(self): |
|
|
|
MODELS = Registry('models') |
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.image_classfication, 'SwinT') |
|
|
|
@MODELS.register_module(Tasks.image_classification, 'SwinT') |
|
|
|
class SwinTForCls(object): |
|
|
|
pass |
|
|
|
|
|
|
@@ -68,7 +68,7 @@ class RegistryTest(unittest.TestCase): |
|
|
|
def test_build(self): |
|
|
|
MODELS = Registry('models') |
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.image_classfication, 'SwinT') |
|
|
|
@MODELS.register_module(Tasks.image_classification, 'SwinT') |
|
|
|
class SwinTForCls(object): |
|
|
|
pass |
|
|
|
|
|
|
@@ -77,7 +77,7 @@ class RegistryTest(unittest.TestCase): |
|
|
|
pass |
|
|
|
|
|
|
|
cfg = dict(type='SwinT') |
|
|
|
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) |
|
|
|
model = build_from_cfg(cfg, MODELS, Tasks.image_classification) |
|
|
|
self.assertTrue(isinstance(model, SwinTForCls)) |
|
|
|
|
|
|
|
cfg = dict(type='Bert') |
|
|
@@ -86,7 +86,7 @@ class RegistryTest(unittest.TestCase): |
|
|
|
|
|
|
|
with self.assertRaises(KeyError): |
|
|
|
cfg = dict(type='Bert') |
|
|
|
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) |
|
|
|
model = build_from_cfg(cfg, MODELS, Tasks.image_classification) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|