You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

movielens.py 4.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import wget
  3. import zipfile
  4. from collections import defaultdict as dd
  5. import numpy as np
  6. import scipy.sparse as sp
  7. from tqdm import tqdm
  8. DATASETS = ["ml-1m", "ml-20m", "ml-25m"]
  9. urls = {
  10. "ml-1m": "https://files.grouplens.org/datasets/movielens/ml-1m.zip",
  11. "ml-20m": "https://files.grouplens.org/datasets/movielens/ml-20m.zip",
  12. "ml-25m": "https://files.grouplens.org/datasets/movielens/ml-25m.zip",
  13. }
  14. def download(dataset, data_dir, num_negatives=4):
  15. if not os.path.exists(data_dir):
  16. os.mkdir(data_dir)
  17. assert dataset in ["ml-1m", "ml-20m",
  18. "ml-25m"], 'Invalid dataset: %s.' % dataset
  19. data_subdir = os.path.join(data_dir, dataset)
  20. print('Data in', data_subdir)
  21. zip_file = os.path.join(data_dir, dataset + '.zip')
  22. ratings = os.path.join(data_subdir, 'ratings.csv')
  23. if not os.path.exists(ratings):
  24. if not os.path.exists(zip_file):
  25. print('Downloading movielens %s...' % dataset)
  26. wget.download(urls[dataset], zip_file)
  27. with zipfile.ZipFile(zip_file, 'r') as zip_ref:
  28. print('Extracting movielens %s...' % dataset)
  29. zip_ref.extractall(data_dir)
  30. ratings = os.path.join(data_subdir, 'ratings.csv')
  31. num_users, num_items = {
  32. 'ml-1m': (6040, 3706),
  33. 'ml-20m': (138493, 26744),
  34. 'ml-25m': (162541, 59047),
  35. }[dataset]
  36. # Generate raw training and testing files
  37. item_reverse_mapping = {}
  38. cur_item_idx = 0
  39. latest = [(0, -1)] * num_users
  40. mat = sp.dok_matrix((num_users, num_items), dtype=np.float32)
  41. with open(ratings, 'r') as fr:
  42. fr.readline()
  43. for line in tqdm(fr):
  44. entries = line.strip().split(',')
  45. user = int(entries[0])
  46. item = int(entries[1])
  47. if item not in item_reverse_mapping:
  48. item_reverse_mapping[item] = cur_item_idx
  49. cur_item_idx += 1
  50. rating = float(entries[2])
  51. if rating <= 0:
  52. continue
  53. reitem = item_reverse_mapping[item]
  54. mat[user-1, reitem] = 1
  55. timestamp = int(entries[-1])
  56. if latest[user-1][0] < timestamp:
  57. latest[user-1] = (timestamp, reitem)
  58. print('#users:', num_users, '#items:', num_items)
  59. new_lates = np.concatenate((np.array(latest, dtype=np.int32)[
  60. :, 1:], np.empty((num_users, 99), dtype=np.int32)), 1)
  61. # sample for test data first, each user 99 items, using all data
  62. for i, lat in enumerate(latest):
  63. new_lates[i][0] = lat[1]
  64. for k in range(1, 100):
  65. j = np.random.randint(num_items)
  66. while (i, j) in mat.keys():
  67. j = np.random.randint(num_items)
  68. new_lates[i][k] = j
  69. np.save(os.path.join(data_subdir, 'test.npy'), new_lates)
  70. # sample for train data, each data with num_negative negative samples
  71. all_num = (1 + num_negatives) * (len(mat.keys()) - num_users)
  72. user_input = np.empty((all_num,), dtype=np.int32)
  73. item_input = np.empty((all_num,), dtype=np.int32)
  74. labels = np.empty((all_num,), dtype=np.int32)
  75. idx = 0
  76. for (i, j) in mat.keys():
  77. if new_lates[i][0] == j:
  78. continue
  79. # positive instance
  80. user_input[idx] = i
  81. item_input[idx] = j
  82. labels[idx] = 1
  83. idx += 1
  84. # negative instances
  85. for t in range(num_negatives):
  86. k = np.random.randint(num_items)
  87. while (i, k) in mat.keys():
  88. k = np.random.randint(num_items)
  89. user_input[idx] = i
  90. item_input[idx] = k
  91. labels[idx] = 0
  92. idx += 1
  93. assert all_num == idx
  94. np.savez(os.path.join(data_subdir, 'train.npz'),
  95. user_input=user_input, item_input=item_input, labels=labels)
  96. def getdata(dataset, data_dir='datasets'):
  97. assert dataset in ["ml-1m", "ml-20m",
  98. "ml-25m"], 'Invalid dataset: %s.' % dataset
  99. data_subdir = os.path.join(data_dir, dataset)
  100. file_paths = [os.path.join(data_subdir, data)
  101. for data in ['train.npz', 'test.npy']]
  102. if any([not os.path.exists(path) for path in file_paths]):
  103. download(dataset, data_dir)
  104. return np.load(file_paths[0]), np.load(file_paths[1])
  105. if __name__ == "__main__":
  106. download('ml-25m', 'datasets')