diff --git a/jcsctl/cmd/cmd.go b/jcsctl/cmd/cmd.go index 29b0c9c..a164ef9 100644 --- a/jcsctl/cmd/cmd.go +++ b/jcsctl/cmd/cmd.go @@ -31,63 +31,65 @@ func GetCmdCtx(cmd *cobra.Command) *CommandContext { return cmd.Context().Value("cmdCtx").(*CommandContext) } +var caPath string +var certPath string +var keyPath string +var endpoint string + func RootExecute() { - var ca string - var cert string - var key string - var endpoint string - - RootCmd.PersistentFlags().StringVar(&ca, "ca", "", "CA certificate file path") - RootCmd.PersistentFlags().StringVar(&cert, "cert", "", "client certificate file path") - RootCmd.PersistentFlags().StringVar(&key, "key", "", "client key file path") + + RootCmd.PersistentFlags().StringVar(&caPath, "ca", "", "CA certificate file path") + RootCmd.PersistentFlags().StringVar(&certPath, "cert", "", "client certificate file path") + RootCmd.PersistentFlags().StringVar(&keyPath, "key", "", "client key file path") RootCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "API endpoint") RootCmd.MarkFlagsRequiredTogether("ca", "cert", "key") - if ca == "" { - certDir := searchCertDir() - if certDir == "" { - fmt.Printf("cert files not found, please specify --ca, --cert and --key\n") - os.Exit(1) + RootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + ctx := GetCmdCtx(cmd) + + if caPath == "" { + certDir := searchCertDir() + if certDir == "" { + return fmt.Errorf("cert files not found, please specify --ca, --cert and --key") + } + + caPath = filepath.Join(certDir, defaultCAFileName) + certPath = filepath.Join(certDir, defaultCertFileName) + keyPath = filepath.Join(certDir, defaultKeyFileName) } - ca = filepath.Join(certDir, defaultCAFileName) - cert = filepath.Join(certDir, defaultCertFileName) - key = filepath.Join(certDir, defaultKeyFileName) - } + rootCAPool := x509.NewCertPool() + rootCAPem, err := os.ReadFile(caPath) + if err != nil { + return fmt.Errorf("reading CA file: %v", err) + } - rootCAPool := x509.NewCertPool() - rootCAPem, err := os.ReadFile(ca) - if err != nil { - fmt.Printf("reading CA file: %v\n", err) - os.Exit(1) - } + if !rootCAPool.AppendCertsFromPEM(rootCAPem) { + return fmt.Errorf("parsing CA failed") + } - if !rootCAPool.AppendCertsFromPEM(rootCAPem) { - fmt.Printf("parsing CA failed") - os.Exit(1) - } + clientCert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return fmt.Errorf("loading client cert/key: %v", err) + } - clientCert, err := tls.LoadX509KeyPair(cert, key) - if err != nil { - fmt.Printf("loading client cert/key: %v\n", err) - os.Exit(1) - } + if endpoint == "" { + endpoint = "https://127.0.0.1:7890" + } + + cli := cliapi.NewClient(api.Config{ + EndPoint: endpoint, + RootCA: rootCAPool, + Cert: clientCert, + }) - if endpoint == "" { - endpoint = "https://127.0.0.1:7890" + ctx.Cert = clientCert + ctx.RootCA = rootCAPool + ctx.Client = cli + return nil } - cli := cliapi.NewClient(api.Config{ - EndPoint: endpoint, - RootCA: rootCAPool, - Cert: clientCert, - }) - - RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{ - Client: cli, - RootCA: rootCAPool, - Cert: clientCert, - })) + RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{})) } func searchCertDir() string {