| @@ -31,63 +31,65 @@ func GetCmdCtx(cmd *cobra.Command) *CommandContext { | |||
| return cmd.Context().Value("cmdCtx").(*CommandContext) | |||
| } | |||
| var caPath string | |||
| var certPath string | |||
| var keyPath string | |||
| var endpoint string | |||
| func RootExecute() { | |||
| var ca string | |||
| var cert string | |||
| var key string | |||
| var endpoint string | |||
| RootCmd.PersistentFlags().StringVar(&ca, "ca", "", "CA certificate file path") | |||
| RootCmd.PersistentFlags().StringVar(&cert, "cert", "", "client certificate file path") | |||
| RootCmd.PersistentFlags().StringVar(&key, "key", "", "client key file path") | |||
| RootCmd.PersistentFlags().StringVar(&caPath, "ca", "", "CA certificate file path") | |||
| RootCmd.PersistentFlags().StringVar(&certPath, "cert", "", "client certificate file path") | |||
| RootCmd.PersistentFlags().StringVar(&keyPath, "key", "", "client key file path") | |||
| RootCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "API endpoint") | |||
| RootCmd.MarkFlagsRequiredTogether("ca", "cert", "key") | |||
| if ca == "" { | |||
| certDir := searchCertDir() | |||
| if certDir == "" { | |||
| fmt.Printf("cert files not found, please specify --ca, --cert and --key\n") | |||
| os.Exit(1) | |||
| RootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { | |||
| ctx := GetCmdCtx(cmd) | |||
| if caPath == "" { | |||
| certDir := searchCertDir() | |||
| if certDir == "" { | |||
| return fmt.Errorf("cert files not found, please specify --ca, --cert and --key") | |||
| } | |||
| caPath = filepath.Join(certDir, defaultCAFileName) | |||
| certPath = filepath.Join(certDir, defaultCertFileName) | |||
| keyPath = filepath.Join(certDir, defaultKeyFileName) | |||
| } | |||
| ca = filepath.Join(certDir, defaultCAFileName) | |||
| cert = filepath.Join(certDir, defaultCertFileName) | |||
| key = filepath.Join(certDir, defaultKeyFileName) | |||
| } | |||
| rootCAPool := x509.NewCertPool() | |||
| rootCAPem, err := os.ReadFile(caPath) | |||
| if err != nil { | |||
| return fmt.Errorf("reading CA file: %v", err) | |||
| } | |||
| rootCAPool := x509.NewCertPool() | |||
| rootCAPem, err := os.ReadFile(ca) | |||
| if err != nil { | |||
| fmt.Printf("reading CA file: %v\n", err) | |||
| os.Exit(1) | |||
| } | |||
| if !rootCAPool.AppendCertsFromPEM(rootCAPem) { | |||
| return fmt.Errorf("parsing CA failed") | |||
| } | |||
| if !rootCAPool.AppendCertsFromPEM(rootCAPem) { | |||
| fmt.Printf("parsing CA failed") | |||
| os.Exit(1) | |||
| } | |||
| clientCert, err := tls.LoadX509KeyPair(certPath, keyPath) | |||
| if err != nil { | |||
| return fmt.Errorf("loading client cert/key: %v", err) | |||
| } | |||
| clientCert, err := tls.LoadX509KeyPair(cert, key) | |||
| if err != nil { | |||
| fmt.Printf("loading client cert/key: %v\n", err) | |||
| os.Exit(1) | |||
| } | |||
| if endpoint == "" { | |||
| endpoint = "https://127.0.0.1:7890" | |||
| } | |||
| cli := cliapi.NewClient(api.Config{ | |||
| EndPoint: endpoint, | |||
| RootCA: rootCAPool, | |||
| Cert: clientCert, | |||
| }) | |||
| if endpoint == "" { | |||
| endpoint = "https://127.0.0.1:7890" | |||
| ctx.Cert = clientCert | |||
| ctx.RootCA = rootCAPool | |||
| ctx.Client = cli | |||
| return nil | |||
| } | |||
| cli := cliapi.NewClient(api.Config{ | |||
| EndPoint: endpoint, | |||
| RootCA: rootCAPool, | |||
| Cert: clientCert, | |||
| }) | |||
| RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{ | |||
| Client: cli, | |||
| RootCA: rootCAPool, | |||
| Cert: clientCert, | |||
| })) | |||
| RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{})) | |||
| } | |||
| func searchCertDir() string { | |||