package cmd import ( "context" "crypto/tls" "crypto/x509" "fmt" "os" "path/filepath" "github.com/spf13/cobra" "gitlink.org.cn/cloudream/jcs-pub/client/sdk/api" cliapi "gitlink.org.cn/cloudream/jcs-pub/client/sdk/api/v1" ) const ( defaultCAFileName = "ca_cert.pem" defaultCertFileName = "client_cert.pem" defaultKeyFileName = "client_key.pem" ) var RootCmd = cobra.Command{} type CommandContext struct { Client *cliapi.Client RootCA *x509.CertPool Cert tls.Certificate } func GetCmdCtx(cmd *cobra.Command) *CommandContext { return cmd.Context().Value("cmdCtx").(*CommandContext) } var caPath string var certPath string var keyPath string var endpoint string func RootExecute() { RootCmd.PersistentFlags().StringVar(&caPath, "ca", "", "CA certificate file path") RootCmd.PersistentFlags().StringVar(&certPath, "cert", "", "client certificate file path") RootCmd.PersistentFlags().StringVar(&keyPath, "key", "", "client key file path") RootCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "API endpoint") RootCmd.MarkFlagsRequiredTogether("ca", "cert", "key") RootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { ctx := GetCmdCtx(cmd) if caPath == "" { certDir := searchCertDir() if certDir == "" { return fmt.Errorf("cert files not found, please specify --ca, --cert and --key") } caPath = filepath.Join(certDir, defaultCAFileName) certPath = filepath.Join(certDir, defaultCertFileName) keyPath = filepath.Join(certDir, defaultKeyFileName) } rootCAPool := x509.NewCertPool() rootCAPem, err := os.ReadFile(caPath) if err != nil { return fmt.Errorf("reading CA file: %v", err) } if !rootCAPool.AppendCertsFromPEM(rootCAPem) { return fmt.Errorf("parsing CA failed") } clientCert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return fmt.Errorf("loading client cert/key: %v", err) } if endpoint == "" { endpoint = "https://127.0.0.1:7890" } cli := cliapi.NewClient(api.Config{ EndPoint: endpoint, RootCA: rootCAPool, Cert: clientCert, }) ctx.Cert = clientCert ctx.RootCA = rootCAPool ctx.Client = cli return nil } RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{})) } func searchCertDir() string { execPath, err := os.Executable() if err == nil { execDir := filepath.Dir(execPath) ca, err := os.Stat(filepath.Join(execDir, defaultCAFileName)) if err == nil && !ca.IsDir() { return execDir } } workDir, err := os.Getwd() if err == nil { ca, err := os.Stat(filepath.Join(workDir, defaultCAFileName)) if err == nil && !ca.IsDir() { return workDir } } return "" }