Browse Source

Merge pull request #154 from JoeyHwong-gk/lifelong

[Lifelong example]: fix the problem from backend and constant
tags/v0.3.1
KubeEdge Bot GitHub 4 years ago
parent
commit
6ddf750653
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions
  1. +1
    -1
      lib/sedna/algorithms/multi_task_learning/multi_task_learning.py
  2. +10
    -9
      lib/sedna/backend/base.py
  3. +2
    -2
      lib/sedna/core/lifelong_learning/lifelong_learning.py

+ 1
- 1
lib/sedna/algorithms/multi_task_learning/multi_task_learning.py View File

@@ -286,7 +286,7 @@ class MulTaskLearning:
**kwargs):
from sklearn import metrics as sk_metrics

result, tasks = self.predict(data, kwargs=kwargs)
result, tasks = self.predict(data, **kwargs)
m_dict = {}
if metrics:
if callable(metrics): # if metrics is a function


+ 10
- 9
lib/sedna/backend/base.py View File

@@ -49,28 +49,29 @@ class BackendBase:
return kwargs
return {k: v for k, v in kwargs.items() if k in need_kw.args}

def train(self, **kwargs):
def train(self, *args, **kwargs):
"""Train model."""
if callable(self.estimator):
varkw = self.parse_kwargs(self.estimator, **kwargs)
self.estimator = self.estimator(**varkw)
varkw = self.parse_kwargs(self.estimator.train, **kwargs)
return self.estimator.train(**varkw)
fit_method = getattr(self.estimator, "fit", self.estimator.train)
varkw = self.parse_kwargs(fit_method, **kwargs)
return fit_method(*args, **varkw)

def predict(self, **kwargs):
def predict(self, *args, **kwargs):
"""Inference model."""
varkw = self.parse_kwargs(self.estimator.predict, **kwargs)
return self.estimator.predict(**varkw)
return self.estimator.predict(*args, **varkw)

def predict_proba(self, **kwargs):
def predict_proba(self, *args, **kwargs):
"""Compute probabilities of possible outcomes for samples in X."""
varkw = self.parse_kwargs(self.estimator.predict_proba, **kwargs)
return self.estimator.predict_proba(**varkw)
return self.estimator.predict_proba(*args, **varkw)

def evaluate(self, **kwargs):
def evaluate(self, *args, **kwargs):
"""evaluate model."""
varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs)
return self.estimator.evaluate(**varkw)
return self.estimator.evaluate(*args, **varkw)

def save(self, model_url="", model_name=None):
mname = model_name or self.model_name


+ 2
- 2
lib/sedna/core/lifelong_learning/lifelong_learning.py View File

@@ -65,7 +65,7 @@ class LifelongLearning(JobBase):
output_url=Context.get_parameters("OUTPUT_URL", "/tmp")
)
task_index = FileOps.join_path(config['output_url'],
KBResourceConstant.KB_INDEX_NAME)
KBResourceConstant.KB_INDEX_NAME.value)
config['task_index'] = task_index
super(LifelongLearning, self).__init__(
estimator=e, config=config
@@ -141,7 +141,7 @@ class LifelongLearning(JobBase):

save_extractor = FileOps.join_path(
self.config.output_url,
KBResourceConstant.TASK_EXTRACTOR_NAME
KBResourceConstant.TASK_EXTRACTOR_NAME.value
)
extractor = FileOps.dump(extractor, save_extractor)
try:


Loading…
Cancel
Save