Browse Source

修改bucket sampler, 增加url下载功能

tags/v0.2.0
yh 5 years ago
parent
commit
b899b1edd8
3 changed files with 161 additions and 7 deletions
  1. +20
    -4
      fastNLP/api/api.py
  2. +138
    -0
      fastNLP/api/model_zoo.py
  3. +3
    -3
      fastNLP/core/sampler.py

+ 20
- 4
fastNLP/api/api.py View File

@@ -5,17 +5,25 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.predictor import Predictor

from fastNLP.api.model_zoo import load_url

model_urls = {
'cws': "",

}


class API:
def __init__(self):
self.pipeline = None
self.model = None

def predict(self, *args, **kwargs):
raise NotImplementedError

def load(self, name):
_dict = torch.load(name)
def load(self, path):


_dict = torch.load(path)
self.pipeline = _dict['pipeline']


@@ -61,8 +69,13 @@ class POS_tagger(API):


class CWS(API):
def __init__(self, model_path='xxx'):
def __init__(self, model_path=None, pretrain=True):
super(CWS, self).__init__()
# 1. 这里修改为检查
if model_path is None:
model_path = model_urls['cws']


self.load(model_path)

def predict(self, sentence, pretrain=False):
@@ -94,3 +107,6 @@ class CWS(API):
if __name__ == "__main__":
tagger = POS_tagger()
print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]]))

from torchvision import models
models.resnet18()

+ 138
- 0
fastNLP/api/model_zoo.py View File

@@ -0,0 +1,138 @@
import torch

import hashlib
import os
import re
import shutil
import sys
import tempfile

try:
from requests.utils import urlparse
from requests import get as urlopen
requests_available = True
except ImportError:
requests_available = False
if sys.version_info[0] == 2:
from urlparse import urlparse # noqa f811
from urllib2 import urlopen # noqa f811
else:
from urllib.request import urlopen
from urllib.parse import urlparse
try:
from tqdm import tqdm
except ImportError:
tqdm = None # defined below

# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')


def load_url(url, model_dir=None, map_location=None, progress=True):
r"""Loads the Torch serialized object at the given URL.

If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.

The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overridden with the ``$TORCH_MODEL_ZOO`` environment variable.

Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr

Example:
# >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP'))
model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models'))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
# hash_prefix = HASH_REGEX.search(filename).group(1)
_download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return torch.load(cached_file, map_location=map_location)


def _download_url_to_file(url, dst, hash_prefix, progress):
if requests_available:
u = urlopen(url, stream=True)
file_size = int(u.headers["Content-Length"])
u = u.raw
else:
u = urlopen(url)
meta = u.info()
if hasattr(meta, 'getheaders'):
file_size = int(meta.getheaders("Content-Length")[0])
else:
file_size = int(meta.get_all("Content-Length")[0])

f = tempfile.NamedTemporaryFile(delete=False)
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))

f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
.format(hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)


if tqdm is None:
# fake tqdm if it's not installed
class tqdm(object):

def __init__(self, total, disable=False):
self.total = total
self.disable = disable
self.n = 0

def update(self, n):
if self.disable:
return

self.n += n
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
sys.stderr.flush()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.disable:
return

sys.stderr.write('\n')


if __name__ == '__main__':
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context.pkl', model_dir='.')
print(type(pipeline))

+ 3
- 3
fastNLP/core/sampler.py View File

@@ -45,14 +45,14 @@ class RandomSampler(BaseSampler):

class BucketSampler(BaseSampler):

def __init__(self, num_buckets=10, batch_size=32):
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'):
self.num_buckets = num_buckets
self.batch_size = batch_size
self.seq_lens_field_name = seq_lens_field_name

def __call__(self, data_set):
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now."

seq_lens = data_set['seq_lens'].content
seq_lens = data_set[self.seq_lens_field_name].content
total_sample_num = len(seq_lens)

bucket_indexes = []


Loading…
Cancel
Save