Browse Source

merge update

tags/v0.4.10
yh_cc 6 years ago
parent
commit
bddce51b05
3 changed files with 7 additions and 6 deletions
  1. +5
    -3
      fastNLP/core/metrics.py
  2. +1
    -2
      reproduction/seqence_labelling/cws/model/module.py
  3. +1
    -1
      reproduction/seqence_labelling/cws/train_shift_relay.py

+ 5
- 3
fastNLP/core/metrics.py View File

@@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import seq_len_to_mask from .utils import seq_len_to_mask
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from abc import abstractmethod


class MetricBase(object): class MetricBase(object):
""" """
@@ -117,10 +117,12 @@ class MetricBase(object):
def __init__(self): def __init__(self):
self.param_map = {} # key is param in function, value is input param. self.param_map = {} # key is param in function, value is input param.
self._checked = False self._checked = False

@abstractmethod
def evaluate(self, *args, **kwargs): def evaluate(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError

@abstractmethod
def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented


+ 1
- 2
reproduction/seqence_labelling/cws/model/module.py View File

@@ -1,11 +1,10 @@
from torch import nn from torch import nn
import torch import torch
from fastNLP.modules import Embedding
import numpy as np import numpy as np


class SemiCRFShiftRelay(nn.Module): class SemiCRFShiftRelay(nn.Module):
""" """
该模块是一个decoder,但
该模块是一个decoder,但当前不支持含有tag的decode。


""" """
def __init__(self, L): def __init__(self, L):


+ 1
- 1
reproduction/seqence_labelling/cws/train_shift_relay.py View File

@@ -32,7 +32,7 @@ lr = 0.02
#########hyper #########hyper
device = 0 device = 0


# !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到
# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到
# 你们的reproduction路径下,然后设置.gitignore # 你们的reproduction路径下,然后设置.gitignore
file_dir = '/path/to/pku' file_dir = '/path/to/pku'
char_embed_path = '/path/to/1grams_t3_m50_corpus.txt' char_embed_path = '/path/to/1grams_t3_m50_corpus.txt'


Loading…
Cancel
Save