Browse Source

增加dataset自动创建对应的array

tags/v0.2.0
yh 5 years ago
parent
commit
d818e91380
3 changed files with 14 additions and 9 deletions
  1. +6
    -1
      fastNLP/core/dataset.py
  2. +1
    -1
      reproduction/chinese_word_segment/process/cws_processor.py
  3. +7
    -7
      reproduction/chinese_word_segment/process/span_converter.py

+ 6
- 1
fastNLP/core/dataset.py View File

@@ -33,7 +33,9 @@ class DataSet(object):
return self.dataset[name][self.idx]

def __setitem__(self, name, val):
# TODO check new field.
if name not in self.dataset:
new_fields = [None]*len(self.dataset)
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val

def __repr__(self):
@@ -45,6 +47,9 @@ class DataSet(object):
if instance is not None:
self._convert_ins(instance)

def __contains__(self, item):
return item in self.field_arrays

def __iter__(self):
return self.DataSetIter(self)



+ 1
- 1
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -7,7 +7,7 @@ from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet

from fastNLP.api.processor import Processor
from reproduction.chinese_word_segment.process.span_converter import *

_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'



+ 7
- 7
reproduction/chinese_word_segment/process/span_converter.py View File

@@ -2,9 +2,9 @@
import re


class SpanConverterBase:
class SpanConverter:
def __init__(self, replace_tag, pattern):
super(SpanConverterBase, self).__init__()
super(SpanConverter, self).__init__()

self.replace_tag = replace_tag
self.pattern = pattern
@@ -33,7 +33,7 @@ class SpanConverterBase:
return spans


class AlphaSpanConverter(SpanConverterBase):
class AlphaSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<ALPHA>'
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag).
@@ -42,7 +42,7 @@ class AlphaSpanConverter(SpanConverterBase):
super(AlphaSpanConverter, self).__init__(replace_tag, pattern)


class DigitSpanConverter(SpanConverterBase):
class DigitSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<NUM>'
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])'
@@ -71,7 +71,7 @@ class DigitSpanConverter(SpanConverterBase):
return '<NUM>'


class TimeConverter(SpanConverterBase):
class TimeConverter(SpanConverter):
def __init__(self):
replace_tag = '<TOC>'
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])'
@@ -80,7 +80,7 @@ class TimeConverter(SpanConverterBase):



class MixNumAlphaConverter(SpanConverterBase):
class MixNumAlphaConverter(SpanConverter):
def __init__(self):
replace_tag = '<MIX>'
pattern = None
@@ -177,7 +177,7 @@ class MixNumAlphaConverter(SpanConverterBase):



class EmailConverter(SpanConverterBase):
class EmailConverter(SpanConverter):
def __init__(self):
replaced_tag = "<EML>"
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])'


Loading…
Cancel
Save