|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Data operations, will be used in train.py and eval.py
- """
- import os
-
- import numpy as np
-
- from imdb import ImdbParser
- import mindspore.dataset as ds
- from mindspore.mindrecord import FileWriter
-
-
- def create_dataset(data_home, batch_size, repeat_num=1, training=True):
- """Data operations."""
- ds.config.set_seed(1)
- data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
- if not training:
- data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
-
- data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4)
-
- # apply map operations on images
- data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
- data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
- data_set = data_set.repeat(count=repeat_num)
-
- return data_set
-
-
- def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
- """
- convert imdb dataset to mindrecoed dataset
- """
- if weight_np is not None:
- np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)
-
- # write mindrecord
- schema_json = {"id": {"type": "int32"},
- "label": {"type": "int32"},
- "feature": {"type": "int32", "shape": [-1]}}
-
- data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
- if not training:
- data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")
-
- def get_imdb_data(features, labels):
- data_list = []
- for i, (label, feature) in enumerate(zip(labels, features)):
- data_json = {"id": i,
- "label": int(label),
- "feature": feature.reshape(-1)}
- data_list.append(data_json)
- return data_list
-
- writer = FileWriter(data_dir, shard_num=4)
- data = get_imdb_data(features, labels)
- writer.add_schema(schema_json, "nlp_schema")
- writer.add_index(["id", "label"])
- writer.write_raw_data(data)
- writer.commit()
-
-
- def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path):
- """
- convert imdb dataset to mindrecoed dataset
- """
- parser = ImdbParser(aclimdb_path, glove_path, embed_size)
- parser.parse()
-
- if not os.path.exists(preprocess_path):
- print(f"preprocess path {preprocess_path} is not exist")
- os.makedirs(preprocess_path)
-
- train_features, train_labels, train_weight_np = parser.get_datas('train')
- _convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np)
-
- test_features, test_labels, _ = parser.get_datas('test')
- _convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False)
|