Browse Source

[MNT] unify the code format

tags/v0.3.2
bxdd 1 year ago
parent
commit
a09b619d82
30 changed files with 103 additions and 99 deletions
  1. BIN
      docs/_static/img/table_homo_labeled.png
  2. +0
    -0
      examples/__init__.py
  3. +0
    -1
      examples/dataset_image_workflow/config.py
  4. +2
    -2
      examples/dataset_image_workflow/utils.py
  5. +11
    -10
      examples/dataset_image_workflow/workflow.py
  6. +7
    -6
      examples/dataset_table_workflow/base.py
  7. +0
    -1
      examples/dataset_table_workflow/config.py
  8. +6
    -6
      examples/dataset_table_workflow/hetero.py
  9. +7
    -7
      examples/dataset_table_workflow/homo.py
  10. +2
    -2
      examples/dataset_table_workflow/methods.py
  11. +1
    -1
      examples/dataset_table_workflow/train.py
  12. +5
    -5
      examples/dataset_table_workflow/utils.py
  13. +5
    -5
      examples/dataset_table_workflow/workflow.py
  14. +0
    -1
      examples/dataset_text_workflow/config.py
  15. +9
    -8
      examples/dataset_text_workflow/workflow.py
  16. +2
    -1
      learnware/tests/benchmarks/__init__.py
  17. +1
    -1
      learnware/tests/benchmarks/config.py
  18. +1
    -0
      setup.py
  19. +4
    -4
      tests/test_function/test_search.py
  20. +3
    -3
      tests/test_learnware_client/test_all_learnware.py
  21. +1
    -1
      tests/test_learnware_client/test_check_learnware.py
  22. +1
    -0
      tests/test_learnware_client/test_container.py
  23. +1
    -0
      tests/test_learnware_client/test_load_learnware.py
  24. +2
    -2
      tests/test_learnware_client/test_upload.py
  25. +4
    -4
      tests/test_specification/test_hetero_spec.py
  26. +5
    -5
      tests/test_specification/test_image_rkme.py
  27. +4
    -4
      tests/test_specification/test_table_rkme.py
  28. +4
    -5
      tests/test_specification/test_text_rkme.py
  29. +9
    -9
      tests/test_workflow/test_hetero_workflow.py
  30. +6
    -5
      tests/test_workflow/test_workflow.py

BIN
docs/_static/img/table_homo_labeled.png View File

Before After
Width: 6927  |  Height: 4112  |  Size: 730 kB

+ 0
- 0
examples/__init__.py View File


+ 0
- 1
examples/dataset_image_workflow/config.py View File

@@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig


image_benchmark_config = BenchmarkConfig(
name="CIFAR-10",
user_num=100,


+ 2
- 2
examples/dataset_image_workflow/utils.py View File

@@ -1,6 +1,6 @@
import torch
import numpy as np
from torch import optim, nn
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from learnware.utils import choose_device


+ 11
- 10
examples/dataset_image_workflow/workflow.py View File

@@ -1,24 +1,25 @@
import os
import fire
import time
import torch
import pickle
import random
import tempfile
import numpy as np
import time

import fire
import matplotlib.pyplot as plt
import numpy as np
import torch
from config import image_benchmark_config
from model import ConvModel
from torch.utils.data import TensorDataset
from utils import evaluate, train_model

from learnware.utils import choose_device
from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
from learnware.specification import generate_stat_spec
from learnware.tests.benchmarks import LearnwareBenchmark
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
from model import ConvModel
from utils import train_model, evaluate
from config import image_benchmark_config
from learnware.utils import choose_device

logger = get_module_logger("image_workflow", level="INFO")



+ 7
- 6
examples/dataset_table_workflow/base.py View File

@@ -1,20 +1,21 @@
import os
import time
import random
import requests
import tempfile
import time
import traceback

import numpy as np
import requests
from config import market_mapping_params
from methods import loss_func_rmse, test_methods
from utils import set_seed

from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import instantiate_learnware_market
from learnware.reuse.utils import fill_data_with_mean
from learnware.tests.benchmarks import LearnwareBenchmark

from config import market_mapping_params
from methods import loss_func_rmse, test_methods
from utils import set_seed

logger = get_module_logger("base_table", level="INFO")




+ 0
- 1
examples/dataset_table_workflow/config.py View File

@@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig


homo_n_labeled_list = [100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000]
homo_n_repeat_list = [10, 10, 10, 3, 3, 3, 3, 3, 3]
hetero_n_labeled_list = [10, 30, 50, 75, 100, 200, 500, 1000, 2000]


+ 6
- 6
examples/dataset_table_workflow/hetero.py View File

@@ -2,15 +2,15 @@ import os
import warnings

import numpy as np
from base import TableWorkflow
from config import align_model_params, hetero_n_labeled_list, hetero_n_repeat_list, user_semantic
from methods import loss_func_rmse
from utils import Recorder, plot_performance_curves, set_seed

from learnware.logger import get_module_logger
from learnware.specification import generate_stat_spec
from learnware.market import BaseUserInfo
from learnware.reuse import AveragingReuser, FeatureAlignLearnware

from methods import loss_func_rmse
from base import TableWorkflow
from config import align_model_params, user_semantic, hetero_n_labeled_list, hetero_n_repeat_list
from utils import Recorder, plot_performance_curves, set_seed
from learnware.specification import generate_stat_spec

warnings.filterwarnings("ignore")
logger = get_module_logger("hetero_test", level="INFO")


+ 7
- 7
examples/dataset_table_workflow/homo.py View File

@@ -1,17 +1,17 @@
import os
import warnings
import numpy as np

from learnware.market import BaseUserInfo
from learnware.logger import get_module_logger
from learnware.specification import generate_stat_spec
from learnware.reuse import AveragingReuser, JobSelectorReuser

from methods import loss_func_rmse
import numpy as np
from base import TableWorkflow
from config import homo_n_labeled_list, homo_n_repeat_list
from methods import loss_func_rmse
from utils import Recorder, plot_performance_curves

from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo
from learnware.reuse import AveragingReuser, JobSelectorReuser
from learnware.specification import generate_stat_spec

warnings.filterwarnings("ignore")
logger = get_module_logger("homo_table", level="INFO")



+ 2
- 2
examples/dataset_table_workflow/methods.py View File

@@ -1,10 +1,10 @@
import numpy as np
from config import align_model_params
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from train import train_model

from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware
from config import align_model_params
from train import train_model


def loss_func_rmse(y_true, y_pred):


+ 1
- 1
examples/dataset_table_workflow/train.py View File

@@ -1,8 +1,8 @@
import lightgbm as lgb
from config import user_model_params
from lightgbm import early_stopping

from learnware.logger import get_module_logger
from config import user_model_params

logger = get_module_logger("train_table", level="INFO")



+ 5
- 5
examples/dataset_table_workflow/utils.py View File

@@ -1,14 +1,14 @@
import os
import json
import os
import random
from collections import defaultdict

import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import torch
from config import labels, styles

from learnware.logger import get_module_logger
from config import styles, labels

logger = get_module_logger("base_table", level="INFO")



+ 5
- 5
examples/dataset_table_workflow/workflow.py View File

@@ -1,13 +1,13 @@
import fire

from learnware.logger import get_module_logger
from homo import HomogeneousDatasetWorkflow
from hetero import HeterogeneousDatasetWorkflow
from config import (
homo_table_benchmark_config,
hetero_cross_feat_eng_benchmark_config,
hetero_cross_task_benchmark_config,
homo_table_benchmark_config,
)
from hetero import HeterogeneousDatasetWorkflow
from homo import HomogeneousDatasetWorkflow

from learnware.logger import get_module_logger

logger = get_module_logger("base_table", level="INFO")



+ 0
- 1
examples/dataset_text_workflow/config.py View File

@@ -1,6 +1,5 @@
from learnware.tests.benchmarks import BenchmarkConfig


text_benchmark_config = BenchmarkConfig(
name="20-Newsgroups",
user_num=10,


+ 9
- 8
examples/dataset_text_workflow/workflow.py View File

@@ -1,22 +1,23 @@
import os
import fire
import time
import random
import pickle
import random
import tempfile
import numpy as np
import time

import fire
import matplotlib.pyplot as plt
import numpy as np
from config import text_benchmark_config
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer

from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
from learnware.specification import RKMETextSpecification
from learnware.tests.benchmarks import LearnwareBenchmark
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
from config import text_benchmark_config

logger = get_module_logger("text_workflow", level="INFO")



+ 2
- 1
learnware/tests/benchmarks/__init__.py View File

@@ -2,10 +2,11 @@ import os
import pickle
import tempfile
import zipfile
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np

from .config import BenchmarkConfig, benchmark_configs
from ..data import GetData
from ...config import C


+ 1
- 1
learnware/tests/benchmarks/config.py View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Dict
from typing import Dict, List, Optional


@dataclass


+ 1
- 0
setup.py View File

@@ -1,4 +1,5 @@
import os

from setuptools import find_packages, setup




+ 4
- 4
tests/test_function/test_search.py View File

@@ -1,12 +1,12 @@
import logging
import os
import unittest
import tempfile
import logging
import unittest

import learnware
from learnware.learnware import Learnware
from learnware.client import LearnwareClient
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.learnware import Learnware
from learnware.market import BaseUserInfo, instantiate_learnware_market

learnware.init(logging_level=logging.WARNING)



+ 3
- 3
tests/test_learnware_client/test_all_learnware.py View File

@@ -1,11 +1,11 @@
import os
import json
import unittest
import os
import tempfile
import unittest

from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec
from learnware.market import BaseUserInfo
from learnware.specification import generate_semantic_spec


class TestAllLearnware(unittest.TestCase):


+ 1
- 1
tests/test_learnware_client/test_check_learnware.py View File

@@ -1,6 +1,6 @@
import os
import unittest
import tempfile
import unittest

from learnware.client import LearnwareClient



+ 1
- 0
tests/test_learnware_client/test_container.py View File

@@ -1,4 +1,5 @@
import unittest

import numpy as np

from learnware.client import LearnwareClient


+ 1
- 0
tests/test_learnware_client/test_load_learnware.py View File

@@ -1,5 +1,6 @@
import os
import unittest

import numpy as np

from learnware.client import LearnwareClient


+ 2
- 2
tests/test_learnware_client/test_upload.py View File

@@ -1,7 +1,7 @@
import os
import json
import unittest
import os
import tempfile
import unittest

from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec


+ 4
- 4
tests/test_specification/test_hetero_spec.py View File

@@ -1,12 +1,12 @@
import os
import json
import unittest
import os
import tempfile
import unittest

import numpy as np

from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification
from learnware.specification import generate_stat_spec
from learnware.market.heterogeneous.organizer import HeteroMap
from learnware.specification import HeteroMapTableSpecification, RKMETableSpecification, generate_stat_spec


class TestTableRKME(unittest.TestCase):


+ 5
- 5
tests/test_specification/test_image_rkme.py View File

@@ -1,12 +1,12 @@
import os
import json
import torch
import unittest
import os
import tempfile
import unittest

import numpy as np
import torch

from learnware.specification import RKMEImageSpecification
from learnware.specification import generate_stat_spec
from learnware.specification import RKMEImageSpecification, generate_stat_spec


class TestImageRKME(unittest.TestCase):


+ 4
- 4
tests/test_specification/test_table_rkme.py View File

@@ -1,11 +1,11 @@
import os
import json
import unittest
import os
import tempfile
import unittest

import numpy as np

from learnware.specification import RKMETableSpecification
from learnware.specification import generate_stat_spec
from learnware.specification import RKMETableSpecification, generate_stat_spec


class TestTableRKME(unittest.TestCase):


+ 4
- 5
tests/test_specification/test_text_rkme.py View File

@@ -1,12 +1,11 @@
import os
import json
import string
import os
import random
import unittest
import string
import tempfile
import unittest

from learnware.specification import RKMETextSpecification
from learnware.specification import generate_stat_spec
from learnware.specification import RKMETextSpecification, generate_stat_spec


class TestTextRKME(unittest.TestCase):


+ 9
- 9
tests/test_workflow/test_hetero_workflow.py View File

@@ -1,22 +1,22 @@
import torch
import pickle
import unittest
import os
import logging
import os
import pickle
import tempfile
import unittest
import zipfile
from sklearn.linear_model import Ridge

import torch
from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list
from sklearn.datasets import make_regression
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error

import learnware
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate

from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list

learnware.init(logging_level=logging.WARNING)
curr_root = os.path.dirname(os.path.abspath(__file__))



+ 6
- 5
tests/test_workflow/test_workflow.py View File

@@ -1,18 +1,19 @@
import unittest
import os
import logging
import tempfile
import os
import pickle
import tempfile
import unittest
import zipfile

import numpy as np
from sklearn import svm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

import learnware
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market import BaseUserInfo, instantiate_learnware_market
from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, JobSelectorReuser
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate

learnware.init(logging_level=logging.WARNING)


Loading…
Cancel
Save