|
|
@@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger): |
|
|
|
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
self.model = model |
|
|
|
self.marker = marker |
|
|
|
self.driver_name = driver |
|
|
|
if isinstance(driver, str): |
|
|
|
self.driver_name = driver |
|
|
|
else: |
|
|
|
self.driver_name = driver.__class__.__name__ |
|
|
|
self.device = device |
|
|
|
self.fp16 = fp16 |
|
|
|
self.input_mapping = input_mapping |
|
|
@@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger): |
|
|
|
elif accumulation_steps < 0: |
|
|
|
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") |
|
|
|
self.accumulation_steps = accumulation_steps |
|
|
|
|
|
|
|
# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧 |
|
|
|
self.driver = choose_driver( |
|
|
|
model=model, |
|
|
|
driver=driver, |
|
|
@@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger): |
|
|
|
|
|
|
|
def wrapper(fn: Callable) -> Callable: |
|
|
|
cls._custom_callbacks[marker].append((event, fn)) |
|
|
|
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ |
|
|
|
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ |
|
|
|
"`Callback` with the same callback time." |
|
|
|
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] |
|
|
|
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \ |
|
|
|
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\ |
|
|
|
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}." |
|
|
|
return fn |
|
|
|
|
|
|
|
return wrapper |
|
|
@@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger): |
|
|
|
def data_device(self): |
|
|
|
return self.driver.data_device |
|
|
|
|
|
|
|
@property |
|
|
|
def model(self): |
|
|
|
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; |
|
|
|
return self.driver.model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|