diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 38658bcf..f5bce312 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -19,7 +19,9 @@ from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SeqLabelEvaluator2 from fastNLP.core.tester import Tester +# TODO add pretrain urls model_urls = { + } @@ -182,8 +184,6 @@ class CWS(API): return f1, pre, rec -<<<<<<< HEAD -======= class Parser(API): def __init__(self, model_path=None, device='cpu'): super(Parser, self).__init__() @@ -250,7 +250,6 @@ class Parser(API): return uas ->>>>>>> b182b39... * fixing unit tests class Analyzer: def __init__(self, seg=True, pos=True, parser=True, device='cpu'): @@ -265,13 +264,9 @@ class Analyzer: if parser: self.parser = None -<<<<<<< HEAD - def predict(self, content): -======= def predict(self, content, seg=False, pos=False, parser=False): if seg is False and pos is False and parser is False: seg = True ->>>>>>> b182b39... * fixing unit tests output_dict = {} if self.seg: seg_output = self.cws.predict(content) @@ -310,11 +305,6 @@ if __name__ == "__main__": # print(pos.predict(s)) # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' -<<<<<<< HEAD - cws = CWS(device='cpu') - s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , - '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', -======= # cws = CWS(device='cpu') # s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', @@ -326,7 +316,6 @@ if __name__ == "__main__": # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', ->>>>>>> b182b39... * fixing unit tests '那么这款无人机到底有多厉害?'] print(cws.test('/Users/yh/Desktop/test_data/small_test.conll')) print(cws.predict(s)) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 26602dc9..10d8cfab 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -398,55 +398,3 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): if _error_strs: raise ValueError('\n' + '\n'.join(_error_strs)) - - -if __name__ == '__main__': - import torch - from torch import nn - from fastNLP.core.dataset import DataSet - import numpy as np - - class Model(nn.Module): - def __init__(self): - super().__init__() - - self.fc1 = nn.Linear(10, 2) - - def forward(self, words, chars): - output = {} - output['prediction'] = torch.randn(3, 4) - # output['words'] = words - return output - - def get_loss(self, prediction, labels, words): - return torch.mean(self.fc1.weight) - - def evaluate(self, prediction, labels, demo=2): - return {} - - - model = Model() - - num_samples = 4 - fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6), - 'labels': np.random.randint(2, size=(num_samples,)), 'seq_lens': [1, 3, 4, 6]} - - - dataset = DataSet(fake_data_dict) - dataset.set_input(words=True, chars=True) - dataset.set_target(labels=True, words=True) - - # trainer = Trainer(dataset, model) - - _check_code(dataset=dataset, model=model, dev_data=dataset, check_level=1) - - # _check_forward_error(model=model, model_func=model.forward, check_level=1, - # batch_x=fake_data_dict) - - # import inspect - # print(inspect.getfullargspec(model.forward)) - - import pandas - df = pandas.DataFrame({'a':0}) - -