From 902a3a6bcd0b2f5667fdab387f9f550d1b8068bc Mon Sep 17 00:00:00 2001 From: ChenXin Date: Tue, 14 May 2019 16:48:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 46 ++++++++++++++++--------------------- fastNLP/core/utils.py | 5 ++-- test/models/model_runner.py | 1 + 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 7228842f..b506dfae 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -212,39 +212,33 @@ target和input,这种情况下,fastNLP默认不进行pad。另外,当某个field已经被设置为了target或者input后,之后append的 instance对应的field必须要和前面已有的内容一致,否则会报错。 - 可以查看field的dtype - - Example:: + 可以查看field的dtype:: - from fastNLP import DataSet + from fastNLP import DataSet - d = DataSet({'a': [0, 1, 3], 'b':[[1.0, 2.0], [0.1, 0.2], [3]]}) - d.set_input('a', 'b') - d.a.dtype - >> numpy.int64 - d.b.dtype - >> numpy.float64 - # 默认情况下'a'这个field将被转换为torch.LongTensor,但如果需要其为torch.FloatTensor可以手动修改dtype - d.a.dtype = float # 请确保该field的确可以全部转换为float。 + d = DataSet({'a': [0, 1, 3], 'b':[[1.0, 2.0], [0.1, 0.2], [3]]}) + d.set_input('a', 'b') + d.a.dtype + >> numpy.int64 + d.b.dtype + >> numpy.float64 + # 默认情况下'a'这个field将被转换为torch.LongTensor,但如果需要其为torch.FloatTensor可以手动修改dtype + d.a.dtype = float # 请确保该field的确可以全部转换为float。 如果某个field中出现了多种类型混合(比如一部分为str,一部分为int)的情况,fastNLP无法判断该field的类型,会报如下的 - 错误: - - Example:: + 错误:: - from fastNLP import DataSet - d = DataSet({'data': [1, 'a']}) - d.set_input('data') - >> RuntimeError: Mixed data types in Field data: [, ] - - 可以通过设置以忽略对该field进行类型检查 + from fastNLP import DataSet + d = DataSet({'data': [1, 'a']}) + d.set_input('data') + >> RuntimeError: Mixed data types in Field data: [, ] - Example:: + 可以通过设置以忽略对该field进行类型检查:: - from fastNLP import DataSet - d = DataSet({'data': [1, 'a']}) - d.set_ignore_type('data') - d.set_input('data') + from fastNLP import DataSet + d = DataSet({'data': [1, 'a']}) + d.set_ignore_type('data') + d.set_input('data') 当某个field被设置为忽略type之后,fastNLP将不对其进行pad。 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index f7539fd7..a9a7ac0c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -35,9 +35,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): """ 别名::class:`fastNLP.cache_results` :class:`fastNLP.core.uitls.cache_results` - cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用 - - Example:: + cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: import time import numpy as np @@ -607,6 +605,7 @@ def seq_len_to_mask(seq_len): 转变 1-d seq_len到2-d mask. Example:: + >>> seq_len = torch.arange(2, 16) >>> mask = seq_len_to_mask(seq_len) >>> print(mask.size()) diff --git a/test/models/model_runner.py b/test/models/model_runner.py index 3f4e1200..405aa7d6 100644 --- a/test/models/model_runner.py +++ b/test/models/model_runner.py @@ -6,6 +6,7 @@ 此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 Example:: + # import 全大写变量... from model_runner import *