Browse Source

fix typo

master
Yingda Chen 3 years ago
parent
commit
3cbdcb1d3e
2 changed files with 9 additions and 9 deletions
  1. +2
    -2
      maas_lib/utils/constant.py
  2. +7
    -7
      tests/utils/test_registry.py

+ 2
- 2
maas_lib/utils/constant.py View File

@@ -9,7 +9,7 @@ class Fields(object):
cv = 'cv' cv = 'cv'
nlp = 'nlp' nlp = 'nlp'
audio = 'audio' audio = 'audio'
multi_modal = 'multi_modal'
multi_modal = 'multi-modal'




class Tasks(object): class Tasks(object):
@@ -21,7 +21,7 @@ class Tasks(object):
# vision tasks # vision tasks
image_to_text = 'image-to-text' image_to_text = 'image-to-text'
pose_estimation = 'pose-estimation' pose_estimation = 'pose-estimation'
image_classfication = 'image-classification'
image_classification = 'image-classification'
image_tagging = 'image-tagging' image_tagging = 'image-tagging'
object_detection = 'object-detection' object_detection = 'object-detection'
image_segmentation = 'image-segmentation' image_segmentation = 'image-segmentation'


+ 7
- 7
tests/utils/test_registry.py View File

@@ -25,13 +25,13 @@ class RegistryTest(unittest.TestCase):
def test_register_class_with_task(self): def test_register_class_with_task(self):
MODELS = Registry('models') MODELS = Registry('models')


@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object): class SwinTForCls(object):
pass pass


self.assertTrue(Tasks.image_classfication in MODELS.modules)
self.assertTrue(Tasks.image_classification in MODELS.modules)
self.assertTrue( self.assertTrue(
MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls)
MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls)


@MODELS.register_module(Tasks.sentiment_analysis, 'Bert') @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object): class BertForSentimentAnalysis(object):
@@ -54,7 +54,7 @@ class RegistryTest(unittest.TestCase):
def test_list(self): def test_list(self):
MODELS = Registry('models') MODELS = Registry('models')


@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object): class SwinTForCls(object):
pass pass


@@ -68,7 +68,7 @@ class RegistryTest(unittest.TestCase):
def test_build(self): def test_build(self):
MODELS = Registry('models') MODELS = Registry('models')


@MODELS.register_module(Tasks.image_classfication, 'SwinT')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object): class SwinTForCls(object):
pass pass


@@ -77,7 +77,7 @@ class RegistryTest(unittest.TestCase):
pass pass


cfg = dict(type='SwinT') cfg = dict(type='SwinT')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
self.assertTrue(isinstance(model, SwinTForCls)) self.assertTrue(isinstance(model, SwinTForCls))


cfg = dict(type='Bert') cfg = dict(type='Bert')
@@ -86,7 +86,7 @@ class RegistryTest(unittest.TestCase):


with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cfg = dict(type='Bert') cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save