From dfdbc869fd1ba0c6c4aecfbddaa91b11459982da Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Thu, 27 Feb 2025 17:21:39 +0800 Subject: [PATCH] fix:add flavor_id in modelarts --- internal/storeLink/modelarts.go | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/internal/storeLink/modelarts.go b/internal/storeLink/modelarts.go index 349fe603..d1379ea7 100644 --- a/internal/storeLink/modelarts.go +++ b/internal/storeLink/modelarts.go @@ -30,6 +30,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-modelarts/client/modelartsservice" "gitlink.org.cn/JointCloud/pcm-modelarts/modelarts" modelartsclient "gitlink.org.cn/JointCloud/pcm-modelarts/modelarts" + "k8s.io/apimachinery/pkg/util/json" "log" "mime/multipart" "regexp" @@ -155,6 +156,8 @@ func (m *ModelArtsLink) SubmitTask(ctx context.Context, imageId string, cmd stri // modelArts提交任务 environments := make(map[string]string) parameters := make([]*modelarts.ParametersTrainJob, 0) + /* inputs := make([]*modelarts.InputTraining, 0) + outputs := make([]*modelarts.OutputTraining, 0)*/ for _, env := range envs { s := strings.Split(env, COMMA) environments[s[0]] = s[1] @@ -166,6 +169,22 @@ func (m *ModelArtsLink) SubmitTask(ctx context.Context, imageId string, cmd stri Value: s[1], }) } + /* inputs = append(inputs, &modelarts.InputTraining{ + Name: "data_url", + Remote: &modelarts.RemoteTra{ + Obs: &modelarts.Obs1{ + ObsUrl: "/test-wq/data/mnist.npz", + }, + }}) + + outputs = append(outputs, &modelarts.OutputTraining{ + Name: "train_url", + Remote: &modelarts.RemoteOut{ + Obs: &modelarts.ObsTra{ + ObsUrl: "/test-wq/model/", + }, + }, + })*/ req := &modelarts.CreateTrainingJobReq{ Kind: "job", Metadata: &modelarts.MetadataS{ @@ -180,15 +199,22 @@ func (m *ModelArtsLink) SubmitTask(ctx context.Context, imageId string, cmd stri Command: cmd, Environments: environments, Parameters: parameters, + //Inputs: inputs, + //Outputs: outputs, }, Spec: &modelarts.SpecsC{ Resource: &modelarts.ResourceCreateTraining{ - FlavorId: resourceId, + FlavorId: "modelarts.kat1.xlarge", NodeCount: 1, }, }, Platform: m.platform, } + marshal, err2 := json.Marshal(req) + if err2 != nil { + + } + println(string(marshal)) resp, err := m.modelArtsRpc.CreateTrainingJob(ctx, req) if err != nil { return nil, err