|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- package cmd
-
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "os"
- "path/filepath"
-
- "github.com/spf13/cobra"
- "gitlink.org.cn/cloudream/jcs-pub/client/sdk/api"
- cliapi "gitlink.org.cn/cloudream/jcs-pub/client/sdk/api/v1"
- )
-
- const (
- defaultCAFileName = "ca_cert.pem"
- defaultCertFileName = "client_cert.pem"
- defaultKeyFileName = "client_key.pem"
- )
-
- var RootCmd = cobra.Command{}
-
- type CommandContext struct {
- Client *cliapi.Client
- RootCA *x509.CertPool
- Cert tls.Certificate
- }
-
- func GetCmdCtx(cmd *cobra.Command) *CommandContext {
- return cmd.Context().Value("cmdCtx").(*CommandContext)
- }
-
- var caPath string
- var certPath string
- var keyPath string
- var endpoint string
-
- func RootExecute() {
-
- RootCmd.PersistentFlags().StringVar(&caPath, "ca", "", "CA certificate file path")
- RootCmd.PersistentFlags().StringVar(&certPath, "cert", "", "client certificate file path")
- RootCmd.PersistentFlags().StringVar(&keyPath, "key", "", "client key file path")
- RootCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "API endpoint")
- RootCmd.MarkFlagsRequiredTogether("ca", "cert", "key")
-
- RootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
- ctx := GetCmdCtx(cmd)
-
- if caPath == "" {
- certDir := searchCertDir()
- if certDir == "" {
- return fmt.Errorf("cert files not found, please specify --ca, --cert and --key")
- }
-
- caPath = filepath.Join(certDir, defaultCAFileName)
- certPath = filepath.Join(certDir, defaultCertFileName)
- keyPath = filepath.Join(certDir, defaultKeyFileName)
- }
-
- rootCAPool := x509.NewCertPool()
- rootCAPem, err := os.ReadFile(caPath)
- if err != nil {
- return fmt.Errorf("reading CA file: %v", err)
- }
-
- if !rootCAPool.AppendCertsFromPEM(rootCAPem) {
- return fmt.Errorf("parsing CA failed")
- }
-
- clientCert, err := tls.LoadX509KeyPair(certPath, keyPath)
- if err != nil {
- return fmt.Errorf("loading client cert/key: %v", err)
- }
-
- if endpoint == "" {
- endpoint = "https://127.0.0.1:7890"
- }
-
- cli := cliapi.NewClient(api.Config{
- EndPoint: endpoint,
- RootCA: rootCAPool,
- Cert: clientCert,
- })
-
- ctx.Cert = clientCert
- ctx.RootCA = rootCAPool
- ctx.Client = cli
- return nil
- }
-
- RootCmd.ExecuteContext(context.WithValue(context.Background(), "cmdCtx", &CommandContext{}))
- }
-
- func searchCertDir() string {
- execPath, err := os.Executable()
- if err == nil {
- execDir := filepath.Dir(execPath)
- ca, err := os.Stat(filepath.Join(execDir, defaultCAFileName))
-
- if err == nil && !ca.IsDir() {
- return execDir
- }
- }
-
- workDir, err := os.Getwd()
- if err == nil {
- ca, err := os.Stat(filepath.Join(workDir, defaultCAFileName))
-
- if err == nil && !ca.IsDir() {
- return workDir
- }
- }
-
- return ""
- }
|