diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 823e0ee0..4198fd2b 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -5,17 +5,25 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.predictor import Predictor +from fastNLP.api.model_zoo import load_url + +model_urls = { + 'cws': "", + +} + class API: def __init__(self): self.pipeline = None - self.model = None def predict(self, *args, **kwargs): raise NotImplementedError - def load(self, name): - _dict = torch.load(name) + def load(self, path): + + + _dict = torch.load(path) self.pipeline = _dict['pipeline'] @@ -61,8 +69,13 @@ class POS_tagger(API): class CWS(API): - def __init__(self, model_path='xxx'): + def __init__(self, model_path=None, pretrain=True): super(CWS, self).__init__() + # 1. 这里修改为检查 + if model_path is None: + model_path = model_urls['cws'] + + self.load(model_path) def predict(self, sentence, pretrain=False): @@ -94,3 +107,6 @@ class CWS(API): if __name__ == "__main__": tagger = POS_tagger() print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) + + from torchvision import models + models.resnet18() diff --git a/fastNLP/api/model_zoo.py b/fastNLP/api/model_zoo.py new file mode 100644 index 00000000..fcfc966e --- /dev/null +++ b/fastNLP/api/model_zoo.py @@ -0,0 +1,138 @@ +import torch + +import hashlib +import os +import re +import shutil +import sys +import tempfile + +try: + from requests.utils import urlparse + from requests import get as urlopen + requests_available = True +except ImportError: + requests_available = False + if sys.version_info[0] == 2: + from urlparse import urlparse # noqa f811 + from urllib2 import urlopen # noqa f811 + else: + from urllib.request import urlopen + from urllib.parse import urlparse +try: + from tqdm import tqdm +except ImportError: + tqdm = None # defined below + +# matches bfd8deac from resnet18-bfd8deac.pth +HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') + + +def load_url(url, model_dir=None, map_location=None, progress=True): + r"""Loads the Torch serialized object at the given URL. + + If the object is already present in `model_dir`, it's deserialized and + returned. The filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + + The default value of `model_dir` is ``$TORCH_HOME/models`` where + ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be + overridden with the ``$TORCH_MODEL_ZOO`` environment variable. + + Args: + url (string): URL of the object to download + model_dir (string, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar to stderr + + Example: + # >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') + + """ + if model_dir is None: + torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP')) + model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models')) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + # hash_prefix = HASH_REGEX.search(filename).group(1) + _download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return torch.load(cached_file, map_location=map_location) + + +def _download_url_to_file(url, dst, hash_prefix, progress): + if requests_available: + u = urlopen(url, stream=True) + file_size = int(u.headers["Content-Length"]) + u = u.raw + else: + u = urlopen(url) + meta = u.info() + if hasattr(meta, 'getheaders'): + file_size = int(meta.getheaders("Content-Length")[0]) + else: + file_size = int(meta.get_all("Content-Length")[0]) + + f = tempfile.NamedTemporaryFile(delete=False) + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm(total=file_size, disable=not progress) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError('invalid hash value (expected "{}", got "{}")' + .format(hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +if tqdm is None: + # fake tqdm if it's not installed + class tqdm(object): + + def __init__(self, total, disable=False): + self.total = total + self.disable = disable + self.n = 0 + + def update(self, n): + if self.disable: + return + + self.n += n + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) + sys.stderr.flush() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return + + sys.stderr.write('\n') + + +if __name__ == '__main__': + pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context.pkl', model_dir='.') + print(type(pipeline)) diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 652bc97e..6ba2f4d3 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -45,14 +45,14 @@ class RandomSampler(BaseSampler): class BucketSampler(BaseSampler): - def __init__(self, num_buckets=10, batch_size=32): + def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'): self.num_buckets = num_buckets self.batch_size = batch_size + self.seq_lens_field_name = seq_lens_field_name def __call__(self, data_set): - assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now." - seq_lens = data_set['seq_lens'].content + seq_lens = data_set[self.seq_lens_field_name].content total_sample_num = len(seq_lens) bucket_indexes = []