|
|
@@ -6,9 +6,7 @@ from ...utils.constant import Tasks |
|
|
|
from ..base import Model, Tensor |
|
|
|
from ..builder import MODELS |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'StructBertForMaskedLM', 'VecoForMaskedLM', 'AliceMindBaseForMaskedLM' |
|
|
|
] |
|
|
|
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM'] |
|
|
|
|
|
|
|
|
|
|
|
class AliceMindBaseForMaskedLM(Model): |
|
|
@@ -40,9 +38,13 @@ class AliceMindBaseForMaskedLM(Model): |
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.fill_mask, module_name=r'sbert') |
|
|
|
class StructBertForMaskedLM(AliceMindBaseForMaskedLM): |
|
|
|
# The StructBert for MaskedLM uses the same underlying model structure |
|
|
|
# as the base model class. |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.fill_mask, module_name=r'veco') |
|
|
|
class VecoForMaskedLM(AliceMindBaseForMaskedLM): |
|
|
|
# The Veco for MaskedLM uses the same underlying model structure |
|
|
|
# as the base model class. |
|
|
|
pass |