You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imdb.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. imdb dataset parser.
  17. """
  18. import os
  19. from itertools import chain
  20. import gensim
  21. import numpy as np
  22. class ImdbParser():
  23. """
  24. parse aclImdb data to features and labels.
  25. sentence->tokenized->encoded->padding->features
  26. """
  27. def __init__(self, imdb_path, glove_path, embed_size=300):
  28. self.__segs = ['train', 'test']
  29. self.__label_dic = {'pos': 1, 'neg': 0}
  30. self.__imdb_path = imdb_path
  31. self.__glove_dim = embed_size
  32. self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt')
  33. # properties
  34. self.__imdb_datas = {}
  35. self.__features = {}
  36. self.__labels = {}
  37. self.__vacab = {}
  38. self.__word2idx = {}
  39. self.__weight_np = {}
  40. self.__wvmodel = None
  41. def parse(self):
  42. """
  43. parse imdb data to memory
  44. """
  45. self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file)
  46. for seg in self.__segs:
  47. self.__parse_imdb_datas(seg)
  48. self.__parse_features_and_labels(seg)
  49. self.__gen_weight_np(seg)
  50. def __parse_imdb_datas(self, seg):
  51. """
  52. load data from txt
  53. """
  54. data_lists = []
  55. for label_name, label_id in self.__label_dic.items():
  56. sentence_dir = os.path.join(self.__imdb_path, seg, label_name)
  57. for file in os.listdir(sentence_dir):
  58. with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f:
  59. sentence = f.read().replace('\n', '')
  60. data_lists.append([sentence, label_id])
  61. self.__imdb_datas[seg] = data_lists
  62. def __parse_features_and_labels(self, seg):
  63. """
  64. parse features and labels
  65. """
  66. features = []
  67. labels = []
  68. for sentence, label in self.__imdb_datas[seg]:
  69. features.append(sentence)
  70. labels.append(label)
  71. self.__features[seg] = features
  72. self.__labels[seg] = labels
  73. # update feature to tokenized
  74. self.__updata_features_to_tokenized(seg)
  75. # parse vacab
  76. self.__parse_vacab(seg)
  77. # encode feature
  78. self.__encode_features(seg)
  79. # padding feature
  80. self.__padding_features(seg)
  81. def __updata_features_to_tokenized(self, seg):
  82. tokenized_features = []
  83. for sentence in self.__features[seg]:
  84. tokenized_sentence = [word.lower() for word in sentence.split(" ")]
  85. tokenized_features.append(tokenized_sentence)
  86. self.__features[seg] = tokenized_features
  87. def __parse_vacab(self, seg):
  88. # vocab
  89. tokenized_features = self.__features[seg]
  90. vocab = set(chain(*tokenized_features))
  91. self.__vacab[seg] = vocab
  92. # word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0}
  93. word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
  94. word_to_idx['<unk>'] = 0
  95. self.__word2idx[seg] = word_to_idx
  96. def __encode_features(self, seg):
  97. """ encode word to index """
  98. word_to_idx = self.__word2idx['train']
  99. encoded_features = []
  100. for tokenized_sentence in self.__features[seg]:
  101. encoded_sentence = []
  102. for word in tokenized_sentence:
  103. encoded_sentence.append(word_to_idx.get(word, 0))
  104. encoded_features.append(encoded_sentence)
  105. self.__features[seg] = encoded_features
  106. def __padding_features(self, seg, maxlen=500, pad=0):
  107. """ pad all features to the same length """
  108. padded_features = []
  109. for feature in self.__features[seg]:
  110. if len(feature) >= maxlen:
  111. padded_feature = feature[:maxlen]
  112. else:
  113. padded_feature = feature
  114. while len(padded_feature) < maxlen:
  115. padded_feature.append(pad)
  116. padded_features.append(padded_feature)
  117. self.__features[seg] = padded_features
  118. def __gen_weight_np(self, seg):
  119. """
  120. generate weight by gensim
  121. """
  122. weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32)
  123. for word, idx in self.__word2idx[seg].items():
  124. if word not in self.__wvmodel:
  125. continue
  126. word_vector = self.__wvmodel.get_vector(word)
  127. weight_np[idx, :] = word_vector
  128. self.__weight_np[seg] = weight_np
  129. def get_datas(self, seg):
  130. """
  131. return features, labels, and weight
  132. """
  133. features = np.array(self.__features[seg]).astype(np.int32)
  134. labels = np.array(self.__labels[seg]).astype(np.int32)
  135. weight = np.array(self.__weight_np[seg])
  136. return features, labels, weight