diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2b2b2b35 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.gitignore + +.DS_Store +.ipynb_checkpoints +*.pyc +__pycache__ +*.swp +.vscode/ +.idea/** + +caches + +# fitlog +.fitlog +logs/ +.fitconfig diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f994bd31..9ccdaf67 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -476,7 +476,7 @@ class SpanFPreRecMetric(MetricBase): label的f1, pre, rec :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) - :param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. + :param float beta: f_beta分数, :math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ @@ -708,7 +708,7 @@ class SQuADMetric(MetricBase): :param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` :param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` :param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` - :param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. + :param float beta: f_beta分数, :math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 :param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 diff --git a/reproduction/README.md b/reproduction/README.md index 8d14d36d..bb21c067 100644 --- a/reproduction/README.md +++ b/reproduction/README.md @@ -2,43 +2,28 @@ 这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 复现的模型有: -- Star-Transformer +- [Star-Transformer](Star_transformer/) - ... +# 任务复现 +## Text Classification (文本分类) +- still in progress + + +## Matching (自然语言推理/句子匹配) +- still in progress + + +## Sequence Labeling (序列标注) +- still in progress + + +## Coreference resolution (指代消解) +- still in progress + + +## Summarization (摘要) +- still in progress -## Star-Transformer -[reference](https://arxiv.org/abs/1902.09113) -### Performance (still in progress) -|任务| 数据集 | SOTA | 模型表现 | -|------|------| ------| ------| -|Pos Tagging|CTB 9.0|-|ACC 92.31| -|Pos Tagging|CONLL 2012|-|ACC 96.51| -|Named Entity Recognition|CONLL 2012|-|F1 85.66| -|Text Classification|SST|-|49.18| -|Natural Language Inference|SNLI|-|83.76| - -### Usage -``` python -# for sequence labeling(ner, pos tagging, etc) -from fastNLP.models.star_transformer import STSeqLabel -model = STSeqLabel( - vocab_size=10000, num_cls=50, - emb_dim=300) - - -# for sequence classification -from fastNLP.models.star_transformer import STSeqCls -model = STSeqCls( - vocab_size=10000, num_cls=50, - emb_dim=300) - - -# for natural language inference -from fastNLP.models.star_transformer import STNLICls -model = STNLICls( - vocab_size=10000, num_cls=50, - emb_dim=300) - -``` ## ... diff --git a/reproduction/Star_transformer/README.md b/reproduction/Star_transformer/README.md new file mode 100644 index 00000000..37c5f1e9 --- /dev/null +++ b/reproduction/Star_transformer/README.md @@ -0,0 +1,34 @@ +# Star-Transformer +paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) +## Performance (still in progress) +|任务| 数据集 | SOTA | 模型表现 | +|------|------| ------| ------| +|Pos Tagging|CTB 9.0|-|ACC 92.31| +|Pos Tagging|CONLL 2012|-|ACC 96.51| +|Named Entity Recognition|CONLL 2012|-|F1 85.66| +|Text Classification|SST|-|49.18| +|Natural Language Inference|SNLI|-|83.76| + +## Usage +``` python +# for sequence labeling(ner, pos tagging, etc) +from fastNLP.models.star_transformer import STSeqLabel +model = STSeqLabel( + vocab_size=10000, num_cls=50, + emb_dim=300) + + +# for sequence classification +from fastNLP.models.star_transformer import STSeqCls +model = STSeqCls( + vocab_size=10000, num_cls=50, + emb_dim=300) + + +# for natural language inference +from fastNLP.models.star_transformer import STNLICls +model = STNLICls( + vocab_size=10000, num_cls=50, + emb_dim=300) + +``` diff --git a/reproduction/matching/test/test_snlidataloader.py b/reproduction/matching/test/test_snlidataloader.py index 9a0fb9ee..bd5c58b6 100644 --- a/reproduction/matching/test/test_snlidataloader.py +++ b/reproduction/matching/test/test_snlidataloader.py @@ -1,6 +1,6 @@ import unittest -from reproduction.matching.data import SNLIDataLoader -from fastNLP.core.vocabulary import VocabularyOption +from ..data import SNLIDataLoader +from fastNLP.core.vocabulary import Vocabulary class TestCWSDataLoader(unittest.TestCase): diff --git a/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py b/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py index 0b9cb633..f4260849 100644 --- a/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py +++ b/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py @@ -1,7 +1,7 @@ import unittest -from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader +from ..data.CWSDataLoader import SigHanLoader from fastNLP.core.vocabulary import VocabularyOption