diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 131ba28d..18da9bd7 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -33,7 +33,9 @@ class DataSet(object): return self.dataset[name][self.idx] def __setitem__(self, name, val): - # TODO check new field. + if name not in self.dataset: + new_fields = [None]*len(self.dataset) + self.dataset.add_field(name, new_fields) self.dataset[name][self.idx] = val def __repr__(self): @@ -45,6 +47,9 @@ class DataSet(object): if instance is not None: self._convert_ins(instance) + def __contains__(self, item): + return item in self.field_arrays + def __iter__(self): return self.DataSetIter(self) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index bb76b974..3e6b9c3b 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -7,7 +7,7 @@ from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.dataset import DataSet from fastNLP.api.processor import Processor - +from reproduction.chinese_word_segment.process.span_converter import * _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' diff --git a/reproduction/chinese_word_segment/process/span_converter.py b/reproduction/chinese_word_segment/process/span_converter.py index 23e590c4..2635df0e 100644 --- a/reproduction/chinese_word_segment/process/span_converter.py +++ b/reproduction/chinese_word_segment/process/span_converter.py @@ -2,9 +2,9 @@ import re -class SpanConverterBase: +class SpanConverter: def __init__(self, replace_tag, pattern): - super(SpanConverterBase, self).__init__() + super(SpanConverter, self).__init__() self.replace_tag = replace_tag self.pattern = pattern @@ -33,7 +33,7 @@ class SpanConverterBase: return spans -class AlphaSpanConverter(SpanConverterBase): +class AlphaSpanConverter(SpanConverter): def __init__(self): replace_tag = '' # 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). @@ -42,7 +42,7 @@ class AlphaSpanConverter(SpanConverterBase): super(AlphaSpanConverter, self).__init__(replace_tag, pattern) -class DigitSpanConverter(SpanConverterBase): +class DigitSpanConverter(SpanConverter): def __init__(self): replace_tag = '' pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' @@ -71,7 +71,7 @@ class DigitSpanConverter(SpanConverterBase): return '' -class TimeConverter(SpanConverterBase): +class TimeConverter(SpanConverter): def __init__(self): replace_tag = '' pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' @@ -80,7 +80,7 @@ class TimeConverter(SpanConverterBase): -class MixNumAlphaConverter(SpanConverterBase): +class MixNumAlphaConverter(SpanConverter): def __init__(self): replace_tag = '' pattern = None @@ -177,7 +177,7 @@ class MixNumAlphaConverter(SpanConverterBase): -class EmailConverter(SpanConverterBase): +class EmailConverter(SpanConverter): def __init__(self): replaced_tag = "" pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])'