Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh 5 years ago
parent
commit
164f74e12a
2 changed files with 6 additions and 4 deletions
  1. +5
    -4
      fastNLP/io/file_reader.py
  2. +1
    -0
      requirements.txt

+ 5
- 4
fastNLP/io/file_reader.py View File

@@ -5,6 +5,7 @@
__all__ = [] __all__ = []


import json import json
import csv


from ..core import logger from ..core import logger


@@ -21,17 +22,17 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, csv item) :return: generator, every time yield (line number, csv item)
""" """
with open(path, 'r', encoding=encoding) as f:
with open(path, 'r', encoding=encoding) as csv_file:
f = csv.reader(csv_file, delimiter=sep)
start_idx = 0 start_idx = 0
if headers is None: if headers is None:
headers = f.readline().rstrip('\r\n')
headers = headers.split(sep)
headers = next(f)
start_idx += 1 start_idx += 1
elif not isinstance(headers, (list, tuple)): elif not isinstance(headers, (list, tuple)):
raise TypeError("headers should be list or tuple, not {}." \ raise TypeError("headers should be list or tuple, not {}." \
.format(type(headers))) .format(type(headers)))
for line_idx, line in enumerate(f, start_idx): for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
contents = line
if len(contents) != len(headers): if len(contents) != len(headers):
if dropna: if dropna:
continue continue


+ 1
- 0
requirements.txt View File

@@ -2,6 +2,7 @@ numpy>=1.14.2
torch>=1.0.0 torch>=1.0.0
tqdm>=4.28.1 tqdm>=4.28.1
nltk>=3.4.1 nltk>=3.4.1
prettytable>=0.7.2
requests requests
spacy spacy
prettytable>=0.7.2 prettytable>=0.7.2

Loading…
Cancel
Save