Browse Source

fix typo

master
Yingda Chen 3 years ago
parent
commit
3cbdcb1d3e
2 changed files with 9 additions and 9 deletions
  1. +2
    -2
      maas_lib/utils/constant.py
  2. +7
    -7
      tests/utils/test_registry.py

+ 2
- 2
maas_lib/utils/constant.py View File

@@ -9,7 +9,7 @@ class Fields(object):
cv = 'cv'
nlp = 'nlp'
audio = 'audio'
multi_modal = 'multi_modal'
multi_modal = 'multi-modal'


class Tasks(object):
@@ -21,7 +21,7 @@ class Tasks(object):
# vision tasks
image_to_text = 'image-to-text'
pose_estimation = 'pose-estimation'
image_classfication = 'image-classification'
image_classification = 'image-classification'
image_tagging = 'image-tagging'
object_detection = 'object-detection'
image_segmentation = 'image-segmentation'


+ 7
- 7
tests/utils/test_registry.py View File

@@ -25,13 +25,13 @@ class RegistryTest(unittest.TestCase):
def test_register_class_with_task(self):
MODELS = Registry('models')

@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass

self.assertTrue(Tasks.image_classfication in MODELS.modules)
self.assertTrue(Tasks.image_classification in MODELS.modules)
self.assertTrue(
MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls)
MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls)

@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
@@ -54,7 +54,7 @@ class RegistryTest(unittest.TestCase):
def test_list(self):
MODELS = Registry('models')

@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass

@@ -68,7 +68,7 @@ class RegistryTest(unittest.TestCase):
def test_build(self):
MODELS = Registry('models')

@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass

@@ -77,7 +77,7 @@ class RegistryTest(unittest.TestCase):
pass

cfg = dict(type='SwinT')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
self.assertTrue(isinstance(model, SwinTForCls))

cfg = dict(type='Bert')
@@ -86,7 +86,7 @@ class RegistryTest(unittest.TestCase):

with self.assertRaises(KeyError):
cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)


if __name__ == '__main__':


Loading…
Cancel
Save