|
@@ -99,12 +99,26 @@ class DataSet(list): |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def update_vocab(self, **name_vocab): |
|
|
def update_vocab(self, **name_vocab): |
|
|
|
|
|
"""using certain field data to update vocabulary. |
|
|
|
|
|
|
|
|
|
|
|
e.g. :: |
|
|
|
|
|
|
|
|
|
|
|
# update word vocab and label vocab seperately |
|
|
|
|
|
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab) |
|
|
|
|
|
""" |
|
|
for field_name, vocab in name_vocab.items(): |
|
|
for field_name, vocab in name_vocab.items(): |
|
|
for ins in self: |
|
|
for ins in self: |
|
|
vocab.update(ins[field_name].contents()) |
|
|
vocab.update(ins[field_name].contents()) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def set_origin_len(self, origin_field, origin_len_name=None): |
|
|
def set_origin_len(self, origin_field, origin_len_name=None): |
|
|
|
|
|
"""make dataset tensor output contain origin_len field. |
|
|
|
|
|
|
|
|
|
|
|
e.g. :: |
|
|
|
|
|
|
|
|
|
|
|
# output "word_seq_origin_len", lengths based on "word_seq" field |
|
|
|
|
|
dataset.set_origin_len("word_seq") |
|
|
|
|
|
""" |
|
|
if origin_field is None: |
|
|
if origin_field is None: |
|
|
self.origin_len = None |
|
|
self.origin_len = None |
|
|
else: |
|
|
else: |
|
|