|
- """
- /**
- * Copyright 2020 Tianshu AI Platform. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================
- */
- """
- from entrance.executor import classify_by_textcnn as classify
-
-
- def _init():
- print('init classify_obj')
- global classify_obj
- classify_obj = classify.TextCNNClassifier() # label_log
-
-
- def _classification(text_path_list, id_list, label_list):
- """Perform automatic text classification task."""
- textnum = len(text_path_list)
- batched_num = ((textnum - 1) // classify.BATCH_SIZE + 1) * classify.BATCH_SIZE
- for i in range(batched_num - textnum):
- text_path_list.append(text_path_list[0])
- id_list.append(id_list[0])
- annotations = classify_obj.inference(text_path_list, id_list, label_list) #
- return annotations[0:textnum]
-
-
- if __name__ == "__main__":
- test_len = 22
- _init()
- ans = _classification(["dubhe-dev/dataset/2738/origin/32_3_ts1607326726114630.txt"] * test_len, [1] * test_len,
- [111, 112])
- print(ans)
- print(len(ans))
|