diff --git a/client/internal/cmdline/init.go b/client/internal/cmdline/init.go new file mode 100644 index 0000000..3ad3772 --- /dev/null +++ b/client/internal/cmdline/init.go @@ -0,0 +1,366 @@ +package cmdline + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/chzyer/readline" + "github.com/spf13/cobra" + "gitlink.org.cn/cloudream/common/pkgs/logger" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/accesstoken" + clicfg "gitlink.org.cn/cloudream/jcs-pub/client/internal/config" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/db" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/downloader" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/http" + mntcfg "gitlink.org.cn/cloudream/jcs-pub/client/internal/mount/config" + "gitlink.org.cn/cloudream/jcs-pub/client/internal/ticktock" + corrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/coordinator" + hubrpc "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/rpc/hub" + "gitlink.org.cn/cloudream/jcs-pub/common/pkgs/sysevent" +) + +var confs_file = path.Join("confs", "client.config.json") +var cli_certs_path = path.Join("confs", "cli_certs") +var pub_certs_path = path.Join("confs", "pub_certs") +var ca_key_file = "ca_key.pem" +var ca_cert_file = "ca_cert.pem" +var client_cert_file = "client_cert.pem" + +func init() { + cmd := cobra.Command{ + Use: "init", + Short: "initialize client configuration", + Run: func(c *cobra.Command, args []string) { + init2() + }, + } + RootCmd.AddCommand(&cmd) +} + +func getPath() (exePath string, dirPath string, err error) { + exePath, err = os.Executable() + if err != nil { + return "", "", fmt.Errorf("获取执行路径失败: %w", err) + } + dirPath = filepath.Dir(exePath) + return exePath, dirPath, nil +} + +func init2() error { + rl, err := readline.New("> ") + if err != nil { + fmt.Printf("初始化命令行失败: %v\n", err) + return err + } + defer rl.Close() + + // 1. 检查配置文件是否存在 + _, dirPath, err := getPath() + if err != nil { + return err + } + + configFilePath := filepath.Join(dirPath, confs_file) + cliCertsPath := filepath.Join(dirPath, cli_certs_path) + pubCertsPath := filepath.Join(dirPath, pub_certs_path) + + _, err = os.Stat(configFilePath) + if err == nil { + fmt.Println("\033[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m") + fmt.Println("\033[33m⚠ 配置文件已存在!重新初始化会覆盖原配置文件\033[0m") + fmt.Println("\033[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m") + checkLoop: + for { + // 2. 配置文件存在,询问是否覆盖 + rl.SetPrompt("\033[36m是否继续?(y/n): \033[0m") + overwrite, err := rl.Readline() + if err != nil { + return err + } + + switch strings.ToLower(strings.TrimSpace(overwrite)) { + case "y", "yes": + break checkLoop + + case "n", "no": + fmt.Printf("\033[32m已保留原有配置文件,退出初始化\033[0m\n") + return nil + + default: + fmt.Println("\033[31m无效输入!请输入 y/n 或 yes/no \033[0m") + } + } + } + + var cfg clicfg.Config + +mysqlLoop: + for { + // 3. 询问数据库配置 + rl.SetPrompt("\033[36m请输入Mysql连接地址(例如127.0.0.1:3306): \033[0m") + dbAddress, err := rl.Readline() + if err != nil { + return err + } + + rl.SetPrompt("\033[36m请输入Mysql用户名(例如root): \033[0m") + dbAccount, err := rl.Readline() + if err != nil { + return err + } + + dbPasswordBytes, err := rl.ReadPassword("\033[36m请输入Mysql密码(例如123456): \033[0m") + if err != nil { + return err + } + dbPassword := string(dbPasswordBytes) + + rl.SetPrompt("\033[36m请输入Mysql数据库名称(例如cloudream): \033[0m") + dbName, err := rl.Readline() + if err != nil { + return err + } + + cfg.DB = db.Config{ + Address: dbAddress, + Account: dbAccount, + Password: dbPassword, + DatabaseName: dbName, + } + + rl.SetPrompt("\033[36m是否测试数据库连接?(y/n): \033[0m") + needTest, err := rl.Readline() + if err != nil { + return err + } + + switch strings.ToLower(strings.TrimSpace(needTest)) { + case "y", "yes": + // 4. 测试数据库连接 + fmt.Printf("\033[33m正在测试数据库连接...\033[0m\n") + testDB, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", + dbAccount, + dbPassword, + dbAddress, + dbName)) + if err != nil { + fmt.Printf("\033[31m连接创建失败: %v\033[0m\n", err) + testDB.Close() + } else { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = testDB.PingContext(ctx) + if err != nil { + fmt.Printf("\033[31m数据库连接测试失败: %v\033[0m\n", err) + testDB.Close() + } else { + fmt.Printf("\033[32m✓ 数据库连接测试成功!\033[0m\n\n") + testDB.Close() + break mysqlLoop + } + } + + rl.SetPrompt("\033[36m连接测试失败,是否重新输入配置?(y/n): \033[0m") + retry, err := rl.Readline() + if err != nil { + return err + } + + if strings.ToLower(retry) != "y" { + fmt.Printf("\033[33m警告: 数据库连接尚未验证,将使用当前配置,如需后续更改,请查看配置文件client.config.json\033[0m\n") + break mysqlLoop + } + + case "n", "no": + fmt.Printf("\033[33m警告: 数据库连接未测试,将直接使用当前配置,如需后续更改,请查看配置文件client.config.json\033[0m\n") + break mysqlLoop + + default: + fmt.Println("\033[31m无效输入!请输入 y/n 或 yes/no \033[0m") + } + } + + // 5. 询问证书配置 + rl.SetPrompt("\033[36m请输入HTTP API监听地址(例如127.0.0.1:3306): \033[0m") + listen, err := rl.Readline() + if err != nil { + return err + } + + // 6. 生成证书 + err = os.MkdirAll(cliCertsPath, 0755) + if err != nil { + return err + } + + keyFilePath := filepath.Join(cliCertsPath, ca_key_file) + certFilePath := filepath.Join(cliCertsPath, ca_cert_file) + certRoot(cliCertsPath) + certServer(certFilePath, keyFilePath, cliCertsPath) + certClient(certFilePath, keyFilePath, cliCertsPath) + + cfg.HTTP = &http.ConfigJSON{ + Enabled: true, + Listen: listen, + UserSpaceID: 0, + RootCA: path.Join(cli_certs_path, "ca_cert.pem"), + ServerCert: path.Join(cli_certs_path, "server_cert.pem"), + ServerKey: path.Join(cli_certs_path, "server_key.pem"), + ClientCerts: []string{path.Join(cli_certs_path, "client_cert.pem")}, + MaxBodySize: 5242880, + } + +cloudLoop: + for { + // 7. 填写云际基础设施配置 + rl.SetPrompt("\033[36m是否连接云际存储基础设施?(y/n): \033[0m") + isConnect, err := rl.Readline() + if err != nil { + return err + } + + switch strings.ToLower(strings.TrimSpace(isConnect)) { + case "y", "yes": + rl.SetPrompt("\033[36m请输入Coordinator地址: \033[0m") + coorAddress, err := rl.Readline() + if err != nil { + return err + } + + cfg.CoordinatorRPC = corrpc.PoolConfigJSON{ + Address: coorAddress, + RootCA: path.Join(".", pub_certs_path, ca_cert_file), + ClientCert: path.Join(".", pub_certs_path, client_cert_file), + ClientKey: path.Join(".", pub_certs_path, client_cert_file), + } + + cfg.HubRPC = hubrpc.PoolConfigJSON{ + RootCA: path.Join(".", pub_certs_path, ca_cert_file), + ClientCert: path.Join(".", pub_certs_path, client_cert_file), + ClientKey: path.Join(".", pub_certs_path, client_cert_file), + } + + rl.SetPrompt("\033[36m请输入账户名(Account): \033[0m") + account, err := rl.Readline() + if err != nil { + return err + } + + passwordBytes, err := rl.ReadPassword("\033[36m请输入密码(Password): \033[0m") + if err != nil { + return err + } + password := string(passwordBytes) + + cfg.AccessToken = &accesstoken.Config{ + Account: account, + Password: password, + } + + fmt.Printf("\033[33m注意:请将JCS-pub证书文件放置于 %s 目录下\033[0m\n", pubCertsPath) + err = os.MkdirAll(pubCertsPath, 0755) + if err != nil { + return err + } + break cloudLoop + + case "n", "no": + cfg.CoordinatorRPC.Address = "127.0.0.1:5009" + cfg.AccessToken = &accesstoken.Config{} + break cloudLoop + + default: + fmt.Println("\033[31m无效输入!请输入 y/n 或 yes/no \033[0m") + } + } + + // 8. 生成配置文件 + err = saveConfig(&cfg, configFilePath) + if err != nil { + fmt.Printf("\033[31m保存配置文件失败: %v\033[0m\n", err) + return err + } + fmt.Printf("\033[32m配置文件已生成: %s\033[0m\n", configFilePath) + +dbLoop: + for { + // 9. 询问是否生成库表结构 + rl.SetPrompt("\033[36m是否生成数据库表?(y/n): \033[0m") + isCreate, err := rl.Readline() + if err != nil { + return err + } + + switch strings.ToLower(strings.TrimSpace(isCreate)) { + case "y", "yes": + // 10. 创建库表结构 + migrate(configFilePath) + break dbLoop + + case "n", "no": + fmt.Println("\033[33m请自行创建数据库表,如需更改配置,请查看配置文件client.config.json \033[0m") + break dbLoop + + default: + fmt.Println("\033[31m无效输入!请输入 y/n 或 yes/no \033[0m") + } + } + return nil +} + +func saveConfig(cfg *clicfg.Config, configPath string) error { + cfg.Logger = logger.Config{ + Level: "info", + Output: "file", + OutputFileName: "client.log", + OutputDirectory: path.Join(".", "logs"), + } + cfg.SysEvent = sysevent.Config{ + Enabled: false, + Address: "127.0.0.1:5672", + Account: "cloudream", + Password: "123456", + VHost: "/", + Exchange: "SysEvent", + Queue: "SysEvent", + } + cfg.Connectivity.TestInterval = 300 + cfg.Downloader = downloader.Config{ + MaxStripCacheCount: 100, + ECStripPrefetchCount: 1, + } + cfg.DownloadStrategy.HighLatencyHubMs = 35 + cfg.TickTock = ticktock.Config{ + ECFileSizeThreshold: 5242880, + AccessStatHistoryWeight: 0.8, + } + cfg.Mount = &mntcfg.Config{ + Enabled: false, + AttrTimeout: time.Second * 10, + UploadPendingTime: time.Second * 30, + CacheActiveTime: time.Minute * 1, + CacheExpireTime: time.Minute * 1, + ScanDataDirInterval: time.Minute * 10, + } + + configData, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("序列化配置失败: %w", err) + } + + configData = append(configData, '\n') + + err = os.WriteFile(configPath, configData, 0644) + if err != nil { + return fmt.Errorf("写入配置文件失败: %w", err) + } + return nil +}