diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 20d791c4..94871222 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -92,8 +92,40 @@ class ConfigSection(object): setattr(self, key, value) def __contains__(self, item): + """ + :param item: The key of item. + :return: True if the key in self.__dict__.keys() else False. + """ return item in self.__dict__.keys() + def __eq__(self, other): + """Overwrite the == operator + + :param other: Another ConfigSection() object which to be compared. + :return: True if value of each key in each ConfigSection() object are equal to the other, else False. + """ + for k in self.__dict__.keys(): + if k not in other.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + for k in other.__dict__.keys(): + if k not in self.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + return True + + def __ne__(self, other): + """Overwrite the != operator + + :param other: + :return: + """ + return not self.__eq__(other) + @property def data(self): return self.__dict__