From fc70bfe44990592bd8bd53174875ea0579c561dd Mon Sep 17 00:00:00 2001 From: Xu Yige Date: Sun, 1 Jul 2018 19:08:51 +0800 Subject: [PATCH] Add files via upload --- fastNLP/loader/config_loader.py | 31 +++++++++++++++- fastNLP/loader/dataset_loader.py | 62 ++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index fa1d446d..4bac49bb 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -1,4 +1,8 @@ -from loader.base_loader import BaseLoader +from fastNLP.loader.base_loader import BaseLoader + +import configparser +import traceback +import json class ConfigLoader(BaseLoader): @@ -11,3 +15,28 @@ class ConfigLoader(BaseLoader): @staticmethod def parse(string): raise NotImplementedError + + @staticmethod + def loadConfig(filePath, sections): + """ + :param filePath: the path of config file + :param sections: the dict of sections + :return: + """ + cfg = configparser.ConfigParser() + cfg.read(filePath) + for s in sections: + attr_list = [i for i in type(sections[s]).__dict__.keys() if + not callable(getattr(sections[s], i)) and not i.startswith("__")] + gen_sec = cfg[s] + for attr in attr_list: + try: + val = json.loads(gen_sec[attr]) + print(s, attr, val, type(val)) + assert type(val) == type(getattr(sections[s], attr)), \ + 'type not match, except %s but got %s' % \ + (type(getattr(sections[s], attr)), type(val)) + setattr(sections[s], attr, val) + except Exception as e: + traceback.print_exc() + raise ValueError('something wrong in "%s" entry' % attr) diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index f692b22d..0cec50e5 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -1,4 +1,5 @@ from fastNLP.loader.base_loader import BaseLoader +import os class DatasetLoader(BaseLoader): @@ -8,6 +9,67 @@ class DatasetLoader(BaseLoader): super(DatasetLoader, self).__init__(data_name, data_path) +class POSDatasetLoader(DatasetLoader): + """loader for pos data sets""" + + def __init__(self, data_name, data_path): + super(POSDatasetLoader, self).__init__(data_name, data_path) + #self.data_set = self.load() + + + def load(self): + assert os.path.exists(self.data_path) + with open(self.data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + return self.parse(lines) + + @staticmethod + def parse(lines): + """ + :param lines: lines from dataset + :return: list(list(list())): the three level of lists are + token, sentence, and dataset + """ + dataset = list() + for line in lines: + sentence = list() + words = line.split(" ") + for w in words: + tokens = list() + tokens.append(w.split('/')[0]) + tokens.append(w.split('/')[1]) + sentence.append(tokens) + dataset.append(sentence) + return dataset + +class ClassficationDatasetLoader(DatasetLoader): + """loader for classfication data sets""" + + def __init__(self, data_name, data_path): + super(ClassficationDatasetLoader, data_name) + + def load(self): + assert os.path.exists(self.data_path) + with open(self.data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + return self.parse(lines) + + @staticmethod + def parse(lines): + """ + :param lines: lines from dataset + :return: list(list(list())): the three level of lists are + words, sentence, and dataset + """ + dataset = list() + for line in lines: + label = line.split(" ")[0] + words = line.split(" ")[1:] + word = list([w for w in words]) + sentence = list([word, label]) + dataset.append(sentence) + return dataset + class ConllLoader(DatasetLoader): """loader for conll format files"""